算法描述
生成对抗网络(Generative Adversarial Nets)模型中的两位博弈方分别有生成网络(Generator)与判别网络(Discriminator)充当。当生成网络G捕捉到样本数据分布,用服从某一分布的噪声z生成一个类似真实训练数据的样本,与真实样本越接近越好;判别网络D一般是一个二分类模型,在本文中D是一个多分类器,用于估计一个样本来自于真实数据的概率,如果样本来自于真实数据,则D输出大概率,否则输出小概率。本文中,判别网络需要在此基础上实现分类功能。
在训练的过程中,需要固定一方,更新另一方的网络状态,如此交替进行。在整个训练的过程中,双方都极力优化自己的网络,从而形成竞争对抗,知道双方达到一个动态的平衡。此时生成网络训练出来的数据与真实数据的分布几乎相同,判别网络也无法再判断出真伪。
本文中生成对抗网络主要分为两部分,生成网络(Generator)与判别网络(Discriminator)。向生成网络内输入噪声,通过多次反卷积的方式得到一个28x28x1的图像作为X_fake,此时将真实的图像X_real与生成器生成的X_fake放入判别网络,判别网络使用多次卷积与Sigmoid函数并通过交叉熵函数计算出判别网络的损失函数D_loss,通过判别网络的损失函数D_loss计算得到生成网络损失函数G_loss。使用G_loss与D_loss对生成网络与判别网络进行参数调整。

算法流程
1.输入噪声z
2.通过生成网络G得到X_fake=G(z)
3.从数据集中获取真实数据X_real
4.通过判别网络D计算D(real logits)=D(X_real)
5.通过判别网络D计算D(fake logits)=D(X_fake)
6.使用交叉熵函数做损失函数根据D(real logits)计算D(loss real)
7.使用交叉熵函数做损失函数根据D(fake logits)计算D(loss fak

2558

被折叠的 条评论
为什么被折叠?



