模板代码概述
1. 数据集函数
class MyDataset(Dataset):def __init__(self, img_id_list, IMG_SIZE, mode='train', augmentation=False):"""传参,定义参数1. 数据集列表,- 本地数据,文件名/图片名- API,图片ID2. 图片读取尺寸3. 训练模式or推理模式4. 是否做Data augmentation..."""passdef __getitem__(self, idx):"""读取下一个样本1. 读取本地图片,或读API接口获取base64格式图片2. 预处理, 如变换图片尺寸3. 若训练集,读取Mask图片4. Data augmentation"""passdef __len__(self):"""定义样本个数"""pass
def prepare_trainset():"""1. 切分数据集,训练集/验证集2. 定义MyDataset训练集、MyDataset验证集3. 定义Pytorch的DataLoadertrain_dl = DataLoader(train_dataset,batch_size=16,shuffle=True,#sampler=sampler,num_workers=8,drop_last=True)val_dl = DataLoader(val_dataset,batch_size=16,shuffle=False,#sampler=sampler,num_workers=8,drop_last=True)"""pass
2. Utils函数
3. 分割的评估函数

4. 训练脚本
def run_training():"""training pipline1. 读取network- 加载预训练模型- 定义训练全部层的参数/哪几层参数- 定义学习率/为每一层定义学习率- 定义优化函数optimizer、学习率变化方案scheduler- 2. 训练N_EPOCH次迭代,每一个迭代内:- 用DataLoader循环读取训练集上每一个batch数据(N个图片、N个mask)- 将N个图片传入network,输出模型最后一层的预测(sigmoid概率)- 计算这个batch上的loss、metric,并存下来- 反向传播,更新参数(.backward())(是否梯度累加)- 计算所有batch上loss、metric的总体均值,代表这个EPOCH- 用DataLoader循环读取验证集上每一个batch数据,与以上操作相似,计算验证集上的loss、metric,用于决定哪一个EPOCH停止训练- 更新logging、保存checkpoint"""pass
5. Unet介绍
