当前位置: 首页 > news >正文

Day51 复习日-模型改进

day43对自己找的数据集用简单cnn训练,现在用预训练,加入注意力等

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 检查 GPU 是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理
# # 计算均值和方差(仅运行一次)
# def calculate_mean_std(dataloader):
#     mean = torch.zeros(3)
#     std = torch.zeros(3)
#     total_images = 0
#     for images, _ in dataloader:
#         batch_size = images.size(0)
#         images = images.view(batch_size, 3, -1)
#         mean += images.mean(2).sum(0)
#         std += images.std(2).sum(0)
#         total_images += batch_size
#     mean /= total_images
#     std /= total_images
#     return mean, std# # 用无增强的dataloader计算(避免增强影响统计)
# temp_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
# temp_dataset = datasets.ImageFolder(root=your_data_root, transform=temp_transform)
# temp_loader = DataLoader(temp_dataset, batch_size=32, shuffle=False)
# mean, std = calculate_mean_std(temp_loader)
# print(f"数据集均值:{mean},方差:{std}")train_transform = transforms.Compose([transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # 使用ImageNet的均值和方差transforms.Normalize((0.4790, 0.4813, 0.4370), (0.2123, 0.2066, 0.2085))
])test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),# transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))transforms.Normalize((0.4790, 0.4813, 0.4370), (0.2123, 0.2066, 0.2085))
])# 2. 加载自定义数据集
full_dataset = datasets.ImageFolder(root=r"BengaliFishImages\fish_images",  transform=train_transform
)
train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])# 3. 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 4. 定义注意力机制# SE注意力机制模块
class SEBlock(nn.Module):def __init__(self, channel, reduction=16):super(SEBlock, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // reduction, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y.expand_as(x)# CBAM注意力机制模块
class ChannelAttention(nn.Module):def __init__(self, in_planes, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)self.relu1 = nn.ReLU()self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))out = avg_out + max_outreturn self.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)x = torch.cat([avg_out, max_out], dim=1)x = self.conv1(x)return self.sigmoid(x)class CBAMBlock(nn.Module):def __init__(self, channel, ratio=16, kernel_size=7):super(CBAMBlock, self).__init__()self.channel_attention = ChannelAttention(channel, ratio)self.spatial_attention = SpatialAttention(kernel_size)def forward(self, x):x = x * self.channel_attention(x)x = x * self.spatial_attention(x)return x# 5. 定义改进的CNN模型(可选择添加SE或CBAM注意力)
class ImprovedCNN(nn.Module):def __init__(self, num_classes=20, attention_type=None):super(ImprovedCNN, self).__init__()self.attention_type = attention_type# 第一个卷积块self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(2, 2)  # 128 -> 64if attention_type == 'se':self.att1 = SEBlock(32)elif attention_type == 'cbam':self.att1 = CBAMBlock(32)# 第二个卷积块self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(2)  # 64 -> 32if attention_type == 'se':self.att2 = SEBlock(64)elif attention_type == 'cbam':self.att2 = CBAMBlock(64)# 第三个卷积块self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.pool3 = nn.MaxPool2d(2)  # 32 -> 16if attention_type == 'se':self.att3 = SEBlock(128)elif attention_type == 'cbam':self.att3 = CBAMBlock(128)# 第四个卷积块self.conv4 = nn.Conv2d(128, 256, 3, padding=1)self.bn4 = nn.BatchNorm2d(256)self.relu4 = nn.ReLU()self.pool4 = nn.MaxPool2d(2)  # 16 -> 8if attention_type == 'se':self.att4 = SEBlock(256)elif attention_type == 'cbam':self.att4 = CBAMBlock(256)# 全连接层self.fc1 = nn.Linear(256 * 8 * 8, 512)self.dropout = nn.Dropout(p=0.5)self.fc2 = nn.Linear(512, num_classes)def forward(self, x):# 卷积块 1x = self.conv1(x)x = self.bn1(x)x = self.relu1(x)x = self.pool1(x)if self.attention_type is not None:x = self.att1(x)# 卷积块 2x = self.conv2(x)x = self.bn2(x)x = self.relu2(x)x = self.pool2(x)if self.attention_type is not None:x = self.att2(x)# 卷积块 3x = self.conv3(x)x = self.bn3(x)x = self.relu3(x)x = self.pool3(x)if self.attention_type is not None:x = self.att3(x)# 卷积块 4x = self.conv4(x)x = self.bn4(x)x = self.relu4(x)x = self.pool4(x)if self.attention_type is not None:x = self.att4(x)# 全连接层x = x.view(-1, 256 * 8 * 8)x = self.fc1(x)x = self.relu4(x)x = self.dropout(x)x = self.fc2(x)return x# 6. 基于预训练模型的分类器
def create_pretrained_model(model_name, num_classes=20, freeze_feature=True, attention_type=None):"""创建基于预训练模型的分类器Args:model_name: 预训练模型名称,如'resnet50', 'vgg16', 'mobilenet_v2'num_classes: 分类类别数freeze_feature: 是否冻结特征提取部分attention_type: 注意力类型,None, 'se' 或 'cbam'Returns:构建好的模型"""if model_name == 'resnet50':model = models.resnet50(pretrained=True)# 冻结特征提取部分if freeze_feature:for param in model.parameters():param.requires_grad = False# 添加注意力机制(可选)if attention_type == 'se':model.layer4[0].conv1 = nn.Sequential(model.layer4[0].conv1,SEBlock(512))elif attention_type == 'cbam':model.layer4[0].conv1 = nn.Sequential(model.layer4[0].conv1,CBAMBlock(512))# 替换最后的全连接层num_ftrs = model.fc.in_featuresmodel.fc = nn.Sequential(nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))elif model_name == 'vgg16':model = models.vgg16(pretrained=True)if freeze_feature:for param in model.features.parameters():param.requires_grad = False# 添加注意力机制(可选)if attention_type is not None:att_module = SEBlock(512) if attention_type == 'se' else CBAMBlock(512)model.features = nn.Sequential(*list(model.features.children()),att_module)# 替换分类器num_ftrs = model.classifier[6].in_featuresmodel.classifier[6] = nn.Sequential(nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))elif model_name == 'mobilenet_v2':model = models.mobilenet_v2(pretrained=True)if freeze_feature:for param in model.features.parameters():param.requires_grad = False# 添加注意力机制(可选)if attention_type is not None:att_module = SEBlock(1280) if attention_type == 'se' else CBAMBlock(1280)model.features = nn.Sequential(*list(model.features.children()),att_module)# 替换分类器num_ftrs = model.classifier[1].in_featuresmodel.classifier = nn.Sequential(nn.Dropout(0.2),nn.Linear(num_ftrs, 512),nn.ReLU(),nn.Dropout(0.5),nn.Linear(512, num_classes))else:raise ValueError(f"不支持的模型名称: {model_name}")return model# 7. 训练与测试函数(保持原有功能,略作调整)
def train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs):model.train()all_iter_losses = []iter_indices = []train_acc_history = []test_acc_history = []train_loss_history = []test_loss_history = []for epoch in range(epochs):running_loss = 0.0correct = 0total = 0for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()iter_loss = loss.item()all_iter_losses.append(iter_loss)iter_indices.append(epoch * len(train_loader) + batch_idx + 1)running_loss += iter_loss_, predicted = output.max(1)total += target.size(0)correct += predicted.eq(target).sum().item()if (batch_idx + 1) % 100 == 0:print(f'Epoch: {epoch+1}/{epochs} | Batch: {batch_idx+1}/{len(train_loader)} 'f'| 单Batch损失: {iter_loss:.4f} | 累计平均损失: {running_loss/(batch_idx+1):.4f}')epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100. * correct / totaltrain_acc_history.append(epoch_train_acc)train_loss_history.append(epoch_train_loss)# 测试阶段model.eval()test_loss = 0correct_test = 0total_test = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()_, predicted = output.max(1)total_test += target.size(0)correct_test += predicted.eq(target).sum().item()epoch_test_loss = test_loss / len(test_loader)epoch_test_acc = 100. * correct_test / total_testtest_acc_history.append(epoch_test_acc)test_loss_history.append(epoch_test_loss)scheduler.step(epoch_test_loss)print(f'Epoch {epoch+1}/{epochs} 完成 | 训练准确率: {epoch_train_acc:.2f}% | 测试准确率: {epoch_test_acc:.2f}%')plot_iter_losses(all_iter_losses, iter_indices)plot_epoch_metrics(train_acc_history, test_acc_history, train_loss_history, test_loss_history)return epoch_test_acc# 8. 绘图函数(保持不变)
def plot_iter_losses(losses, indices):plt.figure(figsize=(10, 4))plt.plot(indices, losses, 'b-', alpha=0.7, label='Iteration Loss')plt.xlabel('Iteration(Batch序号)')plt.ylabel('损失值')plt.title('每个 Iteration 的训练损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()def plot_epoch_metrics(train_acc, test_acc, train_loss, test_loss):epochs = range(1, len(train_acc) + 1)plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(epochs, train_acc, 'b-', label='训练准确率')plt.plot(epochs, test_acc, 'r-', label='测试准确率')plt.xlabel('Epoch')plt.ylabel('准确率 (%)')plt.title('训练和测试准确率')plt.legend()plt.grid(True)plt.subplot(1, 2, 2)plt.plot(epochs, train_loss, 'b-', label='训练损失')plt.plot(epochs, test_loss, 'r-', label='测试损失')plt.xlabel('Epoch')plt.ylabel('损失值')plt.title('训练和测试损失')plt.legend()plt.grid(True)plt.tight_layout()plt.show()# 9. 模型训练配置与执行
def main():# 选择模型类型: 'custom' (自定义CNN), 'resnet50', 'vgg16', 'mobilenet_v2'model_type = 'resnet50'  # 可更换为其他模型# 选择注意力机制: None, 'se', 'cbam'attention_type = 'cbam'  # 可更换为其他注意力类型或None# 训练参数epochs = 30  # 预训练模型通常需要更少的epochsnum_classes = 20# 初始化模型if model_type == 'custom':print(f"使用自定义CNN模型,注意力机制: {attention_type}")model = ImprovedCNN(num_classes=num_classes, attention_type=attention_type).to(device)else:print(f"使用预训练{model_type}模型,注意力机制: {attention_type}")# model = create_pretrained_model(#     model_name=model_type,#     num_classes=num_classes,#     freeze_feature=False,  # 设为True表示只训练顶层,False表示微调整个模型#     attention_type=attention_type# ).to(device)# 使用预训练模型,先冻结特征层model = create_pretrained_model(model_name=model_type,num_classes=num_classes,freeze_feature=True,  # 先冻结特征层,只训练顶层attention_type=None  # 禁用注意力).to(device)# 定义损失函数、优化器和学习率调度器criterion = nn.CrossEntropyLoss()# optimizer = optim.Adam(model.parameters(), lr=0.001)# scheduler = optim.lr_scheduler.ReduceLROnPlateau(#     optimizer, mode='min', patience=3, factor=0.5# )# 调整优化器和学习率optimizer = optim.Adam(model.parameters(), lr=1e-4)  # 更小的学习率scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5, min_lr=1e-6)# 开始训练print(f"开始训练...")final_accuracy = train(model, train_loader, test_loader, criterion, optimizer, scheduler, device, epochs)print(f"训练完成!最终测试准确率: {final_accuracy:.2f}%")# 保存模型model_filename = f"{model_type}_{attention_type if attention_type else 'no_att'}_fish_model.pth"torch.save(model.state_dict(), model_filename)print(f"模型已保存为: {model_filename}")if __name__ == "__main__":main()

@浙大疏锦行 

 

http://www.lryc.cn/news/579811.html

相关文章:

  • TCP、HTTP/1.1 和HTTP/2 协议
  • 怎么更改cursor字体大小
  • JavaEE初阶第七期:解锁多线程,从 “单车道” 到 “高速公路” 的编程升级(五)
  • ElasticSearch快速入门-1
  • MSPM0G3507学习笔记(一) 重置版:适配逐飞库的ti板环境配置
  • 服装零售企业跨区域运营难题破解方案
  • 如何将大型视频文件从 iPhone 传输到 PC
  • PoE 延长器——让网络部署更自由
  • 第十章:HIL-SERL 真实机器人训练实战
  • Docker拉取bladex 、 sentinel-dashboard
  • 【阿里巴巴JAVA开发手册】IDE的text file encoding设置为UTF-8; IDE中文件的换行符使用Unix格式,不要使用Windows格式。
  • Android BitmapRegionDecoder 详解
  • Java启动脚本
  • vue create 和npm init 创建项目对比
  • error MSB8041: 此项目需要 MFC 库。从 Visual Studio 安装程序(单个组件选项卡)为正在使用的任何工具集和体系结构安装它们。
  • React 渲染深度解密:从 JSX 到 DOM 的初次与重渲染全流程
  • 最快实现的前端灰度方案
  • 因果语言模型、自回归语言模型、仅解码器语言模型都是同一类模型
  • 同步(Synchronization)和互斥(Mutual Exclusion)关系
  • 【机器人】复现 DOV-SG 机器人导航 | 动态开放词汇 | 3D 场景图
  • (超详细)数据库项目初体验:使用C语言连接数据库完成短地址服务(本地运行版)
  • 敏捷开发在国际化团队管理中的落地
  • 二维码驱动的独立站视频集成方案
  • 碰一碰发视频源码搭建与定制化开发:支持OEM
  • 译码器Multisim电路仿真汇总——硬件工程师笔记
  • TensorFlow 安装使用教程
  • MySQL数据库----DML语句
  • 【2.4 漫画SpringBoot实战】
  • 【模糊集合】示例
  • vue-37(模拟依赖项进行隔离测试)