当前位置: 首页 > news >正文

昇思学习打卡-14-ResNet50迁移学习

文章目录

  • 数据集可视化
  • 预训练模型的使用
    • 部分实现
  • 推理

  • 迁移学习:在一个很大的数据集上训练得到一个预训练模型,然后使用该模型来初始化网络的权重参数或作为固定特征提取器应用于特定的任务中。
  • 本章学习使用的是前面学过的ResNet50,使用迁移学习的方法对ImageNet数据集中的狼和狗图像进行分类。

数据集可视化

在这里插入图片描述

预训练模型的使用

  • 搭建好模型框架后,通过将pretrained参数设置为True来下载ResNet50的预训练模型,并将权重参数加载到网络中。
  • 使用固定特征进行训练的时候,需要冻结除最后一层之外的所有网络层。通过设置 requires_grad == False 冻结参数,以便不在反向传播中计算梯度。

部分实现

import matplotlib.pyplot as plt
import os
import time
# 修改参数1pretrained=True
net_work = resnet50(pretrained=True)# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 2)
# 重置全连接层
net_work.fc = head# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():if param.name not in ["fc.weight", "fc.bias"]:# 修改参数2param.requires_grad = False# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = net_work(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

推理

在这里插入图片描述
此章节学习到此结束,感谢昇思平台。

http://www.lryc.cn/news/397131.html

相关文章:

  • 软件开发面试题C#,.NET知识点(续)
  • 2019年美赛题目Problem A: Game of Ecology
  • 沙龙回顾|MongoDB如何充当企业开发加速器?
  • 云端编码:将您的技术API文档安全存储在iCloud的最佳实践
  • 在Spring Boot项目中集成单点登录解决方案
  • Java-常用API
  • Python从Excel表中查找指定数据填入新表
  • 从零开始实现大语言模型(三):Token Embedding与位置编码
  • 视频怎么压缩变小?最佳视频压缩器
  • LLM - 绝对与相对位置编码 与 RoPE 旋转位置编码 源码
  • B3917 [语言月赛 202401] 小跳蛙
  • Bash ——shell
  • PyTorch复现PointNet——模型训练+可视化测试显示
  • 分享五款软件,成为高效生活的好助手
  • 代码随想录算法训练营DAY58|101.孤岛的总面积、102.沉没孤岛、103. 水流问题、104.建造最大岛屿
  • 韦尔股份:深蹲起跳?
  • docs | 使用 sphinx 转化rst文件为html文档
  • 【ChatGPT 消费者偏好】第二弹:ChatGPT在日常生活中的使用—推文分享—2024-07-10
  • Webpack配置及工作流程
  • 华为ensp实现防火墙的区域管理与用户认证
  • 深入解析 Laravel 策略路由:提高应用安全性与灵活性的利器
  • Java | Leetcode Java题解之第228题汇总区间
  • 使用Simulink基于模型设计(三):建模并验证系统
  • 基于go 1.19的站点模板爬虫
  • 0基础学会在亚马逊云科技AWS上搭建生成式AI云原生Serverless问答QA机器人(含代码和步骤)
  • [PaddlePaddle飞桨] PaddleOCR图像小模型部署
  • C语言 | Leetcode C语言题解之第227题基本计算题II
  • kafka.common.KafkaException: Socket server failed to bind to xx:9092
  • 【JS+H5+CSS实现烟花特效】
  • uniapp小程序使用webview 嵌套 vue 项目