论文笔记2-TransGAN:-Two Pure Transformers Can Make One Strong GAN, and That Can Scale Up

TransGAN

       文章开篇的模型命名缩写就明示了该模型的构成,即又突然说form构成的Gan网络,这里我们对其中相关前置知识会简单提及,包括Gan训练的一些改进和transform的结构更改。但是关于Gan和Transformer的基础知识我们不会过多提及,网络上有大量的帖子,肯定理解和阐述都比我清楚的多。关于tf,这里推荐一下帖子The Annotated Transformer

contributions

       作者的贡献自己总结为三点:

1). 第一个由完全非卷积构成的Gan网络模型,调整了attention机制细节,平衡了内存,全局特征和空间差异。

2). 一些训练角度的小trick。

3). STL-10和CIFAR-10等数据集达到SOTA效果。

       在这里不仅介绍一些文中用到的,还会介绍一下在Transformer和Gan领域比较有名的同时也是文中提到的相关方法。

Wasserstein GAN

       众所周知,Gan和强化学习都是出了名的难训练。从14年被提出开始,Gan一直有着众多问题,比如训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等。DCGAN依靠枚举搜索更好的架构,没有解决问题,而是避开差的结果。作者在两篇论文里——第一篇《Towards Principled Methods for Training Generative Adversarial Networks》从数学理论上分析了原始GAN的问题所在,从而针对性地给出了改进要点;第二篇《Wasserstein GAN》里面,又再从这个改进点出发推了公式定理,最终给出了改进的算法实现流程。具体的数学推导我实话实说搞不太懂,有兴趣可以自己去搜一下。

       长话短说,原始GAN问题的根源可以归结为两点,一是等价优化的距离衡量(KL散度、JS散度)不合理,二是生成器随机初始化后的生成分布很难与真实分布有不可忽略的重叠。最后对原始模型修改如下:判别器去掉sigmod,loss不取log,更新判别器参数后绝对值截断,淘汰使用动量的优化算法。文中引入了Wasserstein距离,由于它相对KL散度与JS散度具有优越的平滑特性,理论上可以解决梯度消失问题。接着通过数学变换将Wasserstein距离写成可求解的形式,利用一个参数数值范围受限的判别器神经网络来最大化这个形式,就可以近似Wasserstein距离。在此近似最优判别器下优化生成器使得Wasserstein距离缩小,就能有效拉近生成分布与真实分布。WGAN既解决了训练不稳定的问题,也提供了一个可靠的训练进程指标,而且该指标确实与生成样本的质量高度相关。

Style-Based generator

       典型的风格迁移工作,在判别器中,本文提出了感知路径长度(perceptual path length )和线性可分性(linear separability)来评估生成器。通过这种评估方法发现本文比传统的生成网络允许更加线性的、耦合性更低的变量因子。最后还提出了一个新人脸数据集,这里略过。



Progressive Training

       核心思想是逐步训练生成器和分别器:从低分辨率开始,随着训练进程推进,逐步增加新的层来提炼细节。这种方法不仅加快了训练速度并且更加稳定,可以产生高质量的图像。同时提出了一些实施的细节对于消除生成器和分辨器的不好的竞争,也就是训练的小trick。

       这种递进训练的方法保存了图像粗粒度的特征同时用残差结构和结合新训练的细粒度特征,最后用线性叠加整合,小尺寸用插值法直接增大。另外,本文在提高生产图片多样性方面,人工的在网络加入一些信息,如方差或者标准差。在应对信号幅值过大的问题上,还提出了pixel-normalization,抑制信号的幅值。

Spectral Normalization

       上面提到过的Wasserstein-Gan从每层神经网络的参数矩阵的谱范数角度,引入利普希茨连续性约束,使神经网络对输入扰动具有较好的非敏感性,从而使训练过程更稳定,更容易收敛。比如:深度学习模型存在“对抗攻击样本”,比如图片只改变一个像素就给出完全不一样的分类结果,这就是模型对输入过于敏感的案例。也就是说,在局部最小点附件,一点小小的变动将产生较大的影响,导致泛化性能不好。

       而本文即在上述工作基础上的延续,Spectral normalization for generative adversarial network” (以下简称 Spectral Norm) 使用一种更优雅的方式使得判别器 D 满足利普希茨连续性,限制了函数变化的剧烈程度,从而使模型更稳定。

