网站建设包括两个方面,襄樊网站建设哪家好,东营专业网站建设公司电话,外贸企业建站公司前言#xff1a; 这里主要实现#xff1a; Variational Autoencoders (VAEs) 变分自动编码器 其训练效果如下 训练的过程中要注意调节forward 中的kle ,调参。
整个工程两个文件#xff1a; vae.py main.py 目录#xff1a; vae main 一 vae 文件名#xff1a; vae…前言 这里主要实现 Variational Autoencoders (VAEs) 变分自动编码器 其训练效果如下 训练的过程中要注意调节forward 中的kle ,调参。
整个工程两个文件 vae.py main.py 目录 vae main 一 vae 文件名 vae.py 作用 Variational Autoencoders (VAE) 训练的过程中加入一些限制使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE)它被定义为一个有规律地训练以避免过度拟合的Autoencoder可以确保潜在空间具有良好的属性从而实现内容的生成。 variational autoencoder的架构和Autoencoder差不多区别在于不再是把输入当作一个点而是把输入当成一个分布。
# -*- coding: utf-8 -*-Created on Wed Aug 30 14:19:19 2023author: chengxf2
import torch
from torch import nn#ae: AutoEncoderclass VAE(nn.Module):def __init__(self,hidden_size20):super(VAE, self).__init__()self.encoder nn.Sequential(nn.Linear(in_features784, out_features256),nn.ReLU(),nn.Linear(in_features256, out_features128),nn.ReLU(),nn.Linear(in_features128, out_features64),nn.ReLU(),nn.Linear(in_features64, out_featureshidden_size),nn.ReLU())# hidden [batch_size, 10]h_dim int(hidden_size/2)self.hDim h_dimself.decoder nn.Sequential(nn.Linear(in_featuresh_dim, out_features64),nn.ReLU(),nn.Linear(in_features64, out_features128),nn.ReLU(),nn.Linear(in_features128, out_features256),nn.ReLU(),nn.Linear(in_features256, out_features784),nn.Sigmoid())def forward(self, x):param x:[batch, 1,28,28]return batchSz x.size(0)#flattenx x.view(batchSz, 784)#encoderh self.encoder(x)#在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigmau, sigma h.chunk(2,dim1)#Reparameterize trick#randn_like产生一个正太分布 ~ N(0,1)#h.shape [batchSize,self.hDim]h usigma* torch.randn_like(sigma)#kld :1e-8 防止sigma 平方为0kld 0.5*torch.sum(torch.pow(u,2)torch.pow(sigma,2)-torch.log(1e-8torch.pow(sigma,2))-1)#MSE loss 是平均loss, 所以kld 也要算一个平均值kld kld/(batchSz*32*32)xHat self.decoder(h)#reshapexHat xHat.view(batchSz,1,28,28)return xHat,kld 二 main
文件名: main.py
作用 训练测试数据集 # -*- coding: utf-8 -*-Created on Wed Aug 30 14:24:10 2023author: chengxf2
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdomdef main():batchNum 32lr 1e-3epochs 20device torch.device(cuda:0 if torch.cuda.is_available() else cpu)torch.manual_seed(1234)viz visdom.Visdom()viz.line([0],[-1],wintrain_loss,opts dict(titletrain acc))tf transforms.Compose([ transforms.ToTensor()])mnist_train datasets.MNIST(mnist,True,transform tf,downloadTrue)train_data DataLoader(mnist_train, batch_sizebatchNum, shuffleTrue)mnist_test datasets.MNIST(mnist,False,transform tf,downloadTrue)test_data DataLoader(mnist_test, batch_sizebatchNum, shuffleTrue)global_step 0model VAE().to(device)criteon nn.MSELoss().to(device) #损失函数optimizer optim.Adam(model.parameters(),lrlr) #梯度更新规则print(\n ----main-----)for epoch in range(epochs):start time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x x.to(device)x_hat,kld model(x)loss criteon(x_hat, x)if kld is not None:elbo -loss -1.0*kldloss -elbo#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y[loss.item()],X[global_step],wintrain_loss,updateappend)global_step 1end time.perf_counter() interval int(end - start)print(epoch: %d%epoch, \t 训练时间 %d%interval, \t 总loss: %4.7f%loss.item(),\t KL divergence: %4.7f%kld.item())x,target iter(test_data).next()x x.to(device)with torch.no_grad():x_hat,kld model(x)tip hatstr(epoch)viz.images(x,nrow8, winx,optsdict(titlex))viz.images(x_hat,nrow8, winx_hat,optsdict(titletip))if __name__ __main__:main() 参考 课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili