当前位置: 首页 > news >正文

[oneAPI] 手写数字识别-GAN

[oneAPI] 手写数字识别-GAN

  • 手写数字识别
    • 参数与包
    • 加载数据
    • 模型
    • 训练过程
    • 结果
  • oneAPI

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

手写数字识别

使用了pytorch以及Intel® Optimization for PyTorch,通过优化扩展了 PyTorch,使英特尔硬件的性能进一步提升,让手写数字识别问题更加的快速高效
在这里插入图片描述

使用MNIST数据集,该数据集包含了一系列以黑白图像表示的手写数字,每个图像的大小为28x28像素,数据集组成如下:

  • 训练集:包含60,000个图像和标签,用于训练模型。
  • 测试集:包含10,000个图像和标签,用于测试模型的性能。

每个图像都被标记为0到9之间的一个数字,表示图像中显示的手写数字。这个数据集常常被用来验证图像分类模型的性能,特别是在计算机视觉领域。

参数与包

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_imageimport intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

加载数据

# Create a directory if not exists
if not os.path.exists(sample_dir):os.makedirs(sample_dir)# Image processing
# transform = transforms.Compose([
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=(0.5, 0.5, 0.5),   # 3 for RGB channels
#                                      std=(0.5, 0.5, 0.5))])
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],  # 1 for greyscale channelsstd=[0.5])])# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size,shuffle=True)

模型

# Discriminator
D = nn.Sequential(nn.Linear(image_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, 1),nn.Sigmoid())# Generator 
G = nn.Sequential(nn.Linear(latent_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, image_size),nn.Tanh())

训练过程

# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
D, d_optimizer = ipex.optimize(D, optimizer=d_optimizer)
G, g_optimizer = ipex.optimize(G, optimizer=g_optimizer)def denorm(x):out = (x + 1) / 2return out.clamp(0, 1)def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):for i, (images, _) in enumerate(data_loader):images = images.reshape(batch_size, -1).to(device)# Create the labels which are later used as input for the BCE lossreal_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# ================================================================== ##                      Train the discriminator                       ## ================================================================== ## Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# Second term of the loss is always zero since real_labels == 1outputs = D(images)d_loss_real = criterion(outputs, real_labels)real_score = outputs# Compute BCELoss using fake images# First term of the loss is always zero since fake_labels == 0z = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs, fake_labels)fake_score = outputs# Backprop and optimized_loss = d_loss_real + d_loss_fakereset_grad()d_loss.backward()d_optimizer.step()# ================================================================== ##                        Train the generator                         ## ================================================================== ## Compute loss with fake imagesz = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdfg_loss = criterion(outputs, real_labels)# Backprop and optimizereset_grad()g_loss.backward()g_optimizer.step()if (i + 1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(epoch, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),real_score.mean().item(), fake_score.mean().item()))# Save real imagesif (epoch + 1) == 1:images = images.reshape(images.size(0), 1, 28, 28)save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))# Save sampled imagesfake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)))# Save the model checkpoints 
torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

结果

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

oneAPI

import intel_extension_for_pytorch as ipex# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
D, d_optimizer = ipex.optimize(D, optimizer=d_optimizer)
G, g_optimizer = ipex.optimize(G, optimizer=g_optimizer)
http://www.lryc.cn/news/127108.html

相关文章:

  • 爬虫逆向实战(十五)--阿某某营登录
  • 【计组】校验码(奇偶校验码、海明校验码、CRC)
  • File Inclusion
  • 函数性能探测:更简单高效的 Serverless 规格选型方案
  • 嵌入式Linux Qt5 (C++)开发栏目概述
  • C语言“牵手”微店商品详情数据方法,微店商品详情API接口申请指南
  • C++ volatile
  • 空洞卷积学习笔记
  • WPF中的UseLayoutRounding和SnapsToDevicePixels
  • Windows权限维持—自启动映像劫持粘滞键辅助屏保后门WinLogon
  • Mysql之explain详解
  • 每天一道leetcode:1926. 迷宫中离入口最近的出口(图论中等广度优先遍历)
  • Mysql_5.7下载安装与配置基础操作教程
  • 【业务功能篇68】电商项目相关核心设计
  • 微信开发之一键退出群聊的技术实现
  • 〔012〕Stable Diffusion 之 中文提示词自动翻译插件 篇
  • 【C++】一文带你初识C++继承
  • SDK是什么,SDK和API有什么区别
  • golang中使用chan控制协程并发简单事例
  • 好用画流程图软件推荐 excalidraw
  • 【RP2040】香瓜树莓派RP2040之搭建开发环境(windows)
  • 基于springboot线上礼品商城
  • 开源,微信小程序 美食便签地图(FoodNoteMap)的设计与开发
  • kubernetes集群(k8s)之安装部署Calico 网络
  • 【C/C++】C++11 Lambda 表达式捕获变量使用技巧
  • 大势智慧软硬件技术答疑第八期
  • jvm-jvm与java体系结构
  • Three.js 实现材质边缘通道发光效果
  • 【MFC】10.MFC六大机制:RTTI(运行时类型识别),动态创建机制,窗口切分,子类化-笔记
  • ui设计师个人年终工作总结2020最新范文5篇