【PyTorch】教程:torch.nn.Hardshrink
torch.nn.Hardshrink
CLASS torch.nn.Hardshrink(lambd=0.5)
参数
- lambd ([float]) – the λ\lambdaλ 默认为
0.5
定义
HardShrink(x)={x,if x>λx,if x<−λ0,otherwise \text{HardShrink}(x) = \begin{cases} x, & \text{ if } x > \lambda \\ x, & \text{ if } x < -\lambda \\ 0, & \text{ otherwise } \end{cases} HardShrink(x)=⎩⎨⎧x,x,0, if x>λ if x<−λ otherwise
图
代码
import torch
import torch.nn as nnm = nn.Hardshrink()
input = torch.randn(2)
output = m(input)
print("input: ", input) # input: tensor([ 0.2078, -1.4333])
print("output: ", output) # output: tensor([ 0.0000, -1.4333])
【参考】
Hardshrink — PyTorch 1.13 documentation