当前位置: 首页 > news >正文

使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题

在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著


1. 小尺寸图像如何加剧样本不均衡?

(1) 局部裁剪导致类别分布偏差
  • 问题:遥感图像中某些类别(如道路、建筑)可能稀疏分布。小尺寸裁剪后,部分训练样本可能完全不含某些类别(例如一块纯农田的补丁),导致模型对这些类别缺乏学习机会。
  • 示例
    • 原图中“道路”占比5%,若裁剪为 256x256 的小图,部分小图中可能完全无道路像素。
    • 极端情况下,某些类别可能仅在极少数小图中出现,形成“长尾分布”。
(2) 批次内类别覆盖不足
  • 问题:小尺寸图像的批训练(batch training)中,若单个批次内缺少某些类别,梯度更新会偏向多数类。
  • 示例:若一个batch中80%的补丁以“植被”为主,模型会倾向于将模糊区域预测为植被。
(3) 像素级不平衡放大
  • 问题:即使原图类别均衡,小尺寸裁剪可能导致局部像素比例失衡。
    • 例如,原图中“水体”占10%,但某个小图中水体可能占90%(河流区域)或0%(干旱区域)。

2. 样本不均衡的典型影响

  • 模型偏向多数类:对高频类别(如植被、背景)过拟合,低频类别(如车辆、道路)漏检。
  • 边界模糊:模型对类别交界处的预测置信度低,导致分割边缘不连续。
  • 评估指标失真:全局指标(如整体准确率)虚高,但关键类别(如灾害损毁区域)的IoU/F1值极低。

3. 针对小尺寸图像的解决方案

(1) 数据层面的优化
  • 定向裁剪(Guided Cropping)
    • 根据类别分布优先裁剪包含稀有类别的小图。
    • 工具:使用滑动窗口统计每个候选补丁的类别比例,筛选包含目标类别的补丁。
  • 过采样(Oversampling)
    • 对包含稀有类别的小图增加采样概率。
    • 例如:若某小图中含“道路”,则其在训练集中的出现次数增加3倍。
  • 数据增强强化
    • 对小图中稀有类别区域进行针对性增强:
      • 局部旋转、缩放、亮度调整(避免全局变换导致稀有目标失真)。
      • 复制-粘贴增强(Copy-Paste):将稀有目标粘贴到其他背景中(如将车辆粘贴到农田补丁上)。
(2) 损失函数设计
  • 加权交叉熵(Weighted Cross-Entropy)
    • 根据类别像素频率反向加权,例如权重与类别频率成反比:
      weight = 1 / (class_freq + epsilon)  # 防止除零
      
  • Focal Loss
    • 抑制易分类样本(如背景)的损失贡献,聚焦难样本(如小目标):
      loss = -α * (1 - p)^γ * log(p)  # α平衡类别,γ聚焦难样本
      
  • Dice Loss / Tversky Loss
    • 直接优化分割重叠区域(IoU),对类别不平衡更鲁棒:
      Dice Loss = 1 - (2*|X∩Y|) / (|X| + |Y|)
      Tversky Loss = 1 - (|X∩Y|) / (|X∩Y| + α|X-Y| + β|Y-X|)  # 调整α,β权衡假阳/假阴
      
(3) 模型架构改进
  • 上下文感知模块
    • 使用空洞卷积(Dilated Convolution)或注意力机制(如SE Block、Non-local Networks),增强模型对稀疏目标的捕捉能力。
  • 多尺度特征融合
    • 通过金字塔池化(PSPNet)或U-Net++结构,融合不同尺度的特征,缓解因小尺寸输入丢失的上下文信息。
  • 辅助监督(Auxiliary Supervision)
    • 在中间层添加辅助损失函数,强制模型关注细粒度特征。
(4) 训练策略调整
  • 小批次大迭代
    • 使用小batch size但增加迭代次数,确保稀有类别在多个epoch中被充分学习。
  • 动态类别权重
    • 根据当前batch内的类别分布实时调整损失权重。
  • 困难样本挖掘(Hard Example Mining)
    • 在每个epoch后,筛选对稀有类别预测误差大的样本,下一轮训练中增加其采样概率。

4. 实验验证建议

  • 监控类别指标:除了整体准确率,跟踪每个类别的IoU、F1-score。
  • 可视化错误样本:检查模型在稀有类别上的失败案例,针对性优化数据或模型。
  • 消融实验:对比不同损失函数、数据增强策略的效果。

