当前位置: 首页 > news >正文

深入浅出Pytorch函数——torch.nn.Linear

分类目录:《深入浅出Pytorch函数》总目录


对输入数据做线性变换 y = x A T + b y=xA^T+b y=xAT+b

语法

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)

参数

  • in_features:[int] 每个输入样本的大小
  • out_features :[int] 每个输出样本的大小
  • bias:[bool] 若设置为False,则该层不会学习偏置项目,默认值为True

变量形状

  • 输入变量: ( N , in_features ) (N, \text{in\_features}) (N,in_features)
  • 输出变量: ( N , out_features ) (N, \text{out\_features}) (N,out_features)

变量

  • weight:模块中形状为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)的可学习权重项
  • bias :模块中形状为 out_features \text{out\_features} out_features的可学习偏置项

实例

>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])

函数实现

class Linear(Module):r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`This module supports :ref:`TensorFloat32<tf32_on_ampere>`.On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.Args:in_features: size of each input sampleout_features: size of each output samplebias: If set to ``False``, the layer will not learn an additive bias.Default: ``True``Shape:- Input: :math:`(*, H_{in})` where :math:`*` means any number ofdimensions including none and :math:`H_{in} = \text{in\_features}`.- Output: :math:`(*, H_{out})` where all but the last dimensionare the same shape as the input and :math:`H_{out} = \text{out\_features}`.Attributes:weight: the learnable weights of the module of shape:math:`(\text{out\_features}, \text{in\_features})`. The values areinitialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where:math:`k = \frac{1}{\text{in\_features}}`bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.If :attr:`bias` is ``True``, the values are initialized from:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where:math:`k = \frac{1}{\text{in\_features}}`Examples::>>> m = nn.Linear(20, 30)>>> input = torch.randn(128, 20)>>> output = m(input)>>> print(output.size())torch.Size([128, 30])"""__constants__ = ['in_features', 'out_features']in_features: intout_features: intweight: Tensordef __init__(self, in_features: int, out_features: int, bias: bool = True,device=None, dtype=None) -> None:factory_kwargs = {'device': device, 'dtype': dtype}super().__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))if bias:self.bias = Parameter(torch.empty(out_features, **factory_kwargs))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self) -> None:# Setting a=sqrt(5) in kaiming_uniform is the same as initializing with# uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see# https://github.com/pytorch/pytorch/issues/57109init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0init.uniform_(self.bias, -bound, bound)def forward(self, input: Tensor) -> Tensor:return F.linear(input, self.weight, self.bias)def extra_repr(self) -> str:return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias is not None)
http://www.lryc.cn/news/130001.html

相关文章:

  • Vue3.2+TS的defineExpose的应用
  • 牛客网Python入门103题练习|【08--元组】
  • Jenkins改造—nginx配置鉴权
  • (二)VisionOS平台概述
  • 菜单中的类似iOS中开关的样式
  • Vue 2 动态组件和异步组件
  • MongoDB升级经历(4.0.23至5.0.19)
  • iPhone上的个人热点丢失了怎么办?如何修复iPhone上不见的个人热点?
  • AI 媒人:为什么图形神经网络比 MLP 更好?
  • 信息学奥赛一本通 1984:【19CSPJ普及组】纪念品 | 洛谷 P5662 [CSP-J2019] 纪念品
  • JVM——JVM参数指南
  • 马上七夕到了,用各种编程语言实现10种浪漫表白方式
  • Spring Clould 注册中心 - Eureka,Nacos
  • 使用appuploader工具发布证书和描述性文件教程
  • 【面试八股文】每日一题:谈谈你对IO的理解
  • 200. 岛屿数量
  • 【LeetCode】581.最短无序连续子数组
  • 曲面(弧面、柱面)展平(拉直)瓶子标签识别ocr
  • 知识继承概述
  • 深度剖析数据在内存中的存储
  • 【ARM Linux 系统稳定性分析入门及渐进10 -- GDB 初始化脚本介绍及使用】
  • AQS源码解读
  • QT实现天气预报
  • 【马蹄集】第二十三周——进位制专题
  • [足式机器人]Part3 变分法Ch01-1 数学预备知识——【读书笔记】
  • 计算机网络----CRC冗余码的运算
  • 将Nginx源码数组结构(ngx_array.c)和内存池代码单独编译运行,附代码
  • java forEach中不能使用break和continue的原因
  • [杂项]水浒英雄谱系列电影列表
  • 6.RocketMQ之索引文件ConsumeQueue