如何做服装的微商城网站建设,广州有几个区哪个区最好,做个网站多少钱,腾讯云购买域名后如何建网站2023.8.27 在进行深度学习的进阶的时候#xff0c;我发了生成对抗网络是一个很神奇的东西#xff0c;为什么它可以“将一堆随机噪声经过生成器变成一张图片”#xff0c;特此记录一下学习心得。
一、生成对抗网络百科 2014年#xff0c;还在蒙特利尔读博士的Ian Goodfello…2023.8.27 在进行深度学习的进阶的时候我发了生成对抗网络是一个很神奇的东西为什么它可以“将一堆随机噪声经过生成器变成一张图片”特此记录一下学习心得。
一、生成对抗网络百科 2014年还在蒙特利尔读博士的Ian Goodfellow发表了论 文《Generative Adversarial Networks》网址 https://arxiv.org/abs/1406.2661将生成对抗网络引入 深度学习领域。2016年GAN热潮席卷AI领域顶级会议 从ICLR到NIPS大量高质量论文被发表和探讨。Yann LeCun曾评价GAN是“20年来机器学习领域最酷的想法”。
机器学习的模型可大体分为两类生成模型 Generative Model和判别模型Discriminative Model。判别模型需要输入变量 通过某种模型来 预测 。生成模型是给定某种隐含信息来随机产生观 测数据。 GAN百科 GAN生成对抗网络的系统全面介绍醍醐灌顶_打灰人的博客-CSDN博客 二、GAN代码
训练代码 epoch1000时的效果就不错啦
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as pltclass Generator(nn.Module): # 生成器def __init__(self, latent_dim):super(Generator, self).__init__()self.model nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self, z):img self.model(z)img img.view(img.size(0), 1, 28, 28)return imgclass Discriminator(nn.Module): # 判别器def __init__(self):super(Discriminator, self).__init__()self.model nn.Sequential(nn.Linear(784, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid())def forward(self, img):img img.view(img.size(0), -1)validity self.model(img)return validitydef gen_img_plot(model, test_input):pred np.squeeze(model(test_input).detach().cpu().numpy())fig plt.figure(figsize(4, 4))for i in range(16):plt.subplot(4, 4, i 1)plt.imshow((pred[i] 1) / 2)plt.axis(off)plt.show(blockFalse)plt.pause(3) # 停留0.5splt.close()# 调用GPU
device torch.device(cuda:0 if torch.cuda.is_available() else cpu)# 超参数设置
lr 0.0001
batch_size 128
latent_dim 100
epochs 1000# 数据集载入和数据变换
transform transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 训练数据
train_dataset datasets.MNIST(root./data, trainTrue, transformtransform, downloadFalse)
train_loader DataLoader(train_dataset, batch_sizebatch_size, shuffleTrue)# 测试数据 torch.randn()函数的作用是生成一组均值为0方差为1(即标准正态分布)的随机数
# test_data torch.randn(batch_size, latent_dim).to(device)
test_data torch.FloatTensor(batch_size, latent_dim).to(device)# 实例化生成器和判别器并定义损失函数和优化器
generator Generator(latent_dim).to(device)
discriminator Discriminator().to(device)
adversarial_loss nn.BCELoss()
optimizer_G optim.Adam(generator.parameters(), lrlr)
optimizer_D optim.Adam(discriminator.parameters(), lrlr)# 开始训练模型
for epoch in range(epochs):for i, (imgs, _) in enumerate(train_loader):batch_size imgs.shape[0]real_imgs imgs.to(device)# 训练判别器z torch.FloatTensor(batch_size, latent_dim).to(device)z.data.normal_(0, 1)fake_imgs generator(z) # 生成器生成假的图片real_labels torch.full((batch_size, 1), 1.0).to(device)fake_labels torch.full((batch_size, 1), 0.0).to(device)real_loss adversarial_loss(discriminator(real_imgs), real_labels)fake_loss adversarial_loss(discriminator(fake_imgs.detach()), fake_labels)d_loss (real_loss fake_loss) / 2optimizer_D.zero_grad()d_loss.backward()optimizer_D.step()# 训练生成器z.data.normal_(0, 1)fake_imgs generator(z)g_loss adversarial_loss(discriminator(fake_imgs), real_labels)optimizer_G.zero_grad()g_loss.backward()optimizer_G.step()torch.save(generator.state_dict(), Generator_mnist.pth)print(fEpoch [{epoch}/{epochs}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f})# gen_img_plot(Generator, test_data)
gen_img_plot(generator, test_data)测试代码
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import randomdevice torch.device(cuda:0 if torch.cuda.is_available() else cpu)class Generator(nn.Module): # 生成器def __init__(self, latent_dim):super(Generator, self).__init__()self.model nn.Sequential(nn.Linear(latent_dim, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self, z):img self.model(z)img img.view(img.size(0), 1, 28, 28)return img# test_data torch.FloatTensor(128, 100).to(device)
test_data torch.randn(128, 100).to(device) # 随机噪声model Generator(100).to(device)
model.load_state_dict(torch.load(Generator_mnist.pth))
model.eval()pred np.squeeze(model(test_data).detach().cpu().numpy())for i in range(64):plt.subplot(8, 8, i 1)plt.imshow((pred[i] 1) / 2)plt.axis(off)
plt.savefig(fnameimage.png, figsize[5, 5])
plt.show()三、结果 在超参数设置 epoch1000batch_size128lr0.0001latent_dim 100 时gan生成的权重测的结果如图所示 四GAN的损失函数曲线 一开始训练时我的gan的损失函数的曲线是类似这样的就是知乎这文章里一样生成器损失函数的曲线一直发散。首先这个loss的曲线一看就是网络崩了一般正常的情况d_loss的值会一直下降然后收敛而g_loss的曲线会先增大后减少最后同样也会收敛。其次网络拿到手以后先不要训练太多次容易出现过拟合的情况。 生成对抗网络的损失函数图像如下合理吗 - 知乎 这是训练了10轮的生成器和鉴别器的损失函数值变化吧 效果如图所示