【PyTorch】教程:torch.nn.Hardswish
torch.nn.Hardswish
原型
CLASS torch.nn.Hardswish(inplace=False)
参数
- inplace (bool) – 内部运算,默认为
False
定义
Hardswish(x)={0if x≤−3,xif x≥+3,x⋅(x+3)/6otherwise\text{Hardswish}(x) = \begin{cases} 0 & \text{if~} x \le -3, \\ x & \text{if~} x \ge +3, \\ x \cdot (x + 3) /6 & \text{otherwise} \end{cases} Hardswish(x)=⎩⎨⎧0xx⋅(x+3)/6if x≤−3,if x≥+3,otherwise
图
代码
import torch
import torch.nn as nnm = nn.Hardswish()
input = torch.randn(4)
output = m(input)
print("input: ", input) # input: tensor([-0.5567, -0.4911, 0.2918, 2.1492])
print("output: ", output) # output: tensor([-0.2267, -0.2054, 0.1601, 1.8445])
【参考】
Hardswish — PyTorch 1.13 documentation