当前位置: 首页 > news >正文

CV 医学影像分类、分割、目标检测,之【肝脏分割】项目拆解

CV 医学影像分类、分割、目标检测,之【肝脏分割】项目拆解

    • 第1-4行:导入基础库
    • 第6-8行:科学计算和可视化
    • 第10-15行:图像处理库
    • 第18-26行:数据集构建函数
    • 第27-30行:路径设置和数据加载
    • 第33-41行:glob方式读取文件
    • 第44-65行:自定义Dataset类
    • 第68-86行:数据变换定义
    • 第88-93行:创建数据集和加载器
    • 第95-100行:数据可视化
    • 第103-105行:加载预训练模型
    • 第111-115行:GPU设置
    • 第117-121行:优化器设置
    • 第123行:损失函数
    • 第125-195行:训练函数核心逻辑
    • 第197-209行:训练循环
    • 第211-213行:加载保存的模型
    • 第215-220行:测试集预测
    • 第222-225行:预测结果处理
    • 第227-236行:可视化对比
    • 第238-243行:图像显示调试
    • 第245-247行:CPU模式预测
    • 第249-259行:最终可视化
    • 代码整体架构回顾

 


第1-4行:导入基础库

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data

问1:为什么要导入torch?
答1:PyTorch是深度学习框架,提供张量运算和自动求导功能。

问2:什么是张量?
答2:多维数组,是神经网络中数据的基本表示形式。

问3:nn是什么的缩写?
答3:Neural Network,包含构建神经网络的层和损失函数。

问4:为什么要单独导入functional?
答4:F包含无状态的函数操作,如激活函数、池化等。

问5:torch.utils.data的作用是什么?
答5:提供数据加载和批处理的工具类。

第6-8行:科学计算和可视化

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

问1:numpy和torch的关系是什么?
答1:numpy处理CPU数组,torch可在GPU运算,两者可互相转换。

问2:%matplotlib inline是什么语法?
答2:Jupyter魔法命令,让图表直接显示在notebook中。

问3:为什么不用plt.show()?
答3:inline模式自动显示,无需手动调用show。

第10-15行:图像处理库

import torchvision
from torchvision import transforms
import os
import glob
from PIL import Image
from tqdm import tqdm

问1:torchvision的核心功能是什么?
答1:提供计算机视觉的数据集、模型和图像变换。

问2:transforms具体做什么变换?
答2:图像预处理,如缩放、裁剪、归一化、数据增强。

问3:glob模块的作用?
答3:通过通配符模式查找文件路径。

问4:PIL和OpenCV的区别?
答4:PIL更轻量,适合基础图像IO;OpenCV功能更强但更重。

问5:tqdm是什么的缩写?
答5:阿拉伯语taqaddum(进展),用于显示进度条。

第18-26行:数据集构建函数

def make_dataset(root):imgs = []labels=[]n = len(os.listdir(root)) // 2  #因为数据集中一套训练数据包含有训练图和mask图,所以要除2

问1:为什么要除以2?
答1:每个样本有原图和掩码图两个文件。

问2:什么是mask图?
答2:标注图,标记每个像素属于哪个类别(肝脏或背景)。

问3://和/的区别?
答3://是整除运算,返回整数;/返回浮点数。

    for i in range(n):img = os.path.join(root, "%03d.png" % i)mask = os.path.join(root, "%03d_mask.png" % i)

问4:%03d是什么格式化语法?
答4:格式化为3位数字,不足用0填充(000, 001, 002…)。

问5:为什么用os.path.join而不是字符串拼接?
答5:自动处理不同操作系统的路径分隔符(/或\)。

第27-30行:路径设置和数据加载

root='E:/肝脏CT边缘分割/data/liver/liver/train/'
root_test='E:/肝脏CT边缘分割/data/liver/liver/val/'
train_imgs,train_labels=make_dataset(root)
test_imgs,test_labels=make_dataset(root_test)

