当前位置: 首页 > news >正文

DAY 54 Inception网络及其思考

作业:一次稍微有点学术感觉的作业:

  1. 对inception网络在cifar10上观察精度
  2. 消融实验:引入残差机制和cbam模块分别进行消融
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torchvision import datasets, transforms
    from torch.utils.data import DataLoader
    import torchvision.models as models
    import time
    import copy# 设置随机种子确保可复现性
    torch.manual_seed(42)# 数据预处理
    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])# 加载CIFAR-10数据集
    trainset = datasets.CIFAR10(root='./data', train=True,download=True, transform=transform_train)
    trainloader = DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)testset = datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)
    testloader = DataLoader(testset, batch_size=100,shuffle=False, num_workers=2)# 定义CBAM模块
    class ChannelAttention(nn.Module):def __init__(self, in_channels, reduction_ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False),nn.ReLU(),nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False))def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outreturn torch.sigmoid(out)class SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avg_out, max_out], dim=1)out = self.conv(out)return torch.sigmoid(out)class CBAM(nn.Module):def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):super(CBAM, self).__init__()self.channel_att = ChannelAttention(in_channels, reduction_ratio)self.spatial_att = SpatialAttention(kernel_size)def forward(self, x):x = x * self.channel_att(x)x = x * self.spatial_att(x)return x# 定义Inception模块
    class InceptionModule(nn.Module):def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):super(InceptionModule, self).__init__()# 1x1卷积分支self.branch1 = nn.Sequential(nn.Conv2d(in_channels, ch1x1, kernel_size=1),nn.BatchNorm2d(ch1x1),nn.ReLU(True),)# 1x1卷积 -> 3x3卷积分支self.branch2 = nn.Sequential(nn.Conv2d(in_channels, ch3x3red, kernel_size=1),nn.BatchNorm2d(ch3x3red),nn.ReLU(True),nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, padding=1),nn.BatchNorm2d(ch3x3),nn.ReLU(True),)# 1x1卷积 -> 5x5卷积分支self.branch3 = nn.Sequential(nn.Conv2d(in_channels, ch5x5red, kernel_size=1),nn.BatchNorm2d(ch5x5red),nn.ReLU(True),nn.Conv2d(ch5x5red, ch5x5, kernel_size=5, padding=2),nn.BatchNorm2d(ch5x5),nn.ReLU(True),)# 3x3池化 -> 1x1卷积分支self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),nn.Conv2d(in_channels, pool_proj, kernel_size=1),nn.BatchNorm2d(pool_proj),nn.ReLU(True),)def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)branch4 = self.branch4(x)outputs = [branch1, branch2, branch3, branch4]return torch.cat(outputs, 1)# 基础Inception网络
    class BasicInception(nn.Module):def __init__(self, num_classes=10):super(BasicInception, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(64)self.relu1 = nn.ReLU(True)self.inception1 = InceptionModule(64, 32, 48, 64, 8, 16, 16)self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.inception2 = InceptionModule(128, 64, 64, 96, 16, 48, 32)self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.inception3 = InceptionModule(240, 96, 48, 104, 8, 24, 32)self.global_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(256, num_classes)def forward(self, x):x = self.relu1(self.bn1(self.conv1(x)))x = self.inception1(x)x = self.pool1(x)x = self.inception2(x)x = self.pool2(x)x = self.inception3(x)x = self.global_pool(x)x = x.view(x.size(0), -1)x = self.fc(x)return x# 带残差连接的Inception网络
    class ResidualInception(nn.Module):def __init__(self, num_classes=10):super(ResidualInception, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(64)self.relu1 = nn.ReLU(True)self.inception1 = InceptionModule(64, 32, 48, 64, 8, 16, 16)self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.inception2 = InceptionModule(128, 64, 64, 96, 16, 48, 32)self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.inception3 = InceptionModule(240, 96, 48, 104, 8, 24, 32)self.global_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(256, num_classes)# 残差连接的1x1卷积self.res_conv1 = nn.Conv2d(64, 128, kernel_size=1, stride=2)self.res_conv2 = nn.Conv2d(128, 240, kernel_size=1, stride=2)def forward(self, x):identity = xx = self.relu1(self.bn1(self.conv1(x)))x = self.inception1(x)identity = self.res_conv1(identity)x += identityx = F.relu(x)x = self.pool1(x)identity = xx = self.inception2(x)identity = self.res_conv2(identity)x += identityx = F.relu(x)x = self.pool2(x)x = self.inception3(x)x = self.global_pool(x)x = x.view(x.size(0), -1)x = self.fc(x)return x# 带CBAM模块的Inception网络
    class CBAMInception(nn.Module):def __init__(self, num_classes=10):super(CBAMInception, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)self.bn1 = nn.BatchNorm2d(64)self.relu1 = nn.ReLU(True)self.inception1 = InceptionModule(64, 32, 48, 64, 8, 16, 16)self.cbam1 = CBAM(128)self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.inception2 = InceptionModule(128, 64, 64, 96, 16, 48, 32)self.cbam2 = CBAM(240)self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.inception3 = InceptionModule(240, 96, 48, 104, 8, 24, 32)self.cbam3 = CBAM(256)self.global_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(256, num_classes)def forward(self, x):x = self.relu1(self.bn1(self.conv1(x)))x = self.inception1(x)x = self.cbam1(x)x = self.pool1(x)x = self.inception2(x)x = self.cbam2(x)x = self.pool2(x)x = self.inception3(x)x = self.cbam3(x)x = self.global_pool(x)x = x.view(x.size(0), -1)x = self.fc(x)return x# 训练函数
    def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs - 1}')print('-' * 10)# 每个epoch都有一个训练和验证阶段model.train()  # 训练模式running_loss = 0.0running_corrects = 0# 迭代训练数据for inputs, labels in trainloader:inputs = inputs.to(device)labels = labels.to(device)# 零梯度optimizer.zero_grad()# 前向传播# 只有在训练时才跟踪历史with torch.set_grad_enabled(True):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 后向传播 + 优化loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)scheduler.step()epoch_loss = running_loss / len(trainset)epoch_acc = running_corrects.double() / len(trainset)print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 深拷贝模型if epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')print(f'Best train Acc: {best_acc:4f}')# 加载最佳模型权重model.load_state_dict(best_model_wts)return model# 测试函数
    def evaluate_model(model):model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in testloader:inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}%')return accuracy# 设置设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")# 实验结果记录
    results = {}# 实验1:基础Inception网络
    print("===== 实验1: 基础Inception网络 =====")
    model_basic = BasicInception().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model_basic.parameters(), lr=0.001, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)model_basic = train_model(model_basic, criterion, optimizer, scheduler, num_epochs=10)
    basic_accuracy = evaluate_model(model_basic)
    results["Basic Inception"] = basic_accuracy# 实验2:带残差连接的Inception网络
    print("\n===== 实验2: 带残差连接的Inception网络 =====")
    model_residual = ResidualInception().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model_residual.parameters(), lr=0.001, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)model_residual = train_model(model_residual, criterion, optimizer, scheduler, num_epochs=10)
    residual_accuracy = evaluate_model(model_residual)
    results["Residual Inception"] = residual_accuracy# 实验3:带CBAM模块的Inception网络
    print("\n===== 实验3: 带CBAM模块的Inception网络 =====")
    model_cbam = CBAMInception().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model_cbam.parameters(), lr=0.001, momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)model_cbam = train_model(model_cbam, criterion, optimizer, scheduler, num_epochs=10)
    cbam_accuracy = evaluate_model(model_cbam)
    results["CBAM Inception"] = cbam_accuracy# 输出实验结果对比
    print("\n===== 实验结果对比 =====")
    print("{:<20} {:<10}".format("模型", "准确率 (%)"))
    print("-" * 30)
    for model, acc in results.items():print("{:<20} {:<10.2f}".format(model, acc))    

