当前位置: 首页 > news >正文

19.数据增强技术

19.1 图像水平翻转与垂直翻转

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for _ in range(num_rows * num_cols)]figsize = (num_cols * scale, num_rows * scale)fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()  for ax, y in zip(axes, Y):ax.imshow(y)ax.axis('off')plt.tight_layout()plt.show()
apply(image,torchvision.transforms.RandomHorizontalFlip())#左右翻转
apply(image,torchvision.transforms.RandomVerticalFlip())#上下翻转

在这里插入图片描述

19.1 图像随机裁剪与色彩调整

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for _ in range(num_rows * num_cols)]figsize = (num_cols * scale, num_rows * scale)fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()  for ax, y in zip(axes, Y):ax.imshow(y)ax.axis('off')plt.tight_layout()plt.show()
apply(image,torchvision.transforms.RandomResizedCrop(size=(200,200),scale=(0.2,1),ratio=(1,2)))#随机裁剪
apply(image,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5))#明亮度/对比度等调整

在这里插入图片描述

19.3 图像整体增强变换

import torch
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):Y = [aug(img) for _ in range(num_rows * num_cols)]figsize = (num_cols * scale, num_rows * scale)fig, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()  for ax, y in zip(axes, Y):ax.imshow(y)ax.axis('off')plt.tight_layout()plt.show()
loc_aug=torchvision.transforms.RandomHorizontalFlip()
shape_aug = torchvision.transforms.RandomResizedCrop((200, 200), scale=(0.1, 1), ratio=(0.5, 2))
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
augs = torchvision.transforms.Compose([loc_aug,color_aug, shape_aug])
apply(image, augs)

在这里插入图片描述

19.4 基于CiFar-10数据集的图像增强效果对比

在这里插入图片描述

################################################################################################################
#ResNet
################################################################################################################
import torch
import torchvision
from torch import nn
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision.models as models
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
plt.rcParams['font.family']=['Times New Roman']
class Reshape(torch.nn.Module):def forward(self,x):return x.view(-1,1,28,28)#[bs,1,28,28]
def plot_metrics(train_loss_list, train_acc_list, test_acc_list, title='Training Curve'):epochs = range(1, len(train_loss_list) + 1)plt.figure(figsize=(4, 3))plt.plot(epochs, train_loss_list, label='Train Loss')plt.plot(epochs, train_acc_list, label='Train Acc',linestyle='--')plt.plot(epochs, test_acc_list, label='Test Acc', linestyle='--')plt.xlabel('Epoch')plt.ylabel('Value')plt.title(title)plt.legend()plt.grid(True)plt.tight_layout()plt.show()
def train_model(model,train_data,test_data,num_epochs):train_loss_list = []train_acc_list = []test_acc_list = []for epoch in range(num_epochs):total_loss=0total_acc_sample=0total_samples=0loop1=tqdm(train_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop1:#X=X.reshape(X.shape[0],-1)#print(X.shape)X=X.to(device)y=y.to(device)y_hat=model(X)loss=CEloss(y_hat,y)optimizer.zero_grad()loss.backward()optimizer.step()#loss累加total_loss+=loss.item()*X.shape[0]y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()total_acc_sample+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数total_samples+=X.shape[0]test_acc_samples=0test_samples=0loop2=tqdm(test_data,desc=f"EPOCHS[{epoch+1}/{num_epochs}]")for X,y in loop2:X=X.to(device)y=y.to(device)#X=X.reshape(X.shape[0],-1)y_hat=model(X)y_pred=y_hat.argmax(dim=1).detach().cpu().numpy()y_true=y.detach().cpu().numpy()test_acc_samples+=accuracy_score(y_pred,y_true)*X.shape[0]#保存样本数test_samples+=X.shape[0]avg_train_loss=total_loss/total_samplesavg_train_acc=total_acc_sample/total_samplesavg_test_acc=test_acc_samples/test_samplestrain_loss_list.append(avg_train_loss)train_acc_list.append(avg_train_acc)test_acc_list.append(avg_test_acc)print(f"Epoch {epoch+1}: Loss: {avg_train_loss:.4f},Trian Accuracy: {avg_train_acc:.4f},test Accuracy: {avg_test_acc:.4f}")plot_metrics(train_loss_list, train_acc_list, test_acc_list)return model
def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)
################################################################################################################
#这里选取一个是翻转,一个是归一化,一个是调整明亮度,最后是tensor化
transforms_train=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.5),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
transforms_test=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.5),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
train_img=torchvision.datasets.CIFAR10(root="./data",train=True,transform=transforms_train,download=True)
test_img=torchvision.datasets.CIFAR10(root="./data",train=False,transform=transforms_test,download=True)
train_data=DataLoader(train_img,batch_size=128,num_workers=4,shuffle=True)
test_data=DataLoader(test_img,batch_size=128,num_workers=4,shuffle=False)
################################################################################################################
device=torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model=models.resnet50(pretrained=True)#直接调用ResNet-50进行训练
model.fc=nn.Linear(model.fc.in_features,10)
model.to(device)
model.apply(init_weights)
optimizer=torch.optim.SGD(model.parameters(),lr=0.05,momentum=0.9)
CEloss=nn.CrossEntropyLoss()
model=train_model(model,train_data,test_data,num_epochs=20)
################################################################################################################

在这里插入图片描述

http://www.lryc.cn/news/588508.html

相关文章:

  • LeetCode 692题解 | 前K个高频单词
  • JAVA进阶--JVM
  • C#——数据与变量
  • Brooks 低温泵On-Board Cryopump 安装和维护手法Installation and Maintenance Manual
  • 文献查找任务及其方法
  • 信息学奥赛一本通 1549:最大数 | 洛谷 P1198 [JSOI2008] 最大数
  • 7.14练习案例总结
  • YOLOv11开发流程
  • 数字化红头文件生成工具:提升群聊与团队管理效率的创新方案
  • H264的帧内编码和帧间编码
  • 每日mysql
  • RAG索引流程中的文档解析:工业级实践方案与最佳实践
  • 【Linux网络】:HTTP(应用层协议)
  • 学习软件测试的第十五天
  • 【DOCKER】-6 docker的资源限制与监控
  • Linux操作系统之信号:信号的产生
  • 深入学习前端 Proxy 和 Reflect:现代 JavaScript 元编程核心
  • Modbus 开发工具实战:ModScan32 与 Wireshark 抓包分析(二)
  • Swift 解 LeetCode 326:两种方法判断是否是 3 的幂,含循环与数学技巧
  • [硬件电路-21]:模拟信号处理运算与数字信号处理运算的详细比较
  • 无人机迫降模式模块运行方式概述!
  • ICMP隧道工具完全指南:原理、实战与防御策略
  • Datawhale AI夏令营大模型 task2.1
  • 【科研绘图系列】R语言绘制世界地图
  • 硬盘爆满不够用?这个免费神器帮你找回50GB硬盘空间
  • 【React Natve】NetworkError 和 TouchableOpacity 组件
  • 网络编程(TCP连接)
  • 代理模式详解:代理、策略与模板方法模式
  • 暑期自学嵌入式——Day02(C语言阶段)
  • PyTorch张量(Tensor)创建的方式汇总详解和代码示例