当前位置: 首页 > news >正文

day41

# 原始模型(2层卷积)
class OriginalCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc = nn.Linear(32*5*5, 10)
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        return self.fc(x.flatten(1))

# 修改1:减少卷积层(1层卷积)
class ReducedCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3)  # 通道数翻倍补偿层数减少
        self.fc = nn.Linear(32*13*13, 10)
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        return self.fc(x.flatten(1))

# 修改2:增加注意力(简单通道注意力)
class SECNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.se = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(),
                               nn.Linear(32, 8), nn.ReLU(),
                               nn.Linear(8, 32), nn.Sigmoid(), nn.Unflatten(1, (32,1,1)))
        self.fc = nn.Linear(32*5*5, 10)
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = x * self.se(x)  # 应用注意力
        x = torch.max_pool2d(x, 2)
        return self.fc(x.flatten(1))

# 训练函数(支持不同调度器)
def train(model, scheduler_type, epochs=5):
    model = model.to(device)
    opt = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    
    # 调度器
    if scheduler_type == 'cosine':
        scheduler = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
    elif scheduler_type == 'step':
        scheduler = optim.lr_scheduler.StepLR(opt, step_size=2, gamma=0.5)
    else:
        scheduler = None
    
    loss_list = []
    for e in range(epochs):
        model.train()
        loss_sum = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()
            loss = nn.CrossEntropyLoss()(model(x), y)
            loss.backward()
            opt.step()
            loss_sum += loss.item()
        loss_list.append(loss_sum / len(loader))
        if scheduler: scheduler.step()
        print(f"Epoch {e+1}, Loss: {loss_list[-1]:.4f}")
    return loss_list

# 对比实验
def compare():
    models = {'Original': OriginalCNN(), 'Reduced': ReducedCNN(), 'SE': SECNN()}
    schedulers = ['cosine', 'step', None]
    
    results = {}
    for name, model in models.items():
        results[name] = {}
        for sched in schedulers:
            s_name = 'Cosine' if sched=='cosine' else 'StepLR' if sched=='step' else 'None'
            print(f"\n=== {name} + {s_name} ===")
            results[name][s_name] = train(model, sched, epochs=5)
    
    # 可视化
    plt.figure(figsize=(10, 6))
    for name, scheds in results.items():
        for s_name, loss in scheds.items():
            plt.plot(loss, label=f"{name} + {s_name}")
    plt.title("Loss Comparison")
    plt.legend()
    plt.grid(True)
    plt.show()

if __name__ == "__main__":
    compare()
@浙大疏锦行https://blog.csdn.net/weixin_45655710

http://www.lryc.cn/news/574799.html

相关文章:

  • 深入理解 BOM:浏览器对象模型详解
  • IoTDB的基本概念及常用命令
  • 【css】增强 CSS 的复用性与灵活性的Mixins
  • ArkUI-X通过Stage模型开发Android端应用指南(二)
  • 【软考高级系统架构论文】### 论软件系统架构评估
  • linux grep的一些坑
  • 接口自动化测试之 pytest 接口关联框架封装
  • Unity_导航操作(鼠标控制人物移动)_运动动画
  • matplotilb实现对MACD的实战
  • SQL关键字三分钟入门:UPDATE —— 修改数据
  • Camera Sensor接口协议全解析(五)SLVS-EC接口深度解析
  • Stable Diffusion 项目实战落地:打造完美海报的秘密武器 第二篇:边缘柔化、蒙版处理与图生图技术大揭秘!
  • 如何通过nvm切换本地node环境详情教程(已装过node.js更改成nvm)
  • 2025.6.24总结
  • useState为异步,测试一下编码时候是否考虑?
  • Unity反射机制
  • mongoose解析http字段值
  • Spring Boot 的Banner的介绍和设置
  • 中科米堆3D扫描逆向建模方案:汽车轮毂三维扫描抄数建模
  • elk+filebeat收集springboot项目日志
  • iwebsec靶场-文件上传漏洞
  • 串口助手实例
  • lib61850 代码结构与系统架构深度分析
  • 鸿蒙OH南向开发 轻量系统内核(LiteOS-M)【异常调测】
  • 针对基于深度学习的侧信道分析(DLSCA)进行超参数的贝叶斯优化
  • vue 3 计算器
  • Nginx性能优化配置指南
  • 6.24_JAVA_微服务_Elasticsearch搜索
  • vscode + Jlink 一键调试stm32 单片机程序(windows系统版)
  • Git简介和常用命令