当前位置: 首页 > news >正文

LeNet网络的实现

LeNet网络的实现


import torch
from torch import nn
from d2l import torch as d2lx = 28
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * (x/4 - 2) * (x/4 - 2), 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

输入图像是单通道 x*x大小

  1. 卷积层。
    输入一个通道,输出六个通道,卷积核大小5*5,填充2,步幅1,因此输出图像大小不变。
  2. 平均汇聚层。
    核大小2*2,步幅2,因此输出图像大小减半。(x/2)(x/2)
  3. 卷积层。
    输入6通道,输出16通道,核大小5,输出图像大小减4.(x/2-4) (x/2 - 4)
  4. 平均汇聚层。
    核大小2*2,步幅2,输出大小减半。(x/4-2)(x/4-2)
  5. 全连接层。
    输入大小: 16 * (x/4 - 2) * (x/4 - 2)
    输出大小: 10

测试函数

def evaluate_accuracy_gpu(net , data_iter,device=None):if isinstance(net , nn.Module):net.eval()if not device:# 获取第一个参数所在的设备,把以后的数据放在同一个设备上device = next(iter(net.parameters())).devicemetric = d2l.Accumulator(2)with torch.no_grad():for X , y in data_iter:if isinstance(X , list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X),y) , y.numel())return metric[0] / metric[1]

训练和测试

def train_ch6(net , train_iter , test_iter, num_epochs , lr ,device):# 初始化权重def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print( ' training on ' , device)net.to(device)# 优化器  optimizer = torch.optim.SGD(net.parameters(), lr)# 损失函数loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch',xlim=[1 , num_epochs],legend=['train loss','train acc','test acc'])timer , num_batches = d2l.Timer() , len(train_iter)for epoch in range(num_epochs):metric = d2l.Accumulator(3)net.train()for i , (X, y ) in enumerate(train_iter):timer.start()optimizer.zero_grad()X , y = X.to(device) , y.to(device)y_hat = net(X)l = loss(y_hat , y)l.backward()optimizer.step()with torch.no_grad():metric.add(l*X.shape[0] , d2l.accuracy(y_hat , y) , X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i+1) % (num_batches //5) ==0 or i ==num_batches - 1:animator.add(epoch + (i+1) /num_batches,(train_l , train_acc , None))test_acc = evaluate_accuracy_gpu(net , test_iter)animator.add(epoch+1 ,(None , None , test_acc))print(f'loss {train_l:.3f},train_acc {train_acc:.3f} ,  'f'test_acc{test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')
http://www.lryc.cn/news/383816.html

相关文章:

  • 华为HCIP Datacom H12-821 卷6
  • 深入理解SQL优化:理论与实践的结合
  • PostgreSQL 高级功能与扩展(九)
  • 【LinuxC语言】UDP数据收发
  • 【深度学习驱动流体力学】计算流体力学openfoam-paraview与python3交互
  • EWM学习之旅-1-EWM100
  • qt中的枚举值-QMetaEnum
  • 这才是CSDN最系统的网络安全学习路线(建议收藏)
  • 微软Edge浏览器多用户配置文件管理:个性化浏览体验
  • 10.2 JavaEE——Spring MVC入门程序
  • Python 处理大量数据的相关库和框架推荐
  • 【unity笔记】七、Mirror插件使用
  • 掌握SEO:如何优化用ChatGPT生成的文章以提升搜索排名
  • Java面试问题(一)
  • Firewalld防火墙基础
  • 解决Java中多线程同步问题的方案
  • 每日一练 - RSTP与STP收敛速度对比
  • ZS-20H型水泥胶砂振实台
  • 力扣377 组合总和Ⅳ Java版本
  • 昇思25天学习打卡营第3天 | 数据集 Dataset
  • 交换机三层架构及对流量的转发机制
  • 开发者配置项、开发者选项自定义
  • 【Java】解决Java报错:IndexOutOfBoundsException in Collections
  • C++编程(三)面向对象
  • Batch入门教程
  • 49-2 内网渗透 - 使用UACME Bypass UAC
  • Django 表单使用示例:数据格式校验
  • OkHttp框架源码深度剖析【Android热门框架分析第一弹】
  • 【MySQL】数据库——备份与恢复,日志管理1
  • 什么样的企业适合SD-WAN网络专线?