当前位置: 首页 > news >正文

深度学习篇---预训练模型

在深度学习中,预训练模型(Pretrained Model) 是提升开发效率和模型性能的 “利器”。无论是图像识别、自然语言处理还是语音识别,预训练模型都被广泛使用。下面从概念、使用原因、场景、作用等方面详细介绍,并结合 Python 代码展示常用预训练模型的使用。

一、什么是预训练模型?(通俗易懂版)

可以把预训练模型理解为:“别人已经训练好的‘半成品模型’,你可以直接拿来用,或者稍作修改就能适配自己的任务”

举个例子:假设你想训练一个 “识别猫和狗” 的模型,需要大量图片和算力。但有人已经用百万张图片(如 ImageNet 数据集)训练了一个 “能识别 1000 种物体” 的模型,这个模型已经学会了 “边缘、纹理、形状” 等通用视觉特征(比如 “猫有耳朵、狗有尾巴”)。你可以直接用这个模型,要么直接预测猫和狗,要么在它的基础上再用少量猫和狗的图片 “微调”,就能快速得到一个好模型。

简言之,预训练模型是 “前人训练好的成果,你可以站在它的肩膀上做开发”。

二、为什么要用预训练模型?

  1. 节省时间和算力
    训练一个复杂模型(如 ResNet、BERT)可能需要几天甚至几周,还需要高性能 GPU。预训练模型已经完成了大部分计算,直接用或微调只需几小时,适合个人或小团队(没有超强算力)。

  2. 数据量少时也能出效果
    深度学习需要大量数据(如几十万张图片),但实际场景中可能只有几千张数据(如自己拍的猫和狗图片)。预训练模型已经 “见过” 海量数据,学到了通用特征,用少量数据微调就能达到不错的效果(否则从头训练可能过拟合)。

  3. 性能更优
    预训练模型通常基于大规模数据集(如 ImageNet 有 1400 万张图片)和优化的网络结构,其学到的特征更通用、更鲁棒。在此基础上微调的模型,性能往往比 “从头训练” 好很多。

三、预训练模型的使用场景

  1. 快速开发原型
    当你需要快速验证一个想法(比如 “用模型识别工厂的零件是否合格”),可以直接用预训练模型做初步测试,不需要从零开始训练。

  2. 数据量有限的任务
    比如医学影像识别(数据少且标注成本高)、小众物体识别(如特定品种的花),用预训练模型微调能显著提升精度。

  3. 迁移学习任务
    从 “通用任务” 迁移到 “具体任务”:比如用 “识别 1000 类物体” 的预训练模型,迁移到 “识别 5 类水果” 的任务;用 “通用文本分类” 的 BERT,迁移到 “情感分析” 任务。

  4. 边缘设备部署
    很多预训练模型有 “轻量化版本”(如 MobileNet、EfficientNet-Lite),适合在手机、摄像头等边缘设备上部署(算力有限但需要快速推理)。

四、预训练模型的作用

  1. 提供通用特征提取能力
    预训练模型的前半部分(如 CNN 的卷积层、Transformer 的编码器)已经学会了通用特征(如图像的边缘、纹理,文本的语义关系),可以直接作为 “特征提取器” 使用。

  2. 加速模型收敛
    微调时,模型参数不需要从 0 开始学习,而是在预训练的 “好起点” 上优化,训练速度更快(比如原本需要 100 个 epoch,微调可能只需 20 个)。

  3. 降低过拟合风险
    预训练模型学到的通用特征能 “抵抗” 小数据集的噪声,减少模型对训练数据的过度依赖(过拟合)。

五、Python 中常用的预训练模型及使用代码

计算机视觉(图像任务) 为例,PyTorch 的torchvision库和 TensorFlow 的tf.keras.applications提供了大量预训练模型。下面以 PyTorch 为例,介绍最常用的模型及代码。

常用预训练模型(图像任务)
模型名称特点适用场景
ResNet(ResNet50/101)结构深、精度高,适合需要高精度的任务图像分类、特征提取
VGG16/VGG19结构简单、特征提取能力强迁移学习、细粒度分类
MobileNetV2/V3轻量化、计算量小手机、摄像头等边缘设备部署
EfficientNet精度与效率平衡(比 ResNet 好且更轻量)兼顾精度和速度的场景
Faster R-CNN经典目标检测模型目标检测(定位 + 分类)
代码示例:使用预训练模型进行图像分类

ResNet50为例,展示 “加载预训练模型→预处理图像→推理预测” 的完整流程。

步骤 1:安装依赖

确保安装了torchtorchvision

pip install torch torchvision
步骤 2:加载预训练模型并查看结构
import torch
from torchvision import models# 加载预训练的ResNet50(pretrained=True表示加载预训练权重)
resnet50 = models.resnet50(pretrained=True)
# 设置为评估模式(关闭 dropout、batchnorm等训练时的层)
resnet50.eval()# 查看模型结构(简化输出)
print("ResNet50结构概览:")
print(resnet50)

模型结构说明

  • 前半部分是conv1layer4的卷积层(特征提取);
  • 后半部分是avgpool(全局平均池化)和fc(全连接层,输出 1000 类,对应 ImageNet 的 1000 个类别)。
步骤 3:图像预处理(必须与预训练一致)

预训练模型对输入图像有固定要求(如尺寸、归一化参数),需严格匹配:

