当前位置: 首页 > news >正文

Pytorch复习笔记--torch.nn.functional.interpolate()和cv2.resize()的使用与比较

1--前言

        博主在处理图片尺度问题时,习惯使用 cv2.resize() 函数;但当图片数据需用显卡加速运算时,数据需要在 GPU 和 CPU 之间不断迁移,导致程序运行效率降低;

        Pytorch 提供了一个类似于 cv2.resize() 的采样函数,即 torch.nn.functional.interpolate(),支持最近邻插值(nearest)和双线性插值(bilinear)等功能,通过设置合理的插值方式可以取得与 cv2.resize() 函数完全一样的效果。

2--代码测试

        ① 最近邻方法('nearnest' 和 cv2.INTER_NEAREST):

import torch
import cv2
import torch.nn.functional as F
import numpy as npinput_data1 = torch.randint(low = 0, high = 255, size = [40, 40, 3])
input_data2 = np.array(input_data1, dtype = np.uint8)input_data1 = input_data1.permute(2, 0, 1).unsqueeze(0).float() # [1, 3, 40, 40]
output_data1 = F.interpolate(input_data1, size = (224, 224), mode='nearest').float() # [1, 3, 224, 224]
output_data2 = cv2.resize(input_data2, dsize = (224, 224), interpolation=cv2.INTER_NEAREST) # [224, 224, 3]data1 = np.array(output_data1.squeeze(0).permute(1, 2, 0), dtype=np.uint8)
data2 = np.array(output_data2, dtype=np.uint8)print(data1 == data2)print("All done !")

        ② 双线性插值方法('bilinear' 和 cv2.INTER_LINEAR):

import torch
import cv2
import torch.nn.functional as F
import numpy as npinput_data1 = torch.randint(low = 0, high = 255, size = [40, 40, 3])
input_data2 = np.array(input_data1, dtype = np.uint8)input_data1 = input_data1.permute(2, 0, 1).unsqueeze(0).float() # [1, 3, 40, 40]
output_data1 = F.interpolate(input_data1, size = (224, 224), mode='bilinear').float() # [1, 3, 224, 224]
output_data2 = cv2.resize(input_data2, dsize = (224, 224), interpolation=cv2.INTER_LINEAR) # [224, 224, 3]data1 = np.array(output_data1.squeeze(0).permute(1, 2, 0), dtype=np.uint8)
data2 = np.array(output_data2, dtype=np.uint8)print(data1 == data2)print("All done !")

上面两个测试代码的结果表明,在采取相同插值方式的前提下,torch.nn.functional.interpolate() 和 cv2.resize() 两个方法的功能是完全等价的,处理后的数据相同;

3--相关补充

        ① 使用 torch.nn.functional.interpolate()的注意事项:

1. 插值方法(mode)与输入数据的维度(minibatch, channels, [optional depth], [optional height], width)密切相关,目前支持的数据维度有以下几种:

        ① 3D张量输入:minibatch, channels, width;

        ② 4D张量输入:minibatch, channels, height, width;

        ③ 5D张量输入:minibatch, channels, depth, height, width;

2. 插值方法和输入维度的关系如下:

http://www.lryc.cn/news/6653.html

相关文章:

  • ASP.NET Core MVC 项目 AOP之ActionFilterAttribute
  • 浅析EasyCVR安防视频能力在智慧小区建设场景中的应用及意义
  • Python的深、浅拷贝到底是怎么回事?一篇解决问题
  • TCP协议十大特性
  • 2.14作业【GPIIO控制LED】
  • 5min搞定linux环境Jenkins的安装
  • Cortex-M0存储器系统
  • 软件测试——测试用例之场景法
  • 英文写作中的常用的衔接词
  • 新库上线 | CnOpenData中国地方政府债券信息数据
  • Python 条件语句
  • C语言思维导图大总结 可用于期末考试 C语言期末考试题库
  • 从零实现深度学习框架——再探多层双向RNN的实现
  • Flink 连接流详解
  • 分享112个HTML电子商务模板,总有一款适合您
  • 2023备战金三银四,Python自动化软件测试面试宝典合集(八)
  • J-Link RTT Viewer使用教程(附代码)
  • C语言——指针、数组的经典笔试题目
  • 【C语言】程序环境和预处理|预处理详解|定义宏(上)
  • 上海霄腾自动化装备盛装亮相2023生物发酵展
  • python+flask开发mock服务
  • 数据库(三)
  • 2023软考纸质证书领取通知来了!
  • Python requests模块
  • 工业智能网关解决方案:物联网仓储环境监测系统
  • Linux进程线程管理
  • 分享111个HTML电子商务模板,总有一款适合您
  • 百度前端必会手写面试题整理
  • ubuntu 安装支持GPU的Docker详细步骤
  • usbmon+tcpdump+wireshark USB抓包