问1:为什么分train和val?
答1:训练集用于学习,验证集用于评估泛化能力。

问2:val和test的区别?
答2:val用于调参,test用于最终评估,这里val实际当test用。

第33-41行:glob方式读取文件

pic=glob.glob('E:/肝脏CT边缘分割/data/liver/liver/train/*.png')

问1:*.png通配符匹配什么?
答1:匹配所有.png结尾的文件。

for i in range(len(pic)):if 'mask' in pic[i]:lable_k.append(pic[i])

问2:为什么要检查’mask’字符串?
答2:区分原始图像和标注图像文件。

第44-65行:自定义Dataset类

class LiverDataset(data.Dataset):def __init__(self, imgs,labels, imgs_transform=None,labels_transform=None):

问1:为什么要继承data.Dataset?
答1:PyTorch要求自定义数据集必须实现__len__和__getitem__方法。

问2:transform参数的作用?
答2:数据预处理和增强的函数管道。

    def __getitem__(self, index):imgs= self.imgs[index]labels=self.labels[index]

问3:__getitem__什么时候被调用?
答3:DataLoader迭代时自动调用,获取单个样本。

        imgs_open = Image.open(imgs)labels_open = Image.open(labels)

问4:Image.open返回什么类型?
答4:PIL.Image对象,还未加载到内存。

        imgs_f=transformer(imgs_open)labels_f=transformer_label(labels_open)

问5:transformer做了什么转换?
答5:PIL图像→调整尺寸→转为张量→归一化到[0,1]。

第68-86行:数据变换定义

train_transforms_img = transforms.Compose([transforms.Resize((128,128)),transforms.ToTensor(),transforms.RandomHorizontalFlip(0.2),
])

问1:Compose的作用是什么?
答1:将多个变换串联成管道,依次执行。

问2:(128,128)为什么要统一尺寸?
答2:神经网络要求批次内所有图像尺寸一致。

问3:ToTensor具体做什么?
答3:HWC格式→CHW格式,[0,255]→[0,1]。

问4:RandomHorizontalFlip(0.2)的0.2是什么?
答4:20%概率水平翻转,用于数据增强。

问5:为什么测试集没有RandomHorizontalFlip?
答5:测试时要保持一致性,不做随机增强。

第88-93行:创建数据集和加载器

dl_train=data.DataLoader(train_data,batch_size=16,shuffle=True)

问1:DataLoader的核心功能?
答1:批量加载、打乱顺序、并行处理、内存优化。

问2:batch_size=16意味着什么?
答2:每次迭代返回16个样本组成的批次。

问3:为什么要shuffle?
答3:打破数据顺序相关性,提高训练稳定性。

第95-100行:数据可视化

img,lable= next(iter(dl_train))
img.shape
a=img[0].permute(1,2,0).numpy()

问1:next(iter())是什么模式?
答1:获取迭代器的下一个元素,这里是第一批数据。

问2:permute(1,2,0)在做什么?
答2:CHW→HWC,因为matplotlib需要HWC格式。

问3:为什么要.numpy()?
答3:matplotlib不能直接显示tensor,需要numpy数组。

第103-105行:加载预训练模型

Net=torchvision.models.segmentation.fcn_resnet50(pretrained=False, progress=True, num_classes=2)

问1:FCN是什么的缩写?
答1:Fully Convolutional Network,全卷积网络。

问2:ResNet50的50指什么?
答2:网络有50层深度。

问3:为什么num_classes=2?
答3:二分类:背景和肝脏。

问4:pretrained=False的影响?
答4:随机初始化权重,不用ImageNet预训练权重。

第111-115行:GPU设置

if torch.cuda.is_available():model.to('cuda')

问1:cuda是什么?
答1:NVIDIA的并行计算平台,用GPU加速。

问2:为什么要检查is_available?
答2:防止没有GPU时程序崩溃。

问3:.to(‘cuda’)做了什么?
答3:将模型参数从CPU内存移到GPU显存。