http://www.lryc.cn/news/570991.html

相关文章:

  • 人物点评: 马云的野心(zz)
  • 狱搜导航-个性化导航自定义导航网站,搜索导航,简洁清晰大气,支持各种自定义
  • 中国第二大传感器企业宝座易主!14家公司新上市!国产传感器风起云涌!(附最新市值排名榜单)
  • Codeigniter 4基础教程(1)-- Wamp+CodeIgniter 4以及helloworld
  • 计算机主机漏电,电脑主机箱漏电六大原因和解决方法
  • Windows安装与配置Git cz (commitizen)
  • 包含15个APP客户端UI界面的psd适用于餐厅咖啡店面包店快餐店
  • d3dx9_42.dll丢失的解决方法-d3dx9_42.dll缺失下载方法
  • InstallShield 2010打包安装程序,安装完成后执行某个程序
  • 送两本《ECharts数据可视化:入门、实战与进阶》
  • 蓝屏错误代码0x0000007E的解决方法及编程示例
  • linux内核(二)内核移植(DM365-DM368开发攻略——linux-2.6.32的移植)
  • Internet Explorer 已不再尝试还原此网站。该网站看上去仍有问题。
  • 关于部分网页打不开的解决方法详解
  • 学生学籍管理系统页面源代码html_110.188.251:8088四川大学锦江学院教务管理系统...
  • 盗版xp成功验证成正版,享受正版增值服务!—— 完美解决XP被黑和盗版提示
  • 英语学习漫谈
  • 导航条——flash导航条
  • 音频毒品
  • 002微信小程序模板与配置
  • 国开电大 管理心理学 形考任务1-4
  • 电阻篇---上拉电阻
  • 解决安装程序无法初始化。请下载Adobe
  • 设置hosts文件,屏蔽百度和谷歌的网页广告。
  • 内含干货PPT下载|一站式数据管理DMS关键技术解读
  • 地质地貌卫星影像集锦(一 典型地貌篇)
  • 经典DOS怀旧游戏-《炎龙骑士团》系列
  • 《Java编程思想》读书笔记:第十二章
  • ABAP 格式与JSON和XML格式互转
  • 微信小程序—跳一跳,Android游戏助手(外挂)使用教程,【吐血整理】