当前位置: 首页 > news >正文

【深度学习】卷积网络代码实战ResNet

        ResNet (Residual Network) 是由微软研究院的何凯明等人在2015年提出的一种深度卷积神经网络结构。ResNet的设计目标是解决深层网络训练中的梯度消失和梯度爆炸问题,进一步提高网络的表现。下面是一个ResNet模型实现,使用PyTorch框架来展示如何实现基本的ResNet结构。这个例子包括了一个基本的残差块(Residual Block)以及ResNet-18的实现,代码结构分为model.py(模型文件)和train.py(训练文件)。

model.py 

      首先,我们导入所需要的包 

import torch
from torch import nn
from torch.nn import functional as F

        然后,定义Resnet Block(ResBlk)类。

class ResBlk(nn.Module):def __init__(self):super(ResBlk, self).__init__()self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(ch_out)self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(ch_out)self.extra = nn.Sequential()if ch_out != ch_inself.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1)nn.BatchNorm2d(ch_out))def forward(self, x):out = F.relu(self.bn1(self.conv1(x)))out = F.relu(self.bn2(self.conv2(x)))out = self.extra(x) + outreturn out

        最后,根据ResNet18的结构对ResNet Block进行堆叠。

class Resnet18(nn.Module):def __init__(self):super(Resnet18, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)nn.BatchNorm2d(64))self.blk1 = ResBlk(64, 128)self.blk2 = ResBlk(128, 256)self.blk3 = ResBlk(256, 512)self.blk4 = ResBlk(512, 1024)self.outlayer = nn.Linear(512, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.blk1(x)x = self.blk2(x)x = self.blk3(x)x = self.blk4(x)# print('after conv1:', x.shape)x = F.adaptive_avg_pool2d(x, [1,1])x = x.view(x.size(0), -1)x = self.outlayer(x)return x

        其中,在网络结构搭建过程中,需要用到中间阶段的图片参数,用下述测试过程求得。

def main():tmp = torch.randn(2, 3, 32, 32)out = blk(tmp)print('block', out.shape)x = torch.randn(2, 3, 32, 32)model = ResNet18()out = model(x)print('resnet:', out.shape)

train.py

        首先,导入所需要的包

import torch
from torchvision import datasets
from torchvision import transforms
from torch import nn, optimizer

        然后,定义main()函数

def main():batchsz = 32cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor()]), download=True)cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)x, label = iter(cifar_train).next()print('x:', x.shape, 'label:', label.shape)device = torch.device('cuda')model = ResNet18().to(device)criteon = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=1e-3)print(model)for epoch in range(100):for batchidx, (x, label) in enumerate(cifar_train):x, label = x.to(device), label.to(device)logits = model(x)loss = criteon(logitsm label)optimizer.zero_grad()loss.backward()optimizer.step()print(loss.item())with torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:x, label = x.to(device), label.to(device)logits = model(x)pred = logits.argmax(dim=1)total_correct += torch.eq(pred, label).floot().sum().item()total_num += x.size(0)acc = total_correct / total_numprint(epoch, acc)

http://www.lryc.cn/news/512011.html

相关文章:

  • org.apache.zookeeper.server.quorum.QuorumPeerMain
  • oscp学习之路,Kioptix Level2靶场通关教程
  • SkyWalking java-agent 是如何工作的,自己实现一个监控sql执行耗时的agent
  • 每天40分玩转Django:Django表单集
  • 查看vue的所有版本号和已安装的版本
  • 钉钉h5微应用,鉴权提示dd.config错误说明,提示“jsapi ticket读取失败
  • 【openGauss】正则表达式次数符号“{}“在ORACLE和openGauss中的差异
  • 宏任务和微任务的区别
  • 数据库系统原理复习汇总
  • Linux day1204
  • 如何在 Ubuntu 22.04 上安装并开始使用 RabbitMQ
  • 【OpenGL ES】GLSL基础语法
  • 如何使用交叉编译器调试C语言程序在安卓设备中运行
  • Java全栈项目 - 智能考勤管理系统
  • Linux Shell : Process Substitution
  • JOGL 从入门到精通:开启 Java 3D 图形编程之旅
  • 汽车网络安全基线安全研究报告
  • Eclipse 修改项目栏字体大小
  • 【PCIe 总线及设备入门学习专栏 5.1 -- PCIe 引脚 PRSNT 与热插拔】
  • 【YOLO】YOLOv5原理
  • uniapp中wx.getFuzzyLocation报错如何解决
  • opencv图像直方图
  • OpenCV计算机视觉 03 椒盐噪声的添加与常见的平滑处理方式(均值、方框、高斯、中值)
  • 【嵌入式C语言】内存分布
  • 【brainpan靶场渗透】
  • Java实现观察者模式
  • 通过百度api处理交通数据
  • 探索CSDN博客数据:使用Python爬虫技术
  • b站ip属地评论和主页不一样怎么回事
  • 如何查看服务器内存占用情况?