当前位置: 首页 > news >正文

林业网站模板南昌做微信网站

林业网站模板,南昌做微信网站,更改备案网站名称,wordpress 技术类主题引言 在考虑生成对抗网络的文献时#xff0c;Wasserstein GAN 因其与传统 GAN 相比的训练稳定性而成为关键概念之一。在本文中#xff0c;我将介绍基于梯度惩罚的 WGAN 的概念。文章的结构安排如下#xff1a; WGAN 背后的直觉#xff1b;GAN 和 WGAN 的比较#xff1b;…引言 在考虑生成对抗网络的文献时Wasserstein GAN 因其与传统 GAN 相比的训练稳定性而成为关键概念之一。在本文中我将介绍基于梯度惩罚的 WGAN 的概念。文章的结构安排如下 WGAN 背后的直觉GAN 和 WGAN 的比较基于梯度惩罚的WGAN的数学背景使用 PyTorch 从头开始​​在CelebA-Face 数据集上实现WGAN 结果讨论。 WGAN 背后的直觉 GAN 最初由Ian J. Goodfellow 等人发明。在 GAN 中有一个由生成器和判别器进行的双玩家最小最大游戏。早期 GAN 的主要问题是模式崩溃和梯度消失问题。为了克服这些问题长期以来发明了许多技术。WGAN 是试图克服传统 GAN 的这些问题的方法之一。 GAN 与 WGAN 与传统的 GAN 相比WGAN 有一些改进/变化。 评论家而非判别器W-Loss 代替 BCE Loss使用梯度惩罚/权重剪裁进行权重正则化。 传统GAN的判别器被“Critic”取代。从实现的角度来看这只不过是最后一层没有 Sigmoid 激活的判别器。 我们稍后将讨论 WGAN 损失函数和权重正则化。 数学背景 损失函数 这是基于梯度惩罚的 WGAN 的完整损失函数。 等式 1. 具有梯度惩罚的完整 WGAN 损失函数 — [3] 看起来很吓人吧让我们分解一下这个方程。 第 1 部分原始批评损失 该方程产生的值应由生成器正向最大化同时由批评家负向最大化。请注意这里的 x_CURL 是生成器 (G(z)) 生成的图像。 这里D 在最后一层没有 Sigmoid 激活因此 D(*) 可以是任何实数。这给出了地球移动器的真实分布和生成分布之间的距离的近似值 - [1]。我们在这里想做的是 评论家的观点通过最大化等式 2结果的负值/最小化正值尽可能地将评论家对真实图像和生成图像的输出分布分开。这反映了评论家的目标即为真实图像提供更高的分数为更低的分数到生成的图像。生成器的观点尝试通过以相反的方向分离真实图像和生成图像的输出分布来抵消评论家的努力。这最终使式 2 的结果的正值最大化。这反映了生成器的目标是通过欺骗 Critic 来提高生成图像的 Critic 分数。 在这里你可能已经注意到Critic over Discriminator这个名字的出现是因为 Critic 不区分真假图像只是给出一个无界的分数。 为了确保方程有效我们需要确保 Critic 函数是 1-Lipschitz 连续的 — [1]。 1-Lipschitz连续性 函数 f(x) 是 1-L 连续的梯度应始终小于或等于 1。 为了确保这种1-Lipschitz连续性文献中主要提出了2种方法。 Weight Clipping——这是 WGAN 论文 [2] 附带的初始方法梯度惩罚方法——这是在最初的论文之后作为改进提出的[3]。 在本文中我们将重点关注基于梯度惩罚的 WGAN。 第二部分梯度惩罚 这是 Gulrajani 等人提出的梯度惩罚。——[3]。这里我们通过减小 Critic 梯度的 L2 范数与 1 之间的平方距离来强制 Critic 的梯度为 1。注意我们不能强制 Critic 的梯度为 0因为这会导致梯度消失问题。 等等x(^)是什么 考虑到 1-Lipschitz 连续性的定义所有 x 的梯度应≤1。但实际上确保所有可能的图像都满足这种条件是很困难的。因此我们使用 x(^) 表示使用真实图像和生成图像作为梯度惩罚的数据点的随机插值图像。这确保了 Critic 的梯度将通过查看训练期间遇到的一组公平的数据点/图像进行正则化。 Pytorch实现 在这里我将介绍大家应该做的必要更改以便将传统的 GAN 更改为 WGAN。 对于下面的实现我将使用我在之前有关 DCGAN 的文章中详细解释的模型和训练原理。 数据集 Celeba-face 数据集用于训练。下载、预处理、制作数据加载器脚本如代码1所示。 import zipfile import os if not os.path.isfile(celeba.zip):!mkdir data_faces wget https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip with zipfile.ZipFile(celeba.zip,r) as zip_ref:zip_ref.extractall(data_faces/)from torch.utils.data import DataLoadertransform transforms.Compose([transforms.Resize((img_size,img_size)),transforms.ToTensor(),transforms.Normalize((0.5,0.5, 0.5),(0.5, 0.5, 0.5))])dataset datasets.ImageFolder(data_faces, transformtransform) data_loader DataLoader(dataset,batch_sizebatch_size,shuffleTrue)生成器和评论家 Critic 与 Discriminator 相同但不包含最后一层 Sigmoid 激活。 class Generator(nn.Module):def __init__(self,noise_channels,img_channels,hidden_G):super(Generator,self).__init__()self.Gnn.Sequential(conv_trans_block(noise_channels,hidden_G*16,kernal_size4,stride1,padding0),conv_trans_block(hidden_G*16,hidden_G*8),conv_trans_block(hidden_G*8,hidden_G*4),conv_trans_block(hidden_G*4,hidden_G*2),nn.ConvTranspose2d(hidden_G*2,img_channels,kernel_size4,stride2,padding1),nn.Tanh())def forward(self,x):return self.G(x)class Critic(nn.Module):def __init__(self,img_channels,hidden_D):super(Critic,self).__init__()self.Dnn.Sequential(conv_block(img_channels,hidden_G),conv_block(hidden_G,hidden_G*2),conv_block(hidden_G*2,hidden_G*4),conv_block(hidden_G*4,hidden_G*8),nn.Conv2d(hidden_G*8,1,kernel_size4,stride2,padding0))def forward(self,x):return self.D(x)Generator 和 Critic 的支持块如下面的代码 3 所示。 class conv_trans_block(nn.Module):def __init__(self,in_channels,out_channels,kernal_size4,stride2,padding1):super(conv_trans_block,self).__init__()self.blocknn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernal_size,stride,padding),nn.BatchNorm2d(out_channels),nn.ReLU())def forward(self,x):return self.block(x)class conv_block(nn.Module):def __init__(self,in_channels,out_channels,kernal_size4,stride2,padding1):super(conv_block,self).__init__()self.blocknn.Sequential(nn.Conv2d(in_channels,out_channels,kernal_size,stride,padding),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2))def forward(self,x):return self.block(x)损失函数 与任何其他典型的损失函数不同损失函数可能有点棘手因为它包含梯度。在这里我们将使用梯度惩罚来实现 W-loss稍后可以将其插入 WGAN 模型中。 def get_gen_loss(crit_fake_pred):gen_loss -torch.mean(crit_fake_pred)return gen_lossdef get_crit_loss(crit_fake_pred, crit_real_pred, gradient_penalty, c_lambda):crit_loss torch.mean(crit_fake_pred)- torch.mean(crit_real_pred) c_lambda* gradient_penaltyreturn crit_loss让我们分解一下代码 4 中所示的损失函数。 生成器损失 - 生成器损失不受梯度惩罚的影响。因此它必须仅最大化 D(x_CURL)/ D(G(z)) 项这意味着最小化 -D(G(z))。这是在第 2 行中实现的。批评者损失 - 批评者损失包含等式 1 中所示损失的 2 个部分。在第 6 行中前两项给出等式 2 中解释的原始批评者损失而最后一项给出等式 3 中解释的梯度惩罚。 梯度惩罚可以按照下面的代码 5 来实现 - [1]。 def get_gradient(crit, real_imgs, fake_imgs, epsilon):mixed_imgs real_imgs* epsilon fake_imgs*(1- epsilon)mixed_scores crit(mixed_imgs)gradient torch.autograd.grad(outputs mixed_scores,inputs mixed_imgs,grad_outputs torch.ones_like(mixed_scores),create_graphTrue,retain_graphTrue)[0]return gradientdef gradient_penalty(gradient):gradient gradient.view(len(gradient), -1)gradient_norm gradient.norm(2, dim1)penalty torch.nn.MSELoss()(gradient_norm, torch.ones_like(gradient_norm))return penalty在代码 5 中get_gradient()函数返回从x_hat 混合图像开始到Critic 输出 (mixed_scores)结束的所有网络梯度。这将在gradient_penalty()函数中使用它返回Critic梯度的1和L2范数之间的均方距离。 减少 Critic 的损失最终会减少这种梯度惩罚。这确保了 Critic 函数保留了 1-Lipschitz 连续性。 训练 训练将与上一篇文章中的几乎相同。但这里的损失与传统的 GAN 损失不同。我已经使用WANDB记录我的结果。如果您有兴趣记录结果WANDB 是一个非常好的工具。 CCritic(img_channels,hidden_C).to(device) GGenerator(noise_channels,img_channels,hidden_G).to(device)#CC.apply(init_weights) #GG.apply(init_weights)wandb.watch(G, logall, log_freq10) wandb.watch(C, logall, log_freq10)opt_Ctorch.optim.Adam(C.parameters(),lrlr, betas(0.5,0.999)) opt_Gtorch.optim.Adam(G.parameters(),lrlr, betas(0.5,0.999))gen_repeats1 crit_repeats3noise_for_generatetorch.randn(batch_size,noise_channels,1,1).to(device)losses_C[] losses_G[]for epoch in range(1,epochs1):loss_C_epoch[]loss_G_epoch[]for idx,(x,_) in enumerate(data_loader):C.train()G.train()xx.to(device)x_lenx.shape[0]### Train Closs_C_iter0for _ in range(crit_repeats):opt_C.zero_grad()ztorch.randn(x_len,noise_channels,1,1).to(device)real_imgsxfake_imgsG(z).detach()real_C_outC(real_imgs)fake_C_outC(fake_imgs)epsilon torch.rand(len(x),1,1,1, device device, requires_gradTrue)gradient get_gradient(C, real_imgs, fake_imgs.detach(), epsilon)gp gradient_penalty(gradient)loss_C get_crit_loss(fake_C_out, real_C_out, gp, c_lambda10)loss_C.backward()opt_C.step()loss_C_iterloss_C.item()/crit_repeats### Train Gloss_G_iter0for _ in range(gen_repeats):opt_G.zero_grad()ztorch.randn(x_len,noise_channels,1,1).to(device)fake_C_out C(G(z))loss_G get_gen_loss(fake_C_out)loss_G.backward()opt_G.step()loss_G_iterloss_G.item()/gen_repeats结果 这是经过 10 个 epoch 训练后获得的结果。与传统 GAN 一样生成的图像随着时间的推移变得更加真实。WANDB 项目的所有结果都可以在这里找到。 结论 生成对抗网络一直是深度学习社区的热门话题。由于 GAN 传统训练方法的缺点WGAN 随着时间的推移变得越来越流行。这主要是因为它对模式崩溃具有鲁棒性并且不存在梯度消失问题。在本文中我们实现了一个能够生成人脸的简单 WGAN 模型。 请随意查看 GitHub 代码。如有任何意见、建议和意见我们将不胜感激。 Reference [1] GAN specialization on coursera [2] Arjovsky, Martin et al. “Wasserstein GAN” [3] Gulrajani, Ishaan et al. “Improved Training of Wasserstein GANs” [4] Goodfellow, Ian et al. “Generative Adversarial Networks” [5] Vincent Herrmann, “Wasserstein GAN and the Kantorovich-Rubinstein Duality” [6] Karras, Tero et al. “A Style-Based Generator Architecture for Generative Adversarial Networks” 本文译自Udith Haputhanthri的博文。
http://www.zqtcl.cn/news/280946/