from torchvision import transforms
from PIL import Image# 定义预处理流程(与ResNet训练时的预处理一致)
preprocess = transforms.Compose([transforms.Resize(256),  # 缩放到256x256transforms.CenterCrop(224),  # 中心裁剪到224x224transforms.ToTensor(),  # 转为张量并归一化到0-1# 标准化(使用ImageNet的均值和标准差,必须与预训练一致)transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
步骤 4:用预训练模型进行推理(预测图像类别)
# 读取一张测试图片(如一只猫)
img = Image.open("cat.jpg")  # 替换为你的图片路径
# 预处理
input_tensor = preprocess(img)
# 增加批次维度(模型要求输入是(batch_size, channels, H, W),这里batch_size=1)
input_batch = input_tensor.unsqueeze(0)# 用GPU加速(如果有)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet50.to(device)
input_batch = input_batch.to(device)# 推理(关闭梯度计算,加快速度)
with torch.no_grad():output = resnet50(input_batch)# 输出是1000类的概率(logits),取最大概率的类别
predicted_class = torch.argmax(output[0]).item()# 加载ImageNet的类别名称(1000类)
from torchvision.datasets import ImageNet
# 注意:ImageNet数据集需手动下载,这里简化为加载类别名称(可网上搜索获取)
with open("imagenet_classes.txt") as f:  # 包含1000类名称的文件classes = [line.strip() for line in f.readlines()]print(f"预测类别:{classes[predicted_class]}")

说明imagenet_classes.txt包含 ImageNet 的 1000 个类别名称(如 “猫”“狗”“汽车”),可从网上下载(搜索 “imagenet classes list”)。

步骤 5:微调预训练模型(适配自定义任务)

如果要解决自己的分类任务(如识别 “猫、狗、鸟”3 类),需要微调模型:

# 1. 修改输出层(将1000类改为3类)
num_classes = 3  # 自定义类别数
resnet50.fc = torch.nn.Linear(resnet50.fc.in_features, num_classes)# 2. 冻结部分层(可选,加速训练)
# 冻结前几层(保留预训练的通用特征),只训练最后几层
for param in list(resnet50.parameters())[:-10]:  # 冻结除最后10层外的参数param.requires_grad = False# 3. 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet50.parameters(), lr=0.001)# 4. 加载自定义数据集(假设已通过DataLoader准备好)
# train_loader = ...(自定义数据集的DataLoader)# 5. 微调训练
resnet50.train()
for epoch in range(10):  # 训练10个epochrunning_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播outputs = resnet50(images)loss = criterion(outputs, labels)# 反向传播+优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

微调关键

  • 修改输出层以匹配自定义类别数;
  • 可选冻结部分层(减少计算量,保留通用特征);
  • 用较小的学习率(避免破坏预训练的好参数)。

六、总结(通俗易懂版)

预训练模型就像 “已经学过基础知识的学霸”:

  • 如果你想快速解决一个问题(如识别图片里的东西),可以直接让学霸帮你 “答题”(推理);
  • 如果你想让学霸学新技能(如识别你的 3 种宠物),只需让他在已有知识上 “稍作练习”(微调),比教一个零基础的人(从头训练)快得多,效果也好得多。

无论是小团队、个人开发者还是企业,合理使用预训练模型都能大幅提升效率,少走弯路。

http://www.lryc.cn/news/599976.html

相关文章:

  • 升级目标API级别到35,以Android15为目标平台(三 View绑定篇)
  • 【应急响应】进程隐藏技术与检测方式(二)
  • 三坐标和激光跟踪仪的区别
  • 重庆市傲雄司法鉴定所获准新增四项司法鉴定资质
  • 认识编程(3)-语法背后的认知战争:类型声明的前世今生
  • 利用Trae将原型图转换为可执行的html文件,感受AI编程的魅力
  • 使用python的头文件Matplotlib时plt.show()【标题字体过小】问题根源与解决方案
  • java每日精进 7.25【流程设计3.0(网关+边界事件)】
  • 【Linux系统】基础IO(下)
  • 解决笔记本合盖开盖DPI缩放大小变 (异于网传方法,Win11 24H2)
  • STM32的WI-FI通讯(HAL库)
  • 【电赛学习笔记】MaxiCAM 项目实践——二维云台追踪指定目标
  • 嵌入式Linux裸机开发笔记8(IMX6ULL)主频和时钟配置实验(3)
  • vue 渲染 | 不同类型的元素渲染的方式(vue组件/htmlelement/纯 html)
  • linux配置ntp时间同步
  • 前端核心进阶:从原理到手写Promise、防抖节流与深拷贝
  • ERNIE-4.5-0.3B 实战指南:文心一言 4.5 开源模型的轻量化部署与效能跃升
  • Agentic RAG理解和简易实现
  • 计算机体系结构中的中断服务程序ISR是什么?
  • haproxy集群
  • Java测试题(上)
  • Spring之【Bean后置处理器】
  • sam2环境安装
  • JAVA语法糖
  • JAVA同城服务家政服务家政派单系统源码微信小程序+微信公众号+APP+H5
  • 探索 Sui 上 BTCfi 的各类资产
  • 在DolphinScheduler执行Python问题小记
  • DP4871音频放大芯片3W功率单通道AB类立体声/音频放大器
  • 3N90-ASEMI电源管理领域专用3N90
  • 【前端】JavaScript文件压缩指南