当前位置: 首页 > news >正文

pytorch nearest upsample整数型tensor

在用 torch.nn.Upsample 给分割 label 上采样时报错:RuntimeError: "upsample_nearest2d_out_frame" not implemented for 'Long'

参考 [1-3],用 [3] 给出的实现。稍微扩展一下,支持 h、w 用不同的 scale factor,并测试其与 PyTorch 的几个 upsample 类的异同,验证 [3] 的实现用 nearest 插值。

Code

  • linear 要 3D 输入、trilinear 要 5D 输入,故此两种插值法没比。
import torch
import torch.nn as nnclass UpsampleDeterministic(nn.Module):"""deterministic upsample with `nearest` interpolation"""def __init__(self, scale_factor=2):"""Input:scale_factor: int or (int, int), ratio to scale (along heigth & width)"""super(UpsampleDeterministic, self).__init__()if isinstance(scale_factor, (tuple, list)):assert len(scale_factor) == 2self.scale_h, self.scale_w = scale_factorelse:self.scale_h = self.scale_w = scale_factorassert isinstance(self.scale_h, int) and isinstance(self.scale_w, int)def forward(self, x):"""Input:x: [n, c, h, w], torch.TensorOutput:upsampled x': [n, c, h * scale_h, w * scale_w]"""return x[:, :, :, None, :, None].expand(-1, -1, -1, self.scale_h, -1, self.scale_w).reshape(x.size(0), x.size(1), x.size(2) * self.scale_h, x.size(3) * self.scale_w)# 随机数据
x = torch.rand(2, 3, 4, 4) # [n, c, h, w]
# [3] 的实现
us_det = UpsampleDeterministic((2, 3))
# pytorch 自带的几种实现
us_list = {mode: nn.Upsample(scale_factor=(2, 3), mode=mode)for mode in ('nearest', 'bilinear', 'bicubic')}
# linear: 3D
# trilinear: 5Dy_det = us_det(x)
print(y_det.size())
for us_name, us in us_list.items():y = us(x)print(us_name, y.size(), (y_det != y).sum())

输出:

torch.Size([2, 3, 8, 12])
nearest torch.Size([2, 3, 8, 12]) tensor(0)
bilinear torch.Size([2, 3, 8, 12]) tensor(507)
bicubic torch.Size([2, 3, 8, 12]) tensor(576)

可见 [3] 的实现与 nearest 结果一致。

References

  1. 请慎用torch.nn.Upsample
  2. PyTorch中模型的可复现性
  3. Non Deterministic Behaviour even after cudnn.deterministic = True and cudnn.benchmark=False #12207
http://www.lryc.cn/news/292398.html

相关文章:

  • MySQL的SQL MODE
  • GO EASY 框架 之 NET 05
  • 【教程】谈一谈 IPA 上传到 App Store Connect 的几种方法
  • 面试经典 150 题 -- 滑动窗口 (总结)
  • JDK8对List对象根据属性排序
  • 【2024美国大学生数学建模竞赛】2024美赛C题网球运动中的势头,网球教练4.0没人比我更懂这个题了!!!
  • python的Flask生产环境部署说明照做成功
  • EXCEL VBA调用百度api识别身份证
  • 【每日一题】7.LeetCode——合并两个有序链表
  • 【零基础学习CAPL】——CAN报文的发送(按下按钮同时周期性发送)
  • 六、Nacos源码系列:Nacos健康检查
  • 2024美赛C题思路/代码:网球中的动量
  • ConcurrentHashMap原理详解(太细了)
  • EasyExcel根据对应的实体类模板完成多个sheet的写入与读取
  • 在企业数字化转型过程中,IT运维发挥着怎样的价值?
  • 01-工厂模式 ( Factory Pattern )
  • 【LeetCode】每日一题 2024_2_2 石子游戏 VI(排序、贪心)
  • 一站式在线协作开源办公软件ONLYOFFICE,协作更安全更便捷
  • Java进击框架:Spring-综合(十)
  • 2024年第九届信号与图像处理国际会议(ICSIP 2024)
  • webassembly003 MINISIT mnist/convert-h5-to-ggml.py
  • fetch和axios的区别
  • 【unity小技巧】FPS简单的射击换挡瞄准动画控制
  • 如何获取时间戳
  • VSCode 设置代理
  • 保姆级教程: 零门槛制作AI微信红包封面之入门篇
  • Redis核心技术与实战【学习笔记】 - 17.Redis 缓存异常:缓存雪崩、击穿、穿透
  • Leetcode—2670. 找出不同元素数目差数组【简单】
  • App ICP备案获取iOS和Android的公钥和证书指纹
  • 猿创征文 | 项目整合KafkaStream实现文章热度实时计算