网站 pinghei,泰国如何做网站推广,深圳公司排名名字,太原开发网站公司生成对抗神经网络GAN#xff0c;发挥神经网络的想象力#xff0c;可以说是十分厉害了
参考
1、AI作家 2、将模糊图变清晰(去雨#xff0c;去雾#xff0c;去抖动#xff0c;去马赛克等)#xff0c;这需要AI具有“想象力”#xff0c;能脑补情节#xff1b; 3、进行数…生成对抗神经网络GAN发挥神经网络的想象力可以说是十分厉害了
参考
1、AI作家 2、将模糊图变清晰(去雨去雾去抖动去马赛克等)这需要AI具有“想象力”能脑补情节 3、进行数据增强根据已有数据生成更多新数据供以feed可以减缓模型过拟合现象。
那到底是怎么实现的呢 GAN中有两大组成部分G和D
G是generator生成器: 负责凭空捏造数据出来
D是discriminator判别器: 负责判断数据是不是真数据
示例图如下 给一个随机噪声z通过G生成一张假图然后用D去分辨是真图还是假图。假设G生成了一张图在D那里的得分很高那么G就很成功的骗过了D如果D很轻松的分辨出了假图那么G的效果不好那么就需要调整参数了。 G和D是两个单独的网络那么他们的参数都是训练好的吗并不是两个网络的参数是需要在博弈的过程中分别优化的。
下面就是一个训练的过程 GAN在一轮反向传播中分为两步先训练D在训练G。
训练D时上一轮G产生的图片和真实图片一起作为x进行输入假图为0真图标签为1通过x生成一个score通过score和标签y计算损失就可以进行反向传播了。
训练G时G和D是一个整体取名为D_on_G。输入随机噪声G产生一个假图D去分辨score 1就是需要我们需要优化的目标意思就是我们要让生成的图片变成真的。这里的D是不需要参与梯度计算的我们通过反向传播来优化G让他生成更加真实的图片。这就好比如果你参加考试你别指望能改变老师的评分标准 GAN无监督学习cGAN是有监督的以后会学习的。怎么理解无监督学习呢这里给的真图是没有经过人工标注的只知道这是真的D是不知道这是什么的只需要分辨真假。G也不知道生成了什么只需要学真图去骗D。 具体如何实施呢
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_imagedevice torch.device(cuda if torch.cuda.is_available() else cpu)
latent_size 64
hidden_size 256
image_size 784
num_epochs 200
batch_size 100
sample_dir samples
注意这里有个归一化的过程MNIST是单通道但是如果mean0.50.50.5会报错因为是对3通道操作 。
if not os.path.exists(sample_dir):os.makedirs(sample_dir)transform transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean(0.5,), # 3 for RGB channelsstd(0.5,))])# MNIST dataset
mnist torchvision.datasets.MNIST(root./data/,trainTrue,transformtransform,downloadTrue)
# Data loader
data_loader torch.utils.data.DataLoader(datasetmnist,batch_sizebatch_size, shuffleTrue) 定义生成器和判别器
生成器可以看到输入的维度为64是一组噪声图像通过生成器将特征扩大到了MNIST图像大小784。
判别器输入维度为图像大小最后输出特征个数为1采用sigmoid激活不用softmax的
# Discriminator
D nn.Sequential(nn.Linear(image_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, 1),nn.Sigmoid())# Generator
G nn.Sequential(nn.Linear(latent_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, image_size),nn.Tanh())
# Device setting
D D.to(device)
G G.to(device)# Binary cross entropy loss and optimizer
criterion nn.BCELoss()
d_optimizer torch.optim.Adam(D.parameters(), lr0.0002)
g_optimizer torch.optim.Adam(G.parameters(), lr0.0002)def denorm(x):out (x 1) / 2return out.clamp(0, 1)def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad() 重点看训练部分我们到底是如何来训练GAN的。
判别器部分判别器的损失值分为两部分一将mini_batch定义为正样本告诉他我是正品所以设置标签为1。优化判别器判断正品的能力二生成一幅赝品再给判别器判别这时候赝品的标签为0优化判断赝品的能力。所以总损失为这两部分之和计算梯度优化判别器参数。
G_on_D输入一个噪声让生成器生成一幅图像然后让D去判别计算和正品之间的距离即损失。反向传播优化G的参数。
# Start training
total_step len(data_loader)
for epoch in range(num_epochs):for i, (images, _) in enumerate(data_loader):images images.reshape(batch_size, -1).to(device)# Create the labels which are later used as input for the BCE lossreal_labels torch.ones(batch_size, 1).to(device)fake_labels torch.zeros(batch_size, 1).to(device)# ## Train the discriminator ## ## Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# Second term of the loss is always zero since real_labels 1outputs D(images)d_loss_real criterion(outputs, real_labels)real_score outputs# Compute BCELoss using fake images# First term of the loss is always zero since fake_labels 0z torch.randn(batch_size, latent_size).to(device)fake_images G(z)outputs D(fake_images)d_loss_fake criterion(outputs, fake_labels)fake_score outputs# Backprop and optimized_loss d_loss_real d_loss_fakereset_grad()d_loss.backward()d_optimizer.step()# ## Train the generator ## ## Compute loss with fake imagesz torch.randn(batch_size, latent_size).to(device)fake_images G(z)outputs D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdfg_loss criterion(outputs, real_labels)# Backprop and optimizereset_grad()g_loss.backward()g_optimizer.step()if (i1) % 200 0:print(Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f} .format(epoch, num_epochs, i1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))# Save real imagesif (epoch1) 1:images images.reshape(images.size(0), 1, 28, 28)save_image(denorm(images), os.path.join(sample_dir, real_images.png))# Save sampled imagesfake_images fake_images.reshape(fake_images.size(0), 1, 28, 28)save_image(denorm(fake_images), os.path.join(sample_dir, fake_images-{}.png.format(epoch1))) 训练完了怎么用
只要用我们的生成器就可以随意生成了。
import matplotlib.pyplot as plt
z torch.randn(1,latent_size).to(device)
output G(z)
plt.imshow(output.cpu().data.numpy().reshape(28,28),cmapgray)
plt.show() 下面就是随机生成的图像了