第117-121行:优化器设置

optim = torch.optim.Adam(model.parameters(), lr=0.001)
exp_lr_scheduler = lr_scheduler.StepLR(optim, step_size=7, gamma=0.1)

问1:Adam相比SGD的优势?
答1:自适应学习率,结合动量和RMSprop。

问2:lr=0.001是什么单位?
答2:学习率,每次参数更新的步长系数。

问3:StepLR的作用?
答3:每7个epoch将学习率乘以0.1。

问4:为什么要衰减学习率?
答4:前期快速下降,后期精细调整。

第123行:损失函数

loss_fn = nn.CrossEntropyLoss()

问1:交叉熵损失适用于什么任务?
答1:多分类任务,这里是像素级二分类。

问2:CrossEntropyLoss包含什么操作?
答2:LogSoftmax + NLLLoss。

第125-195行:训练函数核心逻辑

def fit(epoch, model, trainloader, testloader):model.train()

问1:model.train()改变了什么?
答1:启用Dropout和BatchNorm的训练模式。

        y_pred = model(x)y_pred=y_pred['out']

问2:为什么要取[‘out’]?
答2:FCN模型返回字典,'out’是主输出,可能还有辅助输出。

        y= torch.squeeze(y).long()

问3:squeeze()去除什么?
答3:去除大小为1的维度。

问4:.long()转换成什么类型?
答4:64位整型,CrossEntropyLoss要求的标签类型。

        loss = loss_fn(y_pred,y)optim.zero_grad()loss.backward()optim.step()

问5:这四行的执行顺序为什么重要?
答5:前向传播→清零梯度→反向传播→更新参数,顺序错误会累积梯度。

        intersection = torch.logical_and(y, y_pred)union = torch.logical_or(y, y_pred)batch_iou = torch.sum(intersection) / torch.sum(union)

问6:IoU是什么的缩写?
答6:Intersection over Union,交并比。

问7:IoU衡量什么?
答7:预测区域和真实区域的重叠程度,分割质量指标。

    torch.save(static_dict,'./data/checkpoint/{}_train_acc_{}_test_acc_{}.pth'.format(epoch,round(epoch_acc, 3),round(epoch_test_acc,3)))

问8:.pth是什么格式?
答8:PyTorch的模型权重文件格式。

问9:为什么文件名包含准确率?
答9:方便识别最佳模型,无需打开文件查看。

第197-209行:训练循环

epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []

问1:epoch是什么概念?
答1:完整遍历一次训练集。

问2:50个epoch意味着什么?
答2:每个样本被模型"看"50次。

问3:为什么要记录loss和acc的列表?
答3:绘制学习曲线,诊断过拟合/欠拟合。

for epoch in range(epochs):epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,model,dl_train,dl_test)

问4:fit函数返回4个值的顺序是什么?
答4:训练损失→训练准确率→测试损失→测试准确率。

问5:为什么每个epoch都要测试?
答5:监控泛化能力,及时发现过拟合。

    train_loss.append(epoch_loss)train_acc.append(epoch_acc)

问6:append和extend的区别?
答6:append添加单个元素,extend添加可迭代对象的所有元素。

第211-213行:加载保存的模型

my_model = Net
PATH='./data/checkpoint/2_train_acc_0.99_test_acc_0.949.pth'
my_model.load_state_dict(torch.load(PATH))

问1:my_model = Net创建了新模型吗?
答1:只是引用,指向同一个对象。

问2:state_dict包含什么?
答2:所有层的权重和偏置参数。

问3:torch.load做了什么?
答3:反序列化,从文件恢复张量数据。

问4:为什么用load_state_dict而不是直接赋值?
答4:确保参数正确映射到模型结构。

问5:0.99训练准确率vs 0.949测试准确率说明什么?
答5:轻微过拟合,但泛化能力良好。

第215-220行:测试集预测

