当前位置: 首页 > news >正文

变分自编码器VAE的Pytorch实现

一、导入第三方库

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader

二、手写数字数据集准备

#手写数字数据集
class MINISTDataset(Dataset):def __init__(self,files,root_dir,transform=None):self.files=filesself.root_dir=root_dirself.transform=transformself.labels=[]for f in files:parts=f.split("_")p=parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self,idx):img_path=os.path.join(self.root_dir,self.files[idx])img=Image.open(img_path).convert("L")if self.transform:img=self.transform(img)label=self.labels[idx]return img,label

三、VAE模型的pytorch代码

#编码器
class Encoder(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Sequential(nn.Conv2d(1,10,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2=nn.Sequential(nn.Conv2d(10,20,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc1=nn.Linear(320,160)self.fc21=nn.Linear(160,80)  #均值self.fc22=nn.Linear(160,80)  #方差self.relu=nn.ReLU()def forward(self,x):batch_size=x.size(0)x=self.conv1(x)x=self.conv2(x)x=x.view(batch_size,-1)h=self.relu(self.fc1(x))mu=self.fc21(h)log_var=self.fc22(h)return mu,log_var#解码器
class Decoder(nn.Module):def __init__(self):super().__init__()self.main=nn.Sequential(nn.Linear(80,160),nn.ReLU(),nn.Linear(160,320),nn.ReLU(),nn.Linear(320,28*28),nn.Sigmoid())def forward(self,z):return self.main(z)#变分自编码器
class VAE(nn.Module):def __init__(self,encoder,decoder):super().__init__()self.encoder=encoderself.decoder=decoder#重参数化def reparameterize(self,mu,log_var):std=torch.exp(0.5*log_var)  #计算标准差eps=torch.randn_like(std)   #从标准正态分布中采样噪声z=mu+eps*std  #重参数化return zdef forward(self,x):mu,log_var=self.encoder(x)z=self.reparameterize(mu,log_var)return self.decoder(z),mu,log_var

四、主程序

if __name__=="__main__":#对数据做归一化处理transforms=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()])#路径base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir=os.path.join(base_dir,"minist_train")#获取文件夹里图像的名称train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]#创建数据集和数据加载器train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)#参数num_epochs=50lr=0.001#模型初始化encoder=Encoder()decoder=Decoder()vae=VAE(encoder,decoder)criterion=nn.BCELoss()optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))#记录损失函数值epoch_loss=[]for epoch in range(num_epochs):total_loss=0.0for data in train_loader:images,_=data#images=images.view(images.size(0),-1)optimizer.zero_grad()outputs,mu,logvar=vae(images)#计算重构损失和KL散度reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())loss=reconstruction_loss+0.1*kl_divergenceloss.backward()optimizer.step()total_loss+=loss.item()avg_loss=total_loss/len(train_loader)epoch_loss.append(avg_loss)print("Epoch",epoch,"  Loss:",avg_loss)#生成新图像with torch.no_grad():if (epoch+1)%5==0:z=torch.randn(9,80)plt.figure(figsize=(9,9))for i in range(9):plt.subplot(3,3,i+1)plt.imshow(decoder(z[i]).view(28,28),cmap="gray")plt.axis("off")name=f"vae_gen_img_{epoch}.jpg"gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)plt.savefig(gen_name,dpi=300)plt.close()#绘制损失函数曲线图plt.figure(figsize=(12,6))plt.plot(epoch_loss,color="tomato")plt.xlabel("epoch")plt.ylabel("loss")plt.title("损失函数曲线图")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")plt.close()

五、运行结果

5.1 损失函数曲线图

5.2 生成的图像

这里只展示一部分

vae_gen_img_4.jpg

vae_gen_img_29.jpg

vae_gen_img_49.jpg

六、VAE的完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']  # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader#手写数字数据集
class MINISTDataset(Dataset):def __init__(self,files,root_dir,transform=None):self.files=filesself.root_dir=root_dirself.transform=transformself.labels=[]for f in files:parts=f.split("_")p=parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self,idx):img_path=os.path.join(self.root_dir,self.files[idx])img=Image.open(img_path).convert("L")if self.transform:img=self.transform(img)label=self.labels[idx]return img,label#编码器
class Encoder(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Sequential(nn.Conv2d(1,10,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2=nn.Sequential(nn.Conv2d(10,20,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc1=nn.Linear(320,160)self.fc21=nn.Linear(160,80)  #均值self.fc22=nn.Linear(160,80)  #方差self.relu=nn.ReLU()def forward(self,x):batch_size=x.size(0)x=self.conv1(x)x=self.conv2(x)x=x.view(batch_size,-1)h=self.relu(self.fc1(x))mu=self.fc21(h)log_var=self.fc22(h)return mu,log_var#解码器
class Decoder(nn.Module):def __init__(self):super().__init__()self.main=nn.Sequential(nn.Linear(80,160),nn.ReLU(),nn.Linear(160,320),nn.ReLU(),nn.Linear(320,28*28),nn.Sigmoid())def forward(self,z):return self.main(z)#变分自编码器
class VAE(nn.Module):def __init__(self,encoder,decoder):super().__init__()self.encoder=encoderself.decoder=decoder#重参数化def reparameterize(self,mu,log_var):std=torch.exp(0.5*log_var)  #计算标准差eps=torch.randn_like(std)   #从标准正态分布中采样噪声z=mu+eps*std  #重参数化return zdef forward(self,x):mu,log_var=self.encoder(x)z=self.reparameterize(mu,log_var)return self.decoder(z),mu,log_varif __name__=="__main__":#对数据做归一化处理transforms=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()])#路径base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir=os.path.join(base_dir,"minist_train")#获取文件夹里图像的名称train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]#创建数据集和数据加载器train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)#参数num_epochs=50lr=0.001#模型初始化encoder=Encoder()decoder=Decoder()vae=VAE(encoder,decoder)criterion=nn.BCELoss()optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))#记录损失函数值epoch_loss=[]for epoch in range(num_epochs):total_loss=0.0for data in train_loader:images,_=data#images=images.view(images.size(0),-1)optimizer.zero_grad()outputs,mu,logvar=vae(images)#计算重构损失和KL散度reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())loss=reconstruction_loss+0.1*kl_divergenceloss.backward()optimizer.step()total_loss+=loss.item()avg_loss=total_loss/len(train_loader)epoch_loss.append(avg_loss)print("Epoch",epoch,"  Loss:",avg_loss)#生成新图像with torch.no_grad():if (epoch+1)%5==0:z=torch.randn(9,80)plt.figure(figsize=(9,9))for i in range(9):plt.subplot(3,3,i+1)plt.imshow(decoder(z[i]).view(28,28),cmap="gray")plt.axis("off")name=f"vae_gen_img_{epoch}.jpg"gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)plt.savefig(gen_name,dpi=300)plt.close()#绘制损失函数曲线图plt.figure(figsize=(12,6))plt.plot(epoch_loss,color="tomato")plt.xlabel("epoch")plt.ylabel("loss")plt.title("损失函数曲线图")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")plt.close()

http://www.lryc.cn/news/620083.html

相关文章:

  • day39_2025-08-13
  • Go 微服务限流与熔断最佳实践:滑动窗口、令牌桶与自适应阈值
  • Day19 C 语言标准 IO 机制
  • React useMemo 深度指南:原理、误区、实战与 2025 最佳实践
  • React常见的Hooks
  • 万字详解C++11列表初始化与移动语义
  • OpenCV的实际应用
  • 类和对象----中
  • 【COMSOL】Comsol学习案例时的心得记录分享
  • Mysql数据库迁移到GaussDB注意事项
  • pycharm配置连接服务器
  • 3.Cursor提效应用场景实战
  • MySQL相关概念和易错知识点(6)(视图、用户管理)
  • 大厂语音合成成本深度对比:微软 / 阿里 / 腾讯 / 火山 API 计费拆解与技术选型指南
  • trace分析之查找点击事件
  • cisco无线WLC flexconnect配置
  • python类--python011
  • 数仓建模理论-数据域和主题域
  • 8.13服务器安全检测技术和防御技术
  • 免费生成视频,Coze扣子工作流完全免费的视频生成方案,实现图生视频、文生视频
  • [ Mybatis 多表关联查询 ] resultMap
  • LeetCode Day5 -- 二叉树
  • 使用 HTML5 Canvas 打造炫酷的数字时钟动画
  • Kubernetes-03:Service
  • 对线面试官之幂等和去重
  • 【OpenGL】LearnOpenGL学习笔记07 - 摄像机
  • 会议征稿!IOP出版|第二届人工智能、光电子学与光学技术国际研讨会(AIOT2025)
  • 【Android】RecyclerView多布局展示案例
  • [系统架构设计师]架构设计专业知识(二)
  • Linux 计划任务