20211201200248

       在一维空间中,很容易看出 y=sin(x) 是 1-Lipschitz的,它的最大斜率是 1。对于一个矩阵A,除以它的 spectral norm(即A^tA最大特征值的开根号)可以使其具有 1-Lipschitz continuity。(证明参考here)

       对于整个网络,由于激活函数通常都是满足 1-Lipschitz的,所以只需要保证卷积部分也满足,就可以推至整个网络。这里需要对各层的卷积核W除以其最大奇异值,但是由于每个迭代对层做SVD分解计算量过于离谱,文中使用了power iteration 的方法去迭代得到奇异值的近似解。很显然,为了保证得到的Lipschitz 连续性,就不能继续使用BatchNorm。

Technical Approach

Memory-friendly Generator

       众所周知,原始的transformer模型的自注意力有着O(n^2)的成本,如果直接将图片展成词向量,那么即使是32*32的低分辨率图像,也会生成1024长度的词向量,这种方案不具有可伸缩性。因此,逐渐增加序列减少嵌入维度不仅可以提取粗细粒度,同时也解决第一步图像的序列转化问题。

20211201202515

       每个阶段最后有一个上采样模块,使用双三次上采样将图片分辨率改变,分辨率更高的阶段将抛弃双三次采样而使用pixivshuffle,这是为了将嵌入维度减少到1/4,。总体来说,这种升级的金字塔结构减轻了内存和计算爆炸,不停重复直到达到需要的分辨率。

Multi-scale Discriminator

       分成不同大小的patch作为输入,每个尺度的做法与生成器类似,将一维句子重塑为二维特征图,并在每个阶段之间采用平均池化层对特征图分辨率进行下采样。 通过在每个阶段递归地形成变换器块,形成了一个金字塔架构,其中提取了多尺度表示。 在这些块的末尾,一个 [cls] 标记被附加到一维序列的开头,然后由分类头获取以输出真/假预测。

Grid Self-Attention: A Scalable Variant of Self-Attention for Image Generation

       self-attention代价太大了,不利于长序列建模,因此本文在这里为高分辨率的生成任务设计了Grid-attention。这里叫做网格注意力,字面意思,在分辨率小时候保持标准自注意力机制,分辨率较大时,将图片分割成不同的片,每个片内进行qkv的计算。这种边界造成的影响在训练前期比较明显,但是随着迭代进行,它会逐渐消失。

20211203152345

Exploring the Training Recipe

数据增强

       在本文设计上使用饿了3个基本的数据增强基线,{translation,Cutout,Color}给TransGan带来巨大的提升,反之CNN构成的Gan网络就几乎没有提升。

相对位置编码

       与绝对位置相比,相对位置编码学习了本地内容之间更强的“关系”,在大规模案例中带来了重要的性能提升,并从那时起得到广泛使用。我们还观察到它不断改进 TransGAN,尤其是在更高分辨率 解决方案数据集。 因此,我们将其应用于生成器和鉴别器的可学习绝对位置编码之上。

归一化层修改

       原始Transformer默认使用normalize层,但是本文增寻一项前者的工作,将其替换为令牌规模缩放层,即:

20211203154235

       这里C是词嵌入的维度,X和Y对应令牌缩放前后的大小。
和一些最新的有学习参数的normaliztion层不同,对于transGan来说,这种简单的缩放效果最好,也进一步提升了FID。

参考

  1. Spectral Normalization 谱归一化
  2. 《Progressive Training of Multi-level Wavelet Residual Networks for Image Denoising》阅读笔记
  3. 令人拍案叫绝的Wasserstein GAN
  4. [Style Transfer]——A Style-Based Generator Architecture for Generative Adversarial Networks
打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2015-2023 Tritonchen
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

微信