image, mask = next(iter(dl_test))
image=image.to('cuda')
pred_mask = my_model(image)
pred_mask=pred_mask['out']

问1:为什么只将image移到cuda?
答1:mask用于对比显示,在CPU上处理即可。

问2:模型在GPU,数据在CPU会怎样?
答2:报错,张量和模型必须在同一设备。

mask=torch.squeeze(mask)
mask.shape

问3:squeeze前后shape变化是什么?
答3:可能从[16,1,128,128]变为[16,128,128]。

问4:为什么mask需要squeeze?
答4:去除通道维度,便于可视化。

第222-225行:预测结果处理

pred_mask
pred_mask.shape
pred_mask=pred_mask.cpu()

问1:为什么要.cpu()?
答1:matplotlib不能直接处理GPU张量。

问2:.cpu()和.numpy()能连用吗?
答2:可以,.cpu().numpy()是常见模式。

问3:pred_mask.shape可能是什么?
答3:[16, 2, 128, 128],批次×类别×高×宽。

第227-236行:可视化对比

num=3
plt.figure(figsize=(10, 10))
for i in range(num):plt.subplot(num, 3, i*num+1)

问1:subplot(num, 3, inum+1)的参数含义?
答1:num行3列,第i
num+1个子图。

问2:inum+1的计算逻辑错了吗?
答2:是的,应该是i
3+1,这是代码bug。

问3:figsize=(10,10)的单位是什么?
答3:英寸,影响图像显示大小。

    plt.imshow(image[i].permute(1,2,0).cpu().numpy())

问4:permute为什么必须在cpu()之前?
答4:不是必须,但在GPU上permute更快。

    plt.imshow(mask[i].cpu().numpy())

问5:mask为什么不需要permute?
答5:mask已经是[H,W]格式,单通道。

    plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())

问6:argmax在做什么?
答6:选择概率最大的类别索引。

问7:axis=-1是哪个维度?
答7:最后一个维度,这里是通道维。

问8:为什么需要detach()?
答8:断开计算图,防止梯度计算。

问9:detach()和requires_grad=False的区别?
答9:detach创建新张量,requires_grad修改原张量属性。

第238-243行:图像显示调试

image, mask = next(iter(dl_test))
mask
plt.figure(figsize=(10, 10))
mask[1]
image=image*255
plt.imshow(image[1].permute(1,2,0).cpu().numpy())

问1:image*255的目的?
答1:[0,1]范围恢复到[0,255]范围。

问2:为什么单独打印mask和mask[1]?
答2:调试查看数据结构和值。

问3:这段代码有问题吗?
答3:有,matplotlib的imshow期望[0,1]或[0,255]的uint8。

第245-247行:CPU模式预测

image, mask = next(iter(dl_test))
my_model=my_model.cpu()
pred_mask = my_model(image)

问1:为什么要切换到CPU?
答1:可能是为了调试或没有GPU环境。

问2:CPU预测的缺点?
答2:速度慢,特别是批量数据。

第249-259行:最终可视化

plt.figure(figsize=(10, 10))
for i in range(num):plt.subplot(num, 3, i*num+1)image=image/255plt.imshow(image[i].permute(1,2,0).cpu().numpy())

问1:image/255这里有问题吗?
答1:有,image已经是[0,1],再除255会变成[0,0.004]。

问2:这个错误会导致什么?
答2:图像几乎全黑。

    mask=mask/255plt.imshow(mask[i].permute(1,2,0).cpu().numpy())

问3:mask/255的必要性?
答3:mask如果是0/1二值,除255后会接近全黑。

    plt.imshow(torch.argmax(pred_mask[i].permute(1,2,0), axis=-1).detach().numpy())

问4:这里为什么没有处理pred_mask的值范围?
答4:argmax返回0或1,matplotlib能正确显示。

代码整体架构回顾

问1:整个pipeline的数据流是什么?
答1:图像文件→PIL读取→Transform→Tensor→模型→预测→可视化。

