WGAN的原理及实现

Posted by nickiwei on October 1, 2017

欢迎转载, 转载请注明出处及链接。

完整代码库请查看我的GithubRepo: https://github.com/nick6918/MyDeepLearning .部分代码参考了Stanford CS231n 课程作业。

DCGAN的问题

传统的DCGAN,利用一个卷积网络来拟合指定类型图像的生成分布。在保证相当优秀的performance的基础上, DCGAN仍存在以下问题:

1, 不稳定

DCGAN的核心问题在于:判别器训练得太好,生成器梯度消失,生成器loss降不下去;判别器训练得不好,生成器梯度不准,四处乱跑。

原论文中详细分析了为什么会出现这样的情况, 具体可查看论文, 总结来说, 在(近似)最优判别器下,最小化生成器的loss等价于最小化Pr与Pg之间的JS散度,而由于Pr与Pg几乎不可能有不可忽略的重叠,所以无论它们相距多远JS散度都是常数log 2,最终导致生成器的梯度(近似)为0,梯度消失。

2, 难于调优

由于DCGAN的D_Loss和G_Loss均不能有效反映出训练进程, 需要通过对若干结果进行枚举选择最好的一组训练参数。

3, 多样性缺乏

优化生成器loss函数,会等价于最小化一个不合理的距离衡量,导致两个问题,一是梯度不稳定,二是collapse mode即多样性不足。

WGAN的优势

与DCGAN相对应, WGAN可以说成功解决了上述多个问题:

1, 提高了模型的鲁棒性

彻底解决GAN训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度

2, 样本多样性

基本解决了collapse mode的问题,确保了生成样本的多样性

3, 良好的训练指标

训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,这个数值越小代表GAN训练得越好,代表生成器产生的图像质量越高(如题图所示)

3, 并不要求额外的精心的拓扑结构设计

以上一切好处不需要精心设计的网络架构,最简单的多层全连接网络就可以做到

Before WGAN

在WGAN第一篇论文中, 作者提出了加噪方案, 具体来说, 就是给生成样本和训练样本同步加噪, 强行让它们产生不可忽略的重叠。而一旦存在重叠,JS散度就能真正发挥作用,此时如果两个分布越靠近,它们“弥散”出来的部分重叠得越多,JS散度也会越小而不会一直是一个常数,于是(在第一种原始GAN形式下)梯度消失的问题就解决了。

这种方案下, 我们可以放心把分类器调至最优, 然后再来训练生成器, 大大简化了GAN的训练, 但是, 我们仍然没有一个好的训练指标。

WGAN距离

如前所述, 我们可以把原始的GAN的距离看成是训练分布和生成分布的JS散度, 正是由于JS散度容易坍塌到log2常数(分布间不具有不可忽略的重叠部分), 才导致梯度消失。加噪方案是强行让二者出现重叠, 便于训练; WGAN距离则是从根本上替换掉难于训练的JS散度。

WGAN

我们的目标转化为优化Pr, Pg使得W最小。 经过化简(具体查看论文), 这一公式转化为

WGAN

其中, fw是一个数学上可求的复杂距离公式, 我们希望再用一个神经网络来拟合他。 这是WGAN最精髓的地方, 我们希望用一个神经网络来拟合分布之间的距离, 然后再用一个神经网络来优化这个神经网络构成的距离。

注意, 我们所做的这些的核心都在于, 将G(z)视为常量的情况下, 去优化Disciminator. 具体算法如下:

WGAN

对比GAN的训练算法

WGAN

可见, 之前我们训练GAN时, 每个iteration训练k次Disciminator(通常为1), 1次Generator. 现在, 每个iteration,我们训练 Ncritic 次Discriminator, 且之前对discrimator的训练为沿D_loss下降, 现在我们对D_loss本身也需要梯度下降。

WGAN

我们将w沿着梯度更新后的值作为梯度去更新w。

最终的实现如下:

def wgangp_loss(logits_real, logits_fake, batch_size, x, G_sample):
    """Compute the WGAN-GP loss."""
    
    D_loss = - tf.reduce_mean(logits_real) + tf.reduce_mean(logits_fake)
    G_loss = - tf.reduce_mean(logits_fake)

    # lambda from the paper
    lam = 10
    
    # random sample of batch_size (tf.random_uniform)
    eps = tf.random_uniform([batch_size,1], minval=0.0, maxval=1.0)
    x_hat = eps*x+(1-eps)*G_sample
    #diff = G_sample - x
    #interp = x + (eps * diff)
    
    # Gradients of Gradients is kind of tricky!
    with tf.variable_scope('',reuse=True) as scope:
        grad_D_x_hat = tf.gradients(discriminator(x_hat), x_hat)
    
    grad_norm = tf.norm(grad_D_x_hat[0], axis=1, ord='euclidean')
    grad_pen = tf.reduce_mean(tf.square(grad_norm-1))
    #slopes = tf.sqrt(tf.reduce_sum(tf.square(grad_D_x_hat), reduction_indices=[1]))
    #grad_pen = tf.reduce_mean((slopes - 1.) ** 2)
       
    D_loss += lam*grad_pen

    return D_loss, G_loss

除此之外, WGAN还需要注意以下几点:

loss中

1, D_loss的更新不能使用带moment的方法, 如adam.

2, loss计算不取sigmoid.

training中,

1, 判别器参数更新后clip到c内。

2, 判别器最后一层不取sigmoid.

小结

WGAN的数学原理确实比较难于理解, 但是我们只要记住WGAN的核心思想是将距离公式本身也用神经网络拟合并用梯度下降, 具体来说, 就是D_loss本身在用于更新w之前, 本身也需要梯度上升, 具体的上升值为D(x_hat)的梯度的norm+mean.


快速联系作者

欢迎关注我的知乎: https://www.zhihu.com/people/NickWey

或直接在Github上联系我: https://github.com/nick6918