网站正能量免费下载,手机制作报价单app,专业做网站广州,汽车网站大全汽车网阶段1#xff1a;GAN是个啥#xff1f;
生成式对抗网络#xff08;Generative Adversarial Networks, GAN#xff09;#xff0c;名字听着就有点“对抗”的意思#xff0c;没错#xff01;它其实是两个神经网络互相斗智斗勇的游戏#xff1a;
生成器#xff08;Gene…阶段1GAN是个啥
生成式对抗网络Generative Adversarial Networks, GAN名字听着就有点“对抗”的意思没错它其实是两个神经网络互相斗智斗勇的游戏
生成器Generator负责造假生成一些以假乱真的数据。判别器Discriminator负责打假判断数据是真还是假。
想象一下生成器是个假币制造商判别器是个验钞机。假币制造商不断提升造假能力验钞机也不断升级打假技巧。最终的目标是生成的假币足以以假乱真让验钞机无法区分。
生成式对抗网络GAN是一种由 Ian Goodfellow 和他的团队在2014年提出的深度学习模型。GAN 本质上是一种用于生成与真实数据分布相似的“新数据”的方法常用于图像生成、风格转换和数据增强等任务。
一、GAN 的基本概念
1. 两个网络生成器Generator和判别器Discriminator
GAN 的核心思想是利用两个神经网络相互对抗
生成器 (G) 学习生成接近真实数据的“假数据”。其目标是“骗过”判别器使其认为假数据是真的。判别器 (D) 学习区分真实数据和生成器生成的假数据。其目标是提高“识别假数据的能力”。
两者形成了一种动态博弈
生成器不断改进以生成更逼真的数据。判别器不断改进以更准确地区分真假数据。
最终目标生成器生成的数据和真实数据难以区分判别器无法给出明确的判断。
2. 训练目标
GAN 的训练目标可以通过以下损失函数来描述
判别器的损失最大化真实数据的得分最小化假数据的得分。生成器的损失最小化判别器对假数据的判断分数即尽量骗过判别器。
数学公式为 这里
D(x)D(x) 表示判别器给真实数据 xx 的打分。G(z)G(z) 表示生成器根据随机噪声 zz 生成的假数据。
3. GAN 的对抗过程
训练过程通常分为两步
更新判别器 让判别器学习如何区分真实和假数据。更新生成器 让生成器学习生成更真实的数据以骗过判别器。 二、直观例子警察与造假者
你可以将 GAN 的训练过程类比为“警察判别器与造假者生成器”之间的较量
一开始造假者技术拙劣警察很容易识破假币。随着时间推移造假者的造假技术逐渐提高而警察也在不断升级检测手段。最终假币与真币变得极为相似警察几乎无法分辨。 阶段2从头写个最简单的GAN
import torch
import torch.nn as nn
import torch.optim as optim# 1. 生成器Generator简单的全连接网络
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model nn.Sequential(nn.Linear(100, 256), # 输入 100 维噪声输出中间隐藏层 256 维nn.ReLU(), # 激活函数 ReLU增加非线性nn.Linear(256, 784), # 隐藏层输出 784 维数据28x28 图像展平后nn.Tanh() # 将输出限制到 [-1, 1]方便后续训练)def forward(self, z):return self.model(z)# 2. 判别器Discriminator另一个简单的全连接网络
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model nn.Sequential(nn.Linear(784, 256), # 输入 784 维数据展平的图像nn.LeakyReLU(0.2), # 激活函数允许小负值更鲁棒nn.Linear(256, 1), # 输出一个值真 or 假nn.Sigmoid() # 输出概率范围 [0, 1])def forward(self, x):return self.model(x)# 初始化网络
G Generator()
D Discriminator()# 优化器
optimizer_G optim.Adam(G.parameters(), lr0.0002)
optimizer_D optim.Adam(D.parameters(), lr0.0002)# 损失函数二分类交叉熵
criterion nn.BCELoss()代码释疑
这段代码实现了生成式对抗网络GAN的生成器Generator和判别器Discriminator并为它们设置了优化器和损失函数。以下是对相关内容的详细解释帮助你理解各个部分的功能。 1. Generator 类生成器
生成器的作用是生成假数据用来骗过判别器。
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model nn.Sequential(nn.Linear(100, 256), # 输入 100 维噪声输出中间隐藏层 256 维nn.ReLU(), # 激活函数 ReLU增加非线性nn.Linear(256, 784), # 隐藏层输出 784 维数据28x28 图像展平后nn.Tanh() # 将输出限制到 [-1, 1]方便后续训练)def forward(self, z):return self.model(z)关键点 输入 生成器的输入是一个随机噪声 z形状为 [batch_size, 100]。噪声是生成器的起点让它从随机性中学习目标数据分布。 输出 输出 784 个值对应一张 28x28 的图像展平如 MNIST 数据。使用 Tanh 将输出限制在 [-1, 1] 区间通常是为了和真实数据的归一化范围一致。 2. Discriminator 类判别器
判别器的作用是判断输入数据是真实的还是生成的。
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model nn.Sequential(nn.Linear(784, 256), # 输入 784 维数据展平的图像nn.LeakyReLU(0.2), # 激活函数允许小负值更鲁棒nn.Linear(256, 1), # 输出一个值真 or 假nn.Sigmoid() # 输出概率范围 [0, 1])def forward(self, x):return self.model(x)关键点 输入 输入是展平的 28x28 图像784 维可以是真实数据或生成器的假数据。 输出 输出是一个概率值0 表示假1 表示真。使用 Sigmoid 将值映射到 [0, 1] 区间。 LeakyReLU 激活函数 LeakyReLU(0.2) 在输入为负值时保留一定斜率0.2解决 ReLU 的“死区”问题使训练更稳定。 3. 优化器
优化器用于更新模型的参数使损失函数逐渐减小。
optimizer_G optim.Adam(G.parameters(), lr0.0002)
optimizer_D optim.Adam(D.parameters(), lr0.0002)Adam 优化器 一种改进的梯度下降算法适用于深度学习模型尤其是 GAN。自动调整学习率提高收敛速度。 学习率 (lr0.0002) 学习率设置为 0.0002是 GAN 训练中一个常见的经验值。 目标 optimizer_G 优化生成器的参数使其生成更逼真的数据。optimizer_D 优化判别器的参数使其更好地区分真假数据。 4. 损失函数BCELoss
BCELoss 是二分类交叉熵损失函数用于计算判别器和生成器的损失。
criterion nn.BCELoss()什么是交叉熵
交叉熵是一种用来衡量两个概率分布相似度的损失函数公式如下 yiy_i真实标签1 表示真0 表示假。pip_i模型预测的概率值判别器的输出。
在 GAN 中的作用 判别器的损失 判别器的目标是区分真实数据和生成器生成的假数据。对于真实数据y 1对于假数据y 0。损失函数让判别器尽量输出接近真实标签的概率。 生成器的损失 生成器的目标是让判别器认为假数据是真实的。生成器通过 GAN 的损失函数间接影响判别器的输出目标是让判别器输出 y 1。 5. 上述代码小结
生成器 (G) 学习生成逼真的假数据。判别器 (D) 学习区分真实数据和假数据。损失函数 (BCELoss) 衡量模型输出概率和目标标签之间的差异。优化器 (Adam) 调整模型参数使损失函数最小化。
在训练过程中
生成器试图最小化生成器的损失。判别器试图最大化判别器的准确率。
这段代码是 GAN 的基础骨架你可以在此基础上进行实验比如用它来生成 MNIST 图像
题外话
PyTorch简称 torch是一个流行的开源深度学习框架它提供了许多用于构建和训练神经网络的功能。它特别以易用性、灵活性和性能而著称是机器学习和深度学习领域的常用工具之一。下面我们来了解一下 PyTorch 的作用以及在这段 GAN 代码中它是如何发挥作用的。
1. PyTorch 的基本功能
PyTorch 提供了以下几个关键功能 张量Tensor PyTorch 中的核心数据结构是张量torch.Tensor类似于 NumPy 的数组但是张量支持 GPU 加速。张量是神经网络中的数据载体存储输入数据、权重、偏置等。 自动求导Autograd PyTorch 提供自动求导功能能够计算神经网络中每一层的梯度简化了反向传播算法的实现。当你定义模型并传入数据后PyTorch 会自动计算损失函数的梯度并更新模型的参数。 构建和训练神经网络 使用 torch.nn 提供的模块可以方便地构建神经网络的各层如全连接层、卷积层、激活函数等。torch.optim 提供了优化算法如 SGD、Adam来训练模型。 GPU 加速 PyTorch 可以利用 GPU如 CUDA来加速计算。你可以将张量和模型移动到 GPU 上这样就能提高训练速度。 阶段3它们怎么斗起来
核心是两步
训练判别器真图片标为1假图片标为0看看它能不能区分真伪。训练生成器假图片骗过判别器努力让判别器给它打1分。
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 数据加载MNIST 数据集
transform transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
mnist datasets.MNIST(root./data, trainTrue, downloadTrue, transformtransform)
dataloader DataLoader(mnist, batch_size64, shuffleTrue)# 训练循环
epochs 10
for epoch in range(epochs):for i, (real_imgs, _) in enumerate(dataloader):# 1. 训练判别器 # 真数据real_imgs real_imgs.view(real_imgs.size(0), -1) # 展平图片real_labels torch.ones(real_imgs.size(0), 1) # 真图片标签为1# 假数据z torch.randn(real_imgs.size(0), 100) # 随机噪声fake_imgs G(z)fake_labels torch.zeros(real_imgs.size(0), 1) # 假图片标签为0# 判别器的预测和损失real_preds D(real_imgs)fake_preds D(fake_imgs.detach()) # 假图片不更新生成器loss_real criterion(real_preds, real_labels)loss_fake criterion(fake_preds, fake_labels)loss_D loss_real loss_fake# 优化判别器optimizer_D.zero_grad()loss_D.backward()optimizer_D.step()# 2. 训练生成器 z torch.randn(real_imgs.size(0), 100)fake_imgs G(z)fake_preds D(fake_imgs)loss_G criterion(fake_preds, real_labels) # 欺骗判别器的损失# 优化生成器optimizer_G.zero_grad()loss_G.backward()optimizer_G.step()# 打印进度if i % 200 0:print(fEpoch [{epoch1}/{epochs}], Step [{i}/{len(dataloader)}], fD Loss: {loss_D.item():.4f}, G Loss: {loss_G.item():.4f})这段代码实现了一个基本的 生成式对抗网络GAN 训练过程使用 MNIST 数据集 生成与真实手写数字类似的图像。执行这段代码会产生以下几个结果
1. 数据加载MNIST 数据集
首先代码通过 torchvision 中的 datasets.MNIST 加载了 MNIST 数据集。这个数据集包含了 60,000 张手写数字的训练图像和 10,000 张测试图像这里只使用了训练集。数据被转换为 PyTorch 张量并做了标准化处理使每个像素值在 [-1, 1] 之间。然后DataLoader 将数据划分为批次batch每次加载 64 张图像。
2. 训练循环
接下来代码进入训练循环在每个 epoch 中它会进行以下操作
1训练判别器Discriminator 真数据 从 MNIST 数据集中提取实际的手写数字图像将图像展平为 784 维28x28 的像素展平。创建真实标签所有真实图像的标签为 1。 假数据 从随机噪声 z100 维的向量中生成假图像。创建假的标签所有生成的假图像标签为 0。 判别器损失 判别器会分别计算它对真实数据和假数据的预测使用二元交叉熵损失 BCELoss 计算真实数据和假数据的损失。loss_real 是判别器对真实图像的损失loss_fake 是对假图像的损失最终判别器的总损失是两者之和 loss_D。 优化判别器 使用 optimizer_D.zero_grad() 清除先前的梯度进行反向传播并更新判别器的参数。
2训练生成器Generator
生成假图像 使用随机噪声 z 通过生成器生成一批假图像。生成器损失 生成器的目标是欺骗判别器让它认为生成的假图像是真实的。因此生成器的损失是判别器对这些假图像的预测希望是 1的损失即 loss_G。优化生成器 使用 optimizer_G.zero_grad() 清除先前的梯度进行反向传播并更新生成器的参数。
3. 打印进度
每训练 200 个批次代码会打印出当前 epoch 和 step 的进度并显示判别器和生成器的损失
Epoch [1/10], Step [0/938], D Loss: 0.6881, G Loss: 0.7014
Epoch [1/10], Step [200/938], D Loss: 0.6834, G Loss: 0.7102
...实际运行效果 执行结果 训练输出 在训练过程中随着生成器和判别器的不断优化你会看到输出的 D Loss判别器损失和 G Loss生成器损失。初始时这两个损失通常较大因为模型还没有学会如何生成和判断图像。随着训练的进行损失会逐渐减小表示生成器和判别器在相互博弈中逐渐变得更强。 图像生成 由于 GAN 的训练是一个对抗过程因此每个 epoch 训练后生成器的输出图像会逐渐接近真实图像的分布。生成器在训练中会变得越来越善于生成逼真的手写数字图像直到它能够生成看起来很像 MNIST 数据集中的真实数字。
小结
判别器学习区分真实和假图像给出图像是“真”还是“假”的概率。生成器学习生成越来越像真实手写数字的图像目的是“欺骗”判别器使判别器认为生成的假图像是真实的。
执行完这段代码后生成器G会经过 10 个 epoch 的训练逐步学会生成类似 MNIST 手写数字的图像。你可以根据损失值的变化和生成的图像的质量观察训练过程的进展。 阶段4GAN生成的图像是啥样
每训练一段时间我们让生成器画个画看看它有没有长进
import matplotlib.pyplot as pltdef show_images(generator, num_images16):z torch.randn(num_images, 100) # 随机噪声fake_imgs generator(z).view(num_images, 1, 28, 28) # 恢复图片形状fake_imgs (fake_imgs 1) / 2.0 # 把值范围从 [-1, 1] 变到 [0, 1]grid torch.cat([fake_imgs[i] for i in range(num_images)], dim2).squeeze(0)plt.imshow(grid.detach().numpy(), cmapgray)plt.axis(off) # 不显示坐标轴plt.savefig(generated_images.png, bbox_inchestight) # 保存图像到文件plt.close() # 关闭图形窗口 这是最终生成地图像 局部放大 是不是可以联想到生成式对抗网络的应用场景相当广泛比如半导体晶圆缺陷检测领域医学影像疾病识别领域等等。 阶段5GAN训练的问题
GAN不是一帆风顺的训练GAN像哄熊孩子生成器和判别器常常互相欺负对方导致训练不稳定。 怎么办我们可以尝试改进
改网络结构比如用更强大的卷积网络。改损失函数比如使用Wasserstein GAN。调参改动学习率、优化器等等。
这就是生成式对抗网络的基础啦希望它的斗智斗勇能让你觉得有趣你也可以试试用它生成其他类型的数据比如音乐、画作或者文字