相关文章:

  • 阿里云域名怎么做网站对网站进行seo优化
  • 响应式网站建设合同11月将现新冠感染高峰
  • 做网站客户一般会问什么问题百度云网盘资源分享网站
  • 网站设计中超链接怎么做艺术设计
  • 卡盟网站建设wordpress优化代码
  • 做网站需要什么技术员商城型网站开发网站建设
  • discuz做地方门户网站网站大全免费完整版
  • 莆田人做的网站一天赚2000加微信
  • 阿里云网站访问不了怎么办做网站二维码
  • 手机商城网站建设可采用的基本方式有
  • 网站备案管理做广告公司网站建设价格
  • 绵阳专业网站建设公司上海外贸公司排名榜
  • 如何做英文系统下载网站快速排名工具免费
  • 苏州建网站必去苏州聚尚网络网页视频提取在线工具
  • 网站建设服务市场分析百度集团
  • 网站怎么企业备案信息做网站业务员如何跟客户沟通
  • 如何网站推广知名的集团门户网站建设费用
  • 网站入口设计规范专门做喷涂设备的网站
  • 最简单网站开发软件有哪些企业管理培训课程培训机构
  • 桂城网站制作公司wordpress 导航网站
  • 一个公司做网站需要注意什么条件网站备案 登陆
  • 百度网站介绍显示图片装修公司一般多少钱一平方
  • 网站销售如何做业绩我找伟宏篷布我做的事ko家的网站
  • 建立网站有哪些步骤?jsp网站开发详细教程
  • 网站怎么做直播功能旅游做攻略用什么网站
  • 企业外贸营销型网站如何写好软文推广
  • 免费建站的网址个人网站建设程序设计
  • 淘宝网站建设违规吗上海大公司
  • 大淘客怎么自己做网站自己开网站能赚钱吗
  • 大型门户网站开发北京网站建设管庄