GAN系列 您所在的位置:网站首页 pytorch源码多少行 GAN系列

GAN系列

2023-05-14 11:34| 来源: 网络整理| 查看: 265

2014年,Ian Goodfellow和他在University of Montreal 的同事们发表了一篇令人惊叹的论文,向世界介绍了对抗生成网络, GANs。通过计算图和博弈论的创新组合,他们表明,如果有足够的建模能力,两个相互对抗的模型将能够通过普通的传统反向传播进行共同训练。

模型扮演了两个截然不同的对抗的角色。给定一些实际数据集R,G是生成器,试图创建看起来就像真实数据的假数据,而D是鉴别器,从真实集合或G中获取数据并标记差异。Goodfellow的比喻(也是一个很好的比喻)是,G就像一群伪造者,试图将真实的绘画与他们的输出(伪造结果)相匹配,而D则是试图分辨出来的侦探团队。(除非在这种情况下,伪造者G永远不会看到原始数据,这样只有鉴别器D的判断。G就像盲目的伪造者。)

在理想的情况下,D和G都会随着对抗时间的推移而变得更好,直到G基本上成为真正文章的“主要伪造者”,并且D at a loss,“无法区分这两种分布”。

在实践中,Goodfellow所展示的是G将能够对原始数据集执行一种无监督学习形式,找到某种方式以(可能)低得多的方式表示该数据。正如Yann LeCun所说,无监督学习是真正人工智能的“蛋糕”。

这种强大的技术似乎只需要成吨的代码才能开始,对吧?不。使用PyTorch,我们实际上可以在50行代码中创建一个非常简单的GAN。实际上只有5个组件需要考虑:

R:原始的真实数据集I:作为熵源进入生成器Generator的随机噪声G:尝试复制/模仿原始数据集的生成器D:试图分辨生成器G的输出和真实数据R的鉴别器实际的“训练”的循环,使得生成器G能够骗过鉴别器D,同时D能提防G。

R:在我们的例子中,我们将从最简单的R - a bell curve开始。该函数采用均值和标准差,并返回一个函数,该函数从高斯函数中提供具有这些参数的样本数据的正确形状。在我们的示例代码中,我们将使用4.0的平均值和1.25的标准差。I:对生成器的输入也是随机的,但是为了使我们的工作更加困难,让我们使用uniform distribution而不是normal distribution。这意味着我们的模型G不能简单地将输入移位/缩放到复制R,而是必须以非线性方式重塑数据。 G:生成器是标准的前馈图feedforward graph - 两个隐藏层,三个linear maps。我们正在使用双曲线切线hyperbolic tangent 激活函数。G 从I 中得到uniformly distribution的数据样本,以某种方式模仿normally distribution的样本R,而不用知道R是什么。D:鉴别器D和生成器D的代码非常相似;:包含两个隐藏层和三个linear maps的前馈图。这里的激活函数是一个sigmoid函数 。它将从真实数据集R或生成器G的输出中获取样本,并将输出0到1之间的单个标量,解释为“假”与“真实”。换句话说,这大约是最基本的神经网络可以做到的。最后,训练循环在两种模式之间交替:首先在真实数据和假数据上训练鉴别器D,通过使用准确标签(就像是警察学校~); 然后训练生成器G骗过D,使用不准确标签(这更像是来自Ocean's Eleven十一罗汉的骗术)。

即使你之前没有见过PyTorch,也可以理解以上的代码代表的意思。在第一个(绿色)部分,我们向鉴别器D送入两种类型的数据,并对D的猜测与实际标签,应用损失计算differentiable criterion。这个喂数据是“forward”的一步; 然后我们显式调用'backward()'来计算梯度gradients,然后在d_optimizer step()调用中用它来更新鉴别器D参数。生成器G被使用但未在此训练。

然后在最后一个(红色)部分,我们对生成器G做同样的事情- 注意我们也通过鉴别器D(实际上是给伪造者一个侦探来练习)运行生成器G的输出,但是在此步中,我们不优化或改变鉴别器D。因为不希望“侦探”鉴别器D学习错误的标签。因此,我们只调用g_optimizer.step()。

如上所述…… 就是这样。还有一些其他样板代码,但GAN特定的东西只是那5个组件。

经过几千轮D和G之间的这种对抗,可以得到什么?鉴别器D很快就会好起来(当生成器G慢慢向上移动时),但是一旦达到某种程度的力量,G就会有一个有价值的对手并开始改善。

超过5,000轮训练,训练鉴别器D 20次,然后生成器G每轮20次,生成器G的输出平均值超过4.0,但随后回到相当稳定,正确的范围(如下图左图)。同样,标准偏差最初在错误的方向上下降,但随后上升到所需的1.25范围(如下图右图),与真实数据集R匹配。

所以基本的数据最终与真实数据R匹配。更好的情形是怎么样?分布的形状看起来是否合适?毕竟,当然可以uniform distribution,平均值为4.0,标准偏差为1.25,但这与真实数据集R不匹配。让我们看看生成器G的最终分布:

结果并不差。右尾比左边有点胖,但是,歪斜和峰度是原始高斯分布的模样。

生成器G几乎完美地恢复原始数据的分布R - 并且鉴别器D在角落里畏缩,对自己喃喃自语,无法从中说出事实。这正是我们想要的行为(参见Goodfellow论文的图1)。这里总共才不到50行代码。

重要提醒:GAN很挑剔,而且很脆弱。当它们进入奇怪的状态时,往往不会在没有一点哄骗的情况下出来。运行示例代码十次(每次超过5,000轮)显示以下十个分布:

如上图所示,十次运行中的八次导致相当不错的最终分布 - 类似于高斯分布。但其中两次没有生成这样的分布。在一种情况下(第5次运行),有一个凹面分布,平均值约为6.0,在最后一次运行中(#10),在-11处有一个狭窄的峰值!当您开始在几乎任何环境中应用GAN时,您会看到这种现象 - GAN并不像平均监督学习工作流程那样稳定。但是当它们工作时,它们看起来几乎是神奇的。

完整代码链接:

使用50行PyTorch代码构建对抗生成网络 (GANs)



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有