coco数据集格式的RandomCrop
transforms.py文件的改进
添加 RandomCrop 函数
class RandomCrop(object):"""随机裁剪图像以及bboxes"""def __init__(self, output_size):self.output_size = output_sizedef __call__(self, image, target):height, width = image.shape[-2:]th = self.output_sizetw = self.output_sizeif width == tw and height == th:return image, targetx = random.randint(0, width - tw)y = random.randint(0, height - th)image = image[:, y:y+th, x:x+tw]bbox = target["boxes"]bbox[:, [0, 2]] = bbox[:, [0, 2]] - xbbox[:, [1, 3]] = bbox[:, [1, 3]] - ytarget["boxes"] = bboxif "masks" in target:target["masks"] = target["masks"][:, y:y+th, x:x+tw]return image, target
train.py文件中的改进
添加RandomCrop模块
data_transform = {"train": transforms.Compose([transforms.ToTensor(),transforms.RandomHorizontalFlip(0.5),transforms.RandomCrop(1024)]),"val": transforms.Compose([transforms.ToTensor()])}
训练中出现错误:
loss达到了50.0+
训练中途loss超过100的的时候会出现 loss is nan的报错。