问2:训练时的梯度流是什么?
答2:损失→反向传播→梯度累积→优化器更新→清零。

问3:这个代码的核心任务是什么?
答3:医学图像语义分割,识别CT图中的肝脏区域。

问4:为什么选择FCN-ResNet50?
答4:FCN适合像素级分类,ResNet解决深度网络退化。

问5:代码中有哪些明显的bug?
答5

  • subplot索引计算错误(i*num+1)
  • 图像归一化重复(image/255)
  • mask不必要的维度操作

问6:如何改进这个代码?
答6

  • 添加早停机制
  • 使用更好的数据增强
  • 添加学习率调度
  • 保存最佳模型而非所有模型
  • 使用Dice Loss替代交叉熵

如何选择合适的模型?

答14:权衡表:

  • 速度优先:LRASPP > FCN > DeepLabV3
  • 精度优先:DeepLabV3 > FCN > LRASPP
  • 显存限制:LRASPP < FCN < DeepLabV3

问1:改模型,为什么看起来只需要改一行?

答1:PyTorch的torchvision.models.segmentation中的模型都遵循相同的接口规范。

问2:这个接口规范是什么?

答2:输入是[B,C,H,W]的张量,输出是包含’out’键的字典。

# 1. 模型替换
Net=torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, progress=True, num_classes=2)# 2. 可能需要调整batch_size
dl_train=data.DataLoader(train_data,batch_size=8,shuffle=True)  # 16→8# 3. 可能需要调整学习率
optim = torch.optim.Adam(model.parameters(), lr=0.0005)  # 0.001→0.0005# 4. 可能需要改变调度器
from torch.optim.lr_scheduler import PolynomialLR
exp_lr_scheduler = PolynomialLR(optim, total_iters=epochs, power=0.9)# 5. 如果显存不足,还可以添加梯度累积
accumulation_steps = 2  # 每2步更新一次
http://www.lryc.cn/news/619336.html

相关文章:

  • windows常用的快捷命令
  • 机器学习实战·第三章 分类(2)
  • docker 容器内编译onnxruntime
  • git clone 支持在命令行临时设置proxy
  • CV 医学影像分类、分割、目标检测,之【腹腔多器官语义分割】项目拆解
  • 何解决PyCharm中pip install安装Python报错ModuleNotFoundError: No module named ‘json’问题
  • Video_AVI_Packet(2)
  • 基于RTSP|RTMP低延迟视频链路的多模态情绪识别系统构建与实现
  • 日志数据链路的 “搬运工”:Flume 分布式采集的组件分工与原理
  • 进阶向:Python编写自动化邮件发送程序
  • Jenkins一直无法启动,怎么办?
  • 论文分享 | Flashboom:一种声东击西攻击手段以致盲基于大语言模型的代码审计
  • 守拙以致远:个人IP的长青之道|创客匠人
  • Hive 创建事务表的方法
  • 自建知识库,向量数据库 体系建设(四)之文本向量与相似度计算——仙盟创梦IDE
  • java中list的api详细使用
  • 无人机航拍数据集|第15期 无人机人员目标检测YOLO数据集4923张yolov11/yolov8/yolov5可训练
  • pt-online-schema-change 全解析:MySQL 表结构变更的安全之道
  • clickhouse集群的安装与部署
  • Vue3 使用 echarts 甘特图(GanttChart)
  • Java -- Vector底层结构-- ArrayList和LinkedList的比较
  • C++主流string的使用
  • 工业元宇宙:迈向星辰大海的“玄奘之路”
  • C++ 类和对象4---(初始化列表,类型转化,static成员)
  • nuxt相比于vue的优点
  • java-泛型接口
  • C++多态:理解面向对象的“一个接口,多种实现”
  • 智能算法流程图在临床工作中的编程视角系统分析
  • 【算法】位运算经典例题
  • 论“证明的终点”:从“定义域 = 正确”看西方体系的自证困境