变分自编码器VAE的Pytorch实现
一、导入第三方库
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader
二、手写数字数据集准备
#手写数字数据集
class MINISTDataset(Dataset):def __init__(self,files,root_dir,transform=None):self.files=filesself.root_dir=root_dirself.transform=transformself.labels=[]for f in files:parts=f.split("_")p=parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self,idx):img_path=os.path.join(self.root_dir,self.files[idx])img=Image.open(img_path).convert("L")if self.transform:img=self.transform(img)label=self.labels[idx]return img,label
三、VAE模型的pytorch代码
#编码器
class Encoder(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Sequential(nn.Conv2d(1,10,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2=nn.Sequential(nn.Conv2d(10,20,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc1=nn.Linear(320,160)self.fc21=nn.Linear(160,80) #均值self.fc22=nn.Linear(160,80) #方差self.relu=nn.ReLU()def forward(self,x):batch_size=x.size(0)x=self.conv1(x)x=self.conv2(x)x=x.view(batch_size,-1)h=self.relu(self.fc1(x))mu=self.fc21(h)log_var=self.fc22(h)return mu,log_var#解码器
class Decoder(nn.Module):def __init__(self):super().__init__()self.main=nn.Sequential(nn.Linear(80,160),nn.ReLU(),nn.Linear(160,320),nn.ReLU(),nn.Linear(320,28*28),nn.Sigmoid())def forward(self,z):return self.main(z)#变分自编码器
class VAE(nn.Module):def __init__(self,encoder,decoder):super().__init__()self.encoder=encoderself.decoder=decoder#重参数化def reparameterize(self,mu,log_var):std=torch.exp(0.5*log_var) #计算标准差eps=torch.randn_like(std) #从标准正态分布中采样噪声z=mu+eps*std #重参数化return zdef forward(self,x):mu,log_var=self.encoder(x)z=self.reparameterize(mu,log_var)return self.decoder(z),mu,log_var
四、主程序
if __name__=="__main__":#对数据做归一化处理transforms=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()])#路径base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir=os.path.join(base_dir,"minist_train")#获取文件夹里图像的名称train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]#创建数据集和数据加载器train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)#参数num_epochs=50lr=0.001#模型初始化encoder=Encoder()decoder=Decoder()vae=VAE(encoder,decoder)criterion=nn.BCELoss()optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))#记录损失函数值epoch_loss=[]for epoch in range(num_epochs):total_loss=0.0for data in train_loader:images,_=data#images=images.view(images.size(0),-1)optimizer.zero_grad()outputs,mu,logvar=vae(images)#计算重构损失和KL散度reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())loss=reconstruction_loss+0.1*kl_divergenceloss.backward()optimizer.step()total_loss+=loss.item()avg_loss=total_loss/len(train_loader)epoch_loss.append(avg_loss)print("Epoch",epoch," Loss:",avg_loss)#生成新图像with torch.no_grad():if (epoch+1)%5==0:z=torch.randn(9,80)plt.figure(figsize=(9,9))for i in range(9):plt.subplot(3,3,i+1)plt.imshow(decoder(z[i]).view(28,28),cmap="gray")plt.axis("off")name=f"vae_gen_img_{epoch}.jpg"gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)plt.savefig(gen_name,dpi=300)plt.close()#绘制损失函数曲线图plt.figure(figsize=(12,6))plt.plot(epoch_loss,color="tomato")plt.xlabel("epoch")plt.ylabel("loss")plt.title("损失函数曲线图")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")plt.close()
五、运行结果
5.1 损失函数曲线图
5.2 生成的图像
这里只展示一部分
vae_gen_img_4.jpg
vae_gen_img_29.jpg
vae_gen_img_49.jpg
六、VAE的完整代码
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号
from torchvision import transforms
import os
from PIL import Image
from torch.utils.data import Dataset,DataLoader#手写数字数据集
class MINISTDataset(Dataset):def __init__(self,files,root_dir,transform=None):self.files=filesself.root_dir=root_dirself.transform=transformself.labels=[]for f in files:parts=f.split("_")p=parts[2].split(".")[0]self.labels.append(int(p))def __len__(self):return len(self.files)def __getitem__(self,idx):img_path=os.path.join(self.root_dir,self.files[idx])img=Image.open(img_path).convert("L")if self.transform:img=self.transform(img)label=self.labels[idx]return img,label#编码器
class Encoder(nn.Module):def __init__(self):super().__init__()self.conv1=nn.Sequential(nn.Conv2d(1,10,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.conv2=nn.Sequential(nn.Conv2d(10,20,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2))self.fc1=nn.Linear(320,160)self.fc21=nn.Linear(160,80) #均值self.fc22=nn.Linear(160,80) #方差self.relu=nn.ReLU()def forward(self,x):batch_size=x.size(0)x=self.conv1(x)x=self.conv2(x)x=x.view(batch_size,-1)h=self.relu(self.fc1(x))mu=self.fc21(h)log_var=self.fc22(h)return mu,log_var#解码器
class Decoder(nn.Module):def __init__(self):super().__init__()self.main=nn.Sequential(nn.Linear(80,160),nn.ReLU(),nn.Linear(160,320),nn.ReLU(),nn.Linear(320,28*28),nn.Sigmoid())def forward(self,z):return self.main(z)#变分自编码器
class VAE(nn.Module):def __init__(self,encoder,decoder):super().__init__()self.encoder=encoderself.decoder=decoder#重参数化def reparameterize(self,mu,log_var):std=torch.exp(0.5*log_var) #计算标准差eps=torch.randn_like(std) #从标准正态分布中采样噪声z=mu+eps*std #重参数化return zdef forward(self,x):mu,log_var=self.encoder(x)z=self.reparameterize(mu,log_var)return self.decoder(z),mu,log_varif __name__=="__main__":#对数据做归一化处理transforms=transforms.Compose([transforms.Resize((28,28)),transforms.ToTensor()])#路径base_dir='C:\\Users\\Administrator\\PycharmProjects\\CNN'train_dir=os.path.join(base_dir,"minist_train")#获取文件夹里图像的名称train_files=[f for f in os.listdir(train_dir) if f.endswith('.jpg')]#创建数据集和数据加载器train_dataset=MINISTDataset(train_files,train_dir,transform=transforms)train_loader=DataLoader(train_dataset,batch_size=64,shuffle=True)#参数num_epochs=50lr=0.001#模型初始化encoder=Encoder()decoder=Decoder()vae=VAE(encoder,decoder)criterion=nn.BCELoss()optimizer=optim.Adam(vae.parameters(),lr=lr,betas=(0.5,0.999))#记录损失函数值epoch_loss=[]for epoch in range(num_epochs):total_loss=0.0for data in train_loader:images,_=data#images=images.view(images.size(0),-1)optimizer.zero_grad()outputs,mu,logvar=vae(images)#计算重构损失和KL散度reconstruction_loss=criterion(outputs,images.view(images.size(0),-1))kl_divergence=-0.5*torch.mean(1+logvar-mu.pow(2)-logvar.exp())loss=reconstruction_loss+0.1*kl_divergenceloss.backward()optimizer.step()total_loss+=loss.item()avg_loss=total_loss/len(train_loader)epoch_loss.append(avg_loss)print("Epoch",epoch," Loss:",avg_loss)#生成新图像with torch.no_grad():if (epoch+1)%5==0:z=torch.randn(9,80)plt.figure(figsize=(9,9))for i in range(9):plt.subplot(3,3,i+1)plt.imshow(decoder(z[i]).view(28,28),cmap="gray")plt.axis("off")name=f"vae_gen_img_{epoch}.jpg"gen_name=os.path.join("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_img",name)plt.savefig(gen_name,dpi=300)plt.close()#绘制损失函数曲线图plt.figure(figsize=(12,6))plt.plot(epoch_loss,color="tomato")plt.xlabel("epoch")plt.ylabel("loss")plt.title("损失函数曲线图")plt.legend()plt.grid()plt.savefig("C:\\Users\\Administrator\\PycharmProjects\\CNN\\vae_gen_loss.jpg")plt.close()