当前位置: 首页 > article >正文

Python打卡训练营Day43

DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

数据集地址:Lung Nodule Malignancy 肺结核良恶性判断 

进阶:并拆分成多个文件

import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 读标签并映射 0/1
df = pd.read_csv('archive/malignancy.csv')# 2. 按 patch_id 划 train/val
ids    = df['NoduleID'].values
labels = df['malignancy'].values
train_ids, val_ids = train_test_split(ids, test_size=0.2, random_state=42, stratify=labels
)
train_df = df[df['NoduleID'].isin(train_ids)].reset_index(drop=True)
val_df   = df[df['NoduleID'].isin(val_ids)].reset_index(drop=True)# 3. Dataset:多页 TIFF 按页读取
class LungTBDataset(Dataset):def __init__(self, tif_path, df, transform=None):self.tif_path = tif_pathself.df = dfself.transform = transformdef __len__(self):return len(self.df)def __getitem__(self, idx):row = self.df.iloc[idx]pid = int(row['NoduleID'])label = int(row['malignancy'])try:with Image.open(self.tif_path) as img:# 检查 pid 是否超出实际帧数total_pages = sum(1 for _ in ImageSequence.Iterator(img))if pid >= total_pages:pid = total_pages - 1  # 取最后一帧img.seek(pid)img = img.convert('RGB')except Exception as e:# 返回黑色占位图img = Image.new('RGB', (224, 224), (0, 0, 0))if self.transform:img = self.transform(img)return img, label# 4. 变换 & DataLoader
transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std =[0.229,0.224,0.225])
])
train_ds = LungTBDataset('archive/ct_tiles.tif', train_df, transform)
val_ds   = LungTBDataset('archive/ct_tiles.tif',   val_df, transform)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=0, pin_memory=True)# 5. 定义简单 CNN(3层卷积 + 全连接)
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层self.conv1 = nn.Sequential(nn.Conv2d(3, 32, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2)  # 224->112)self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2)  # 112->56)# 最后一层卷积,用于 Grad-CAMself.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(2)  # 56->28)# 全连接分类器self.fc = nn.Sequential(nn.Flatten(),nn.Linear(128 * 28 * 28, 256),nn.ReLU(inplace=True),nn.Linear(256, 2))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)      # 保留这一层的输出作 CAMx = self.fc(x)return xdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)# 6. 训练 + 验证循环
num_epochs = 5
for epoch in range(num_epochs):# 训练model.train()running_loss = 0for imgs, labs in train_loader:imgs, labs = imgs.to(device), labs.to(device)optimizer.zero_grad()outputs = model(imgs)loss    = criterion(outputs, labs)loss.backward()optimizer.step()running_loss += loss.item() * imgs.size(0)epoch_loss = running_loss / len(train_ds)# 验证model.eval()correct = 0with torch.no_grad():for imgs, labs in val_loader:imgs, labs = imgs.to(device), labs.to(device)preds = model(imgs).argmax(dim=1)correct += (preds == labs).sum().item()val_acc = correct / len(val_ds)print(f'Epoch {epoch+1}/{num_epochs}  Loss={epoch_loss:.4f}  ValAcc={val_acc:.4f}')# 7. 简易 Grad-CAM
class GradCAM:def __init__(self, model, target_conv):self.model = modelself.target_conv = target_convself.grad    = Noneself.activation = None# 注册 hooktarget_conv.register_forward_hook(self._forward)target_conv.register_backward_hook(self._backward)def _forward(self, module, inp, outp):self.activation = outp.detach()def _backward(self, module, grad_in, grad_out):self.grad = grad_out[0].detach()def __call__(self, x, class_idx=None):self.model.zero_grad()out = self.model(x)if class_idx is None:class_idx = out.argmax(dim=1).item()loss = out[0, class_idx]loss.backward()# 计算权重weights = self.grad.mean(dim=(2,3))  # (1,C)cam = (weights.view(-1,1,1) * self.activation[0]).sum(dim=0)cam = torch.relu(cam)cam -= cam.min()cam /= cam.max()return cam.cpu().numpy()# 8. 随机选一张验证图做可视化
model.eval()
imgs, labs = next(iter(val_loader))
img, lab = imgs[0:1].to(device), labs[0].item()# 以 conv3 的最后 Conv2d 为 target
# conv3 是 Sequential,取其中的第0层 Conv2d
target_layer = model.conv3[0]
gradcam = GradCAM(model, target_layer)
heatmap = gradcam(img)  # (28,28)# 上采样到 224×224
heatmap = np.uint8(255 * heatmap)
heatmap = Image.fromarray(heatmap).resize((224,224), resample=Image.BILINEAR)
heatmap = np.array(heatmap) / 255.0# 反归一化 & 可视化叠加
inv_norm = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],std =[1/0.229,       1/0.224,       1/0.225]
)
img_show = inv_norm(img[0]).permute(1,2,0).cpu().numpy()
img_show = np.clip(img_show, 0, 1)plt.figure(figsize=(8,4))
plt.subplot(1,2,1)
plt.imshow(img_show)
plt.title(f'Label={lab}')
plt.axis('off')plt.subplot(1,2,2)
plt.imshow(img_show, alpha=0.6)
plt.imshow(heatmap, cmap='jet', alpha=0.4)
plt.title('Grad-CAM')
plt.axis('off')
plt.tight_layout()
plt.show()

代码没问题但跑的很慢不知道啥原因。

浙大疏锦行-CSDN博客

http://www.lryc.cn/news/2397396.html

相关文章:

  • PHP7+MySQL5.6 查立得轻量级公交查询系统
  • 如何做好一个决策:基于 Excel的决策树+敏感性分析应用(针对多个变量)
  • Azure DevOps 管道部署系列之一本地服务器
  • DeepSeekMath:突破开放式语言模型中数学推理能力的极限
  • QT 5.15.2 程序中文乱码
  • Celery简介
  • StarRocks物化视图
  • vue2源码解析——响应式原理
  • 基于 GitLab CI + Inno Setup 实现 Windows 程序自动化打包发布方案
  • 做好 4个基本动作,拦住性能优化改坏原功能的bug
  • 【HarmonyOS 5】针对 Harmony-Cordova 性能优化,涵盖原生插件开发、线程管理和资源加载等关键场景
  • 零基础认知企业级数据分析平台如何落实数据建模(GAI)
  • web架构2------(nginx多站点配置,include配置文件,日志,basic认证,ssl认证)
  • AI 的早期萌芽?用 Swift 演绎约翰·康威的「生命游戏」
  • 【DBA】MySQL经典250题,改自OCP英文题库中文版(2025完整版)
  • Cursor 编辑器介绍:专为程序员打造的 AI 编程 IDE
  • go|channel源码分析
  • 【大模型学习】项目练习:视频文本生成器
  • 【Rust】Rust获取命令行参数以及IO操作
  • 【Redis】Zset 有序集合
  • manus对比ChatGPT-Deep reaserch进行研究类论文数据分析!谁更胜一筹?
  • 【 HarmonyOS 5 入门系列 】鸿蒙HarmonyOS示例项目讲解
  • AWS Transit Gateway实战:构建DMZ隔离架构,实现可控的网络互通
  • 用提示词写程序(3),VSCODE+Claude3.5+deepseek开发edge扩展插件V2
  • 栈与队列:数据结构的有序律动
  • 初识PS(Photoshop)
  • go语言的GMP(基础)
  • 电路图识图基础知识-高、低压供配电系统电气系统的继电自动装置(十三)
  • JDK21深度解密 Day 9:响应式编程模型重构
  • 在 Linux 服务器上无需 sudo 权限解压/打包 .7z 的方法(实用命令)