小尺寸图像训练会放大样本不均衡问题,但通过定向数据采样、损失函数优化、模型结构改进三者结合,可显著缓解影响。关键是根据任务特点(如目标大小、类别分布)选择组合策略,例如:

  • 稀疏小目标:Focal Loss + Copy-Paste增强 + 空洞卷积。
  • 长尾分布:加权交叉熵 + 过采样 + 动态类别权重。

在 PyTorch 中,虽然没有直接解决语义分割样本不均衡的“万能模块”,但可以通过组合现有模块社区成熟库高效实现解决方案。


1. 数据层面:加权采样与增强

(1) 加权随机采样(WeightedRandomSampler)

PyTorch 内置 WeightedRandomSampler,可对包含稀有类别的图像补丁过采样:

import numpy as npdef compute_weight_for_patch(patch):image, mask = patch# 假设 mask 是一个二维数组,每个像素值表示类别标签# 计算每个类别的像素数量class_counts = np.bincount(mask.flatten())# 计算总像素数量total_pixels = mask.size# 计算每个类别的比例class_ratios = class_counts / total_pixels# 计算所有类别的权重class_weights = 1.0 / (class_ratios + 1e-6)  # 避免除以零,添加一个小的常数# 应用 sigmoid 函数class_weights = 1.0 / (1.0 + np.exp(-class_weights))# 计算样本的权重sample_weight = np.sum(class_weights)print("Total samples weights:", sample_weight)return class_weights
from torch.utils.data import WeightedRandomSampler# 假设 dataset 返回 (image, mask),且每个样本有一个权重 weight
weights = [compute_weight_for_patch(patch) for patch in dataset]  # 根据补丁中稀有类别比例计算权重
sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
(2) 数据增强库(Albumentations)

Albumentations 提供针对分割任务的增强,支持对特定类别区域增强:

import albumentations as Atransform = A.Compose([A.RandomCrop(256, 256),A.OneOf([A.RandomRotate90(),A.HorizontalFlip(),A.VerticalFlip()]),A.RandomBrightnessContrast(p=0.5),# 对特定类别区域增强(如仅增强“车辆”区域)A.RandomCropNearBBox(p=0.5, max_part_shift=0.3)
])

2. 损失函数:直接调用社区实现

(1) Focal Loss

使用 torchvision.ops 或第三方库:

# 使用 torchvision(需 0.10+ 版本)
from torchvision.ops import sigmoid_focal_lossloss = sigmoid_focal_loss(outputs, targets, alpha=0.25, gamma=2, reduction="mean")# 或自定义多类别 Focal Loss
class FocalLoss(nn.Module):def __init__(self, alpha=0.25, gamma=2):super().__init__()self.alpha = alphaself.gamma = gammadef forward(self, inputs, targets):ce_loss = F.cross_entropy(inputs, targets, reduction="none")pt = torch.exp(-ce_loss)loss = self.alpha * (1 - pt) ** self.gamma * ce_lossreturn loss.mean()
(2) Dice Loss

社区标准实现(或使用 segmentation_models_pytorch 库):

class DiceLoss(nn.Module):def __init__(self, smooth=1e-6):super().__init__()self.smooth = smoothdef forward(self, inputs, targets):inputs = F.softmax(inputs, dim=1)targets = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2)intersection = (inputs * targets).sum()union = inputs.sum() + targets.sum()dice = (2 * intersection + self.smooth) / (union + self.smooth)return 1 - dice
(3) 直接调用 segmentation_models_pytorch 损失函数
import segmentation_models_pytorch as smploss = smp.losses.DiceLoss(mode="multiclass", classes=[0, 1, 2])  # 指定关注类别
loss = smp.losses.FocalLoss(mode="multiclass", normalized=True)   # 归一化版本

3. 模型层面:集成注意力与多尺度模块

(1) 使用预建模型库

segmentation_models_pytorch(SMP)提供即用的模型和模块:

import segmentation_models_pytorch as smpmodel = smp.Unet(encoder_name="resnet34",encoder_weights="imagenet",in_channels=3,classes=5,decoder_attention_type="scse",  # 添加空间-通道注意力
)
(2) 空洞卷积(Dilated Convolution)

直接使用 PyTorch 的 Conv2d 实现:

class DilatedConvBlock(nn.Module):def __init__(self, in_channels, out_channels, dilation_rate=2):super().__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation_rate, dilation=dilation_rate)self.norm = nn.BatchNorm2d(out_channels)self.act = nn.ReLU()def forward(self, x):return self.act(self.norm(self.conv(x)))# 在 U-Net 的 decoder 中插入空洞卷积块

4. 类别权重计算工具

(1) 自动计算类别权重
from sklearn.utils.class_weight import compute_class_weight# 统计训练集所有像素的类别分布
class_counts = np.bincount(all_pixel_labels.flatten())
class_weights = compute_class_weight(class_weight="balanced", classes=np.arange(num_classes), y=all_pixel_labels.flatten()
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)# 在损失函数中使用
criterion = nn.CrossEntropyLoss(weight=class_weights)

5. 完整 Pipeline 示例

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import segmentation_models_pytorch as smp
import albumentations as A# 1. 定义数据集和采样器
dataset = YourDataset(transform=albumentations_transform)
weights = compute_patch_weights(dataset)  # 根据补丁中目标类别比例计算
sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)# 2. 定义模型和损失
model = smp.Unet(encoder_name="resnet34", classes=5, decoder_attention_type="scse")
criterion = smp.losses.DiceLoss(mode="multiclass") + smp.losses.FocalLoss(mode="multiclass")# 3. 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(100):for images, masks in dataloader:outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()

关键工具总结

问题类型PyTorch 原生支持推荐第三方库(直接调用)
数据采样WeightedRandomSamplerAlbumentations(定向增强)
损失函数自定义(需手写)segmentation_models_pytorch.losses
模型结构手动添加模块(空洞卷积、注意力)segmentation_models_pytorch 预建模型
类别权重计算sklearn.utils.class_weight内置自动统计工具(如 SMP 数据集类)

注意事项

  1. 灵活组合策略:例如同时使用 WeightedRandomSamplerFocal Loss 可能过度偏向少数类,需通过实验调整。
  2. 监控类别指标:使用 torchmetrics 库计算每个类别的 IoU:
    from torchmetrics import JaccardIndex
    iou = JaccardIndex(num_classes=5, task="multiclass")
    iou.update(outputs, targets)
    print(f"IoU: {iou.compute()}")
    
  3. 混合精度训练:使用 torch.cuda.amp 加速训练,缓解显存压力:
    scaler = torch.cuda.amp.GradScaler()
    with torch.cuda.amp.autocast():outputs = model(images)loss = criterion(outputs, masks)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
http://www.lryc.cn/news/527696.html

相关文章:

  • 0.91英寸OLED显示屏一种具有小尺寸、高分辨率、低功耗特性的显示器件
  • 读书笔记--分布式服务架构对比及优势
  • HTML5 新的 Input 类型详解
  • ESP32-CAM实验集(WebServer)
  • Case逢无意难休——深度解析JAVA中case穿透问题
  • Golang笔记——常用库context和runtime
  • 2000-2020年各省第二产业增加值占GDP比重数据
  • unity商店插件A* Pathfinding Project如何判断一个点是否在导航网格上?
  • Day24-【13003】短文,数据结构与算法开篇,什么是数据元素?数据结构有哪些类型?什么是抽象类型?
  • 富文本 tinyMCE Vue2 组件使用简易教程
  • 强化学习在自动驾驶中的实现与挑战
  • 记录 | MaxKB创建本地AI智能问答系统
  • 特种作业操作之低压电工考试真题
  • [免费]基于Python的Django博客系统【论文+源码+SQL脚本】
  • Cannot resolve symbol ‘XXX‘ Maven 依赖问题的解决过程
  • 我们需要有哪些知识体系,知识体系里面要有什么哪些内容?
  • 什么是vue.js组件开发,我们需要做哪些准备工作?
  • 网络工程师 (3)指令系统基础
  • 第4章 神经网络【1】——损失函数
  • 【Python】第五弹---深入理解函数:从基础到进阶的全面解析
  • 【MQ】如何保证消息队列的高性能?
  • RAG是否被取代(缓存增强生成-CAG)吗?
  • 用C++编写一个2048的小游戏
  • 为何SAP S4系统中要设置MRP区域?MD04中可否同时显示工厂级、库存地点级的数据?
  • Windows10官方系统下载与安装保姆级教程【U盘-官方ISO直装】
  • 第05章 07 切片图等值线代码一则
  • 【深度学习】线性回归的简洁实现
  • 渗透测试技法之口令安全
  • 【R语言】数学运算
  • 小游戏源码开发搭建技术栈和服务器配置流程