'訓練GAN,你應該知道的二三事'

數學 人工智能 機器之心 2019-07-20
"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

訓練GAN,你應該知道的二三事

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨著訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

訓練GAN,你應該知道的二三事

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨著訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

訓練GAN,你應該知道的二三事

仔細盯著上面式子幾秒鐘,不難發現我們優化的目標是相互悖論的,因為 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分佈拉近,但是優化 -JS 是把兩個分佈推遠,這「一推一拉」就會導致梯度更新非常不穩定。此外,我們知道 KL 不是對稱的,對於生成器無法生成真實樣本的情況,KL 對 loss 的貢獻非常大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻非常小。

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

訓練GAN,你應該知道的二三事

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨著訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

訓練GAN,你應該知道的二三事

仔細盯著上面式子幾秒鐘,不難發現我們優化的目標是相互悖論的,因為 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分佈拉近,但是優化 -JS 是把兩個分佈推遠,這「一推一拉」就會導致梯度更新非常不穩定。此外,我們知道 KL 不是對稱的,對於生成器無法生成真實樣本的情況,KL 對 loss 的貢獻非常大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻非常小。

訓練GAN,你應該知道的二三事

而 JS 是對稱的,不會改變 KL 的這種不公平的行為。這就解釋了我們經常在訓練階段經常看見兩種情況,一個是訓練 loss 抖動非常大,訓練不穩定;另外一個是即使達到了穩定訓練,生成器也大概率上只生成一些安全保險的樣本,這樣就會導致模型缺乏多樣性。

此外,在有監督的機器學習裡面,經常會出現一些過擬合的情況,然而 GANs 也不例外。當生成器訓練得越來越好時候,生成的數據越接近於有限樣本集合裡面的數據。特別是當訓練集裡面包含有錯誤數據時候,判別器會過擬合到這些錯誤的數據,對於那些未見的數據,判別器就不能很好的指導生成器去生成可信的數據。這樣就會導致 GANs 的泛化能力比較差。

綜上所述,原始的 GANs 在訓練穩定性、模式多樣性以及模型泛化性能方面存在著或多或少的問題,後續學術上的工作大多也是基於此進行改進(填坑)。

訓練 GAN 的常用策略

上一節都是基於一些簡單的數學或者經驗的分析,但是根本原因目前沒有一個很好的理論來解釋;儘管理論上的缺陷,我們仍然可以從一些經驗中發現一些實用的 tricks,讓你的 GANs 不再難訓。這裡列舉的一些 tricks 可能跟 ganhacks 裡面的有些重複,更多的是補充,但是為了完整起見,部分也添加在這裡。

1. model choice

如果你不知道選擇什麼樣的模型,那就選擇 DCGAN[3] 或者 ResNet[4] 作為 base model。

2. input layer

假如你的輸入是一張圖片,將圖片數值歸一化到 [-1, 1];假如你的輸入是一個隨機噪聲的向量,最好是從 N(0, 1) 的正態分佈裡面採樣,不要從 U(0,1) 的均勻分佈裡採樣。

3. output layer

使用輸出通道為 3 的卷積作為最後一層,可以採用 1x1 或者 3x3 的 filters,有的論文也使用 9x9 的 filters。(注:ganhacks 推薦使用 tanh)

4. transposed convolution layer

在做 decode 的時候,儘量使用 upsample+conv2d 組合代替 transposed_conv2d,可以減少 checkerboard 的產生 [5];

在做超分辨率等任務上,可以採用 pixelshuffle [6]。在 tensorflow 裡,可以用 tf.depth_to_sapce 來實現 pixelshuffle 操作。

5. convolution layer

由於筆者經常做圖像修復方向相關的工作,推薦使用 gated-conv2d [7]。

6. normalization

雖然在 resnet 裡的標配是 BN,在分類任務上表現很好,但是圖像生成方面,推薦使用其他 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推薦使用 pixel normalization [10]。假如你有選擇困難症,那就選擇大雜燴的 normalization 方法——switchable normalization [11]。

7. discriminator

想要生成更高清的圖像,推薦 multi-stage discriminator [10]。簡單的做法就是對於輸入圖片,把它下采樣(maxpooling)到不同 scale 的大小,輸入三個不同參數但結構相同的 discriminator。

8. minibatch discriminator

由於判別器是單獨處理每張圖片,沒有一個機制能告訴 discriminator 每張圖片之間要儘可能的不相似,這樣就會導致判別器會將所有圖片都 push 到一個看起來真實的點,缺乏多樣性。minibatch discriminator [22] 就是這樣這個機制,顯式地告訴 discriminator 每張圖片應該要不相似。在 tensorflow 中,一種實現 minibatch discriminator 方式如下:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

訓練GAN,你應該知道的二三事

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨著訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

訓練GAN,你應該知道的二三事

仔細盯著上面式子幾秒鐘,不難發現我們優化的目標是相互悖論的,因為 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分佈拉近,但是優化 -JS 是把兩個分佈推遠,這「一推一拉」就會導致梯度更新非常不穩定。此外,我們知道 KL 不是對稱的,對於生成器無法生成真實樣本的情況,KL 對 loss 的貢獻非常大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻非常小。

訓練GAN,你應該知道的二三事

而 JS 是對稱的,不會改變 KL 的這種不公平的行為。這就解釋了我們經常在訓練階段經常看見兩種情況,一個是訓練 loss 抖動非常大,訓練不穩定;另外一個是即使達到了穩定訓練,生成器也大概率上只生成一些安全保險的樣本,這樣就會導致模型缺乏多樣性。

此外,在有監督的機器學習裡面,經常會出現一些過擬合的情況,然而 GANs 也不例外。當生成器訓練得越來越好時候,生成的數據越接近於有限樣本集合裡面的數據。特別是當訓練集裡面包含有錯誤數據時候,判別器會過擬合到這些錯誤的數據,對於那些未見的數據,判別器就不能很好的指導生成器去生成可信的數據。這樣就會導致 GANs 的泛化能力比較差。

綜上所述,原始的 GANs 在訓練穩定性、模式多樣性以及模型泛化性能方面存在著或多或少的問題,後續學術上的工作大多也是基於此進行改進(填坑)。

訓練 GAN 的常用策略

上一節都是基於一些簡單的數學或者經驗的分析,但是根本原因目前沒有一個很好的理論來解釋;儘管理論上的缺陷,我們仍然可以從一些經驗中發現一些實用的 tricks,讓你的 GANs 不再難訓。這裡列舉的一些 tricks 可能跟 ganhacks 裡面的有些重複,更多的是補充,但是為了完整起見,部分也添加在這裡。

1. model choice

如果你不知道選擇什麼樣的模型,那就選擇 DCGAN[3] 或者 ResNet[4] 作為 base model。

2. input layer

假如你的輸入是一張圖片,將圖片數值歸一化到 [-1, 1];假如你的輸入是一個隨機噪聲的向量,最好是從 N(0, 1) 的正態分佈裡面採樣,不要從 U(0,1) 的均勻分佈裡採樣。

3. output layer

使用輸出通道為 3 的卷積作為最後一層,可以採用 1x1 或者 3x3 的 filters,有的論文也使用 9x9 的 filters。(注:ganhacks 推薦使用 tanh)

4. transposed convolution layer

在做 decode 的時候,儘量使用 upsample+conv2d 組合代替 transposed_conv2d,可以減少 checkerboard 的產生 [5];

在做超分辨率等任務上,可以採用 pixelshuffle [6]。在 tensorflow 裡,可以用 tf.depth_to_sapce 來實現 pixelshuffle 操作。

5. convolution layer

由於筆者經常做圖像修復方向相關的工作,推薦使用 gated-conv2d [7]。

6. normalization

雖然在 resnet 裡的標配是 BN,在分類任務上表現很好,但是圖像生成方面,推薦使用其他 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推薦使用 pixel normalization [10]。假如你有選擇困難症,那就選擇大雜燴的 normalization 方法——switchable normalization [11]。

7. discriminator

想要生成更高清的圖像,推薦 multi-stage discriminator [10]。簡單的做法就是對於輸入圖片,把它下采樣(maxpooling)到不同 scale 的大小,輸入三個不同參數但結構相同的 discriminator。

8. minibatch discriminator

由於判別器是單獨處理每張圖片,沒有一個機制能告訴 discriminator 每張圖片之間要儘可能的不相似,這樣就會導致判別器會將所有圖片都 push 到一個看起來真實的點,缺乏多樣性。minibatch discriminator [22] 就是這樣這個機制,顯式地告訴 discriminator 每張圖片應該要不相似。在 tensorflow 中,一種實現 minibatch discriminator 方式如下:

訓練GAN,你應該知道的二三事

上面是通過一個可學習的網絡來顯示度量每個樣本之間的相似度,PGGAN 裡提出了一個更廉價的不需要學習的版本,即通過統計每個樣本特徵每個像素點的標準差,然後取他們的平均,把這個平均值複製到與當前 feature map 一樣空間大小單通道,作為一個額外的 feature maps 拼接到原來的 feature maps 裡,一個簡單的 tensorflow 實現如下:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

訓練GAN,你應該知道的二三事

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨著訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

訓練GAN,你應該知道的二三事

仔細盯著上面式子幾秒鐘,不難發現我們優化的目標是相互悖論的,因為 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分佈拉近,但是優化 -JS 是把兩個分佈推遠,這「一推一拉」就會導致梯度更新非常不穩定。此外,我們知道 KL 不是對稱的,對於生成器無法生成真實樣本的情況,KL 對 loss 的貢獻非常大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻非常小。

訓練GAN,你應該知道的二三事

而 JS 是對稱的,不會改變 KL 的這種不公平的行為。這就解釋了我們經常在訓練階段經常看見兩種情況,一個是訓練 loss 抖動非常大,訓練不穩定;另外一個是即使達到了穩定訓練,生成器也大概率上只生成一些安全保險的樣本,這樣就會導致模型缺乏多樣性。

此外,在有監督的機器學習裡面,經常會出現一些過擬合的情況,然而 GANs 也不例外。當生成器訓練得越來越好時候,生成的數據越接近於有限樣本集合裡面的數據。特別是當訓練集裡面包含有錯誤數據時候,判別器會過擬合到這些錯誤的數據,對於那些未見的數據,判別器就不能很好的指導生成器去生成可信的數據。這樣就會導致 GANs 的泛化能力比較差。

綜上所述,原始的 GANs 在訓練穩定性、模式多樣性以及模型泛化性能方面存在著或多或少的問題,後續學術上的工作大多也是基於此進行改進(填坑)。

訓練 GAN 的常用策略

上一節都是基於一些簡單的數學或者經驗的分析,但是根本原因目前沒有一個很好的理論來解釋;儘管理論上的缺陷,我們仍然可以從一些經驗中發現一些實用的 tricks,讓你的 GANs 不再難訓。這裡列舉的一些 tricks 可能跟 ganhacks 裡面的有些重複,更多的是補充,但是為了完整起見,部分也添加在這裡。

1. model choice

如果你不知道選擇什麼樣的模型,那就選擇 DCGAN[3] 或者 ResNet[4] 作為 base model。

2. input layer

假如你的輸入是一張圖片,將圖片數值歸一化到 [-1, 1];假如你的輸入是一個隨機噪聲的向量,最好是從 N(0, 1) 的正態分佈裡面採樣,不要從 U(0,1) 的均勻分佈裡採樣。

3. output layer

使用輸出通道為 3 的卷積作為最後一層,可以採用 1x1 或者 3x3 的 filters,有的論文也使用 9x9 的 filters。(注:ganhacks 推薦使用 tanh)

4. transposed convolution layer

在做 decode 的時候,儘量使用 upsample+conv2d 組合代替 transposed_conv2d,可以減少 checkerboard 的產生 [5];

在做超分辨率等任務上,可以採用 pixelshuffle [6]。在 tensorflow 裡,可以用 tf.depth_to_sapce 來實現 pixelshuffle 操作。

5. convolution layer

由於筆者經常做圖像修復方向相關的工作,推薦使用 gated-conv2d [7]。

6. normalization

雖然在 resnet 裡的標配是 BN,在分類任務上表現很好,但是圖像生成方面,推薦使用其他 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推薦使用 pixel normalization [10]。假如你有選擇困難症,那就選擇大雜燴的 normalization 方法——switchable normalization [11]。

7. discriminator

想要生成更高清的圖像,推薦 multi-stage discriminator [10]。簡單的做法就是對於輸入圖片,把它下采樣(maxpooling)到不同 scale 的大小,輸入三個不同參數但結構相同的 discriminator。

8. minibatch discriminator

由於判別器是單獨處理每張圖片,沒有一個機制能告訴 discriminator 每張圖片之間要儘可能的不相似,這樣就會導致判別器會將所有圖片都 push 到一個看起來真實的點,缺乏多樣性。minibatch discriminator [22] 就是這樣這個機制,顯式地告訴 discriminator 每張圖片應該要不相似。在 tensorflow 中,一種實現 minibatch discriminator 方式如下:

訓練GAN,你應該知道的二三事

上面是通過一個可學習的網絡來顯示度量每個樣本之間的相似度,PGGAN 裡提出了一個更廉價的不需要學習的版本,即通過統計每個樣本特徵每個像素點的標準差,然後取他們的平均,把這個平均值複製到與當前 feature map 一樣空間大小單通道,作為一個額外的 feature maps 拼接到原來的 feature maps 裡,一個簡單的 tensorflow 實現如下:

訓練GAN,你應該知道的二三事

9. GAN loss

除了第二節提到的原始 GANs 中提出的兩種 loss,還可以選擇 wgan loss [12]、hinge loss、lsgan loss [13]等。wgan loss 使用 Wasserstein 距離(推土機距離)來度量兩個分佈之間的差異,lsgan 採用類似最小二乘法的思路設計損失函數,最後演變成用皮爾森卡方散度代替了原始 GAN 中的 JS 散度,hinge loss 是遷移了 SVM 裡面的思想,在 SAGAN [14] 和 BigGAN [15] 等都是採用該損失函數。

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

訓練GAN,你應該知道的二三事

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨著訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

訓練GAN,你應該知道的二三事

仔細盯著上面式子幾秒鐘,不難發現我們優化的目標是相互悖論的,因為 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分佈拉近,但是優化 -JS 是把兩個分佈推遠,這「一推一拉」就會導致梯度更新非常不穩定。此外,我們知道 KL 不是對稱的,對於生成器無法生成真實樣本的情況,KL 對 loss 的貢獻非常大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻非常小。

訓練GAN,你應該知道的二三事

而 JS 是對稱的,不會改變 KL 的這種不公平的行為。這就解釋了我們經常在訓練階段經常看見兩種情況,一個是訓練 loss 抖動非常大,訓練不穩定;另外一個是即使達到了穩定訓練,生成器也大概率上只生成一些安全保險的樣本,這樣就會導致模型缺乏多樣性。

此外,在有監督的機器學習裡面,經常會出現一些過擬合的情況,然而 GANs 也不例外。當生成器訓練得越來越好時候,生成的數據越接近於有限樣本集合裡面的數據。特別是當訓練集裡面包含有錯誤數據時候,判別器會過擬合到這些錯誤的數據,對於那些未見的數據,判別器就不能很好的指導生成器去生成可信的數據。這樣就會導致 GANs 的泛化能力比較差。

綜上所述,原始的 GANs 在訓練穩定性、模式多樣性以及模型泛化性能方面存在著或多或少的問題,後續學術上的工作大多也是基於此進行改進(填坑)。

訓練 GAN 的常用策略

上一節都是基於一些簡單的數學或者經驗的分析,但是根本原因目前沒有一個很好的理論來解釋;儘管理論上的缺陷,我們仍然可以從一些經驗中發現一些實用的 tricks,讓你的 GANs 不再難訓。這裡列舉的一些 tricks 可能跟 ganhacks 裡面的有些重複,更多的是補充,但是為了完整起見,部分也添加在這裡。

1. model choice

如果你不知道選擇什麼樣的模型,那就選擇 DCGAN[3] 或者 ResNet[4] 作為 base model。

2. input layer

假如你的輸入是一張圖片,將圖片數值歸一化到 [-1, 1];假如你的輸入是一個隨機噪聲的向量,最好是從 N(0, 1) 的正態分佈裡面採樣,不要從 U(0,1) 的均勻分佈裡採樣。

3. output layer

使用輸出通道為 3 的卷積作為最後一層,可以採用 1x1 或者 3x3 的 filters,有的論文也使用 9x9 的 filters。(注:ganhacks 推薦使用 tanh)

4. transposed convolution layer

在做 decode 的時候,儘量使用 upsample+conv2d 組合代替 transposed_conv2d,可以減少 checkerboard 的產生 [5];

在做超分辨率等任務上,可以採用 pixelshuffle [6]。在 tensorflow 裡,可以用 tf.depth_to_sapce 來實現 pixelshuffle 操作。

5. convolution layer

由於筆者經常做圖像修復方向相關的工作,推薦使用 gated-conv2d [7]。

6. normalization

雖然在 resnet 裡的標配是 BN,在分類任務上表現很好,但是圖像生成方面,推薦使用其他 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推薦使用 pixel normalization [10]。假如你有選擇困難症,那就選擇大雜燴的 normalization 方法——switchable normalization [11]。

7. discriminator

想要生成更高清的圖像,推薦 multi-stage discriminator [10]。簡單的做法就是對於輸入圖片,把它下采樣(maxpooling)到不同 scale 的大小,輸入三個不同參數但結構相同的 discriminator。

8. minibatch discriminator

由於判別器是單獨處理每張圖片,沒有一個機制能告訴 discriminator 每張圖片之間要儘可能的不相似,這樣就會導致判別器會將所有圖片都 push 到一個看起來真實的點,缺乏多樣性。minibatch discriminator [22] 就是這樣這個機制,顯式地告訴 discriminator 每張圖片應該要不相似。在 tensorflow 中,一種實現 minibatch discriminator 方式如下:

訓練GAN,你應該知道的二三事

上面是通過一個可學習的網絡來顯示度量每個樣本之間的相似度,PGGAN 裡提出了一個更廉價的不需要學習的版本,即通過統計每個樣本特徵每個像素點的標準差,然後取他們的平均,把這個平均值複製到與當前 feature map 一樣空間大小單通道,作為一個額外的 feature maps 拼接到原來的 feature maps 裡,一個簡單的 tensorflow 實現如下:

訓練GAN,你應該知道的二三事

9. GAN loss

除了第二節提到的原始 GANs 中提出的兩種 loss,還可以選擇 wgan loss [12]、hinge loss、lsgan loss [13]等。wgan loss 使用 Wasserstein 距離(推土機距離)來度量兩個分佈之間的差異,lsgan 採用類似最小二乘法的思路設計損失函數,最後演變成用皮爾森卡方散度代替了原始 GAN 中的 JS 散度,hinge loss 是遷移了 SVM 裡面的思想,在 SAGAN [14] 和 BigGAN [15] 等都是採用該損失函數。

訓練GAN,你應該知道的二三事

ps: 我自己經常使用沒有 relu 的 hinge loss 版本。

10. other loss

  • perceptual loss [17]
  • style loss [18]
  • total variation loss [17]
  • l1 reconstruction loss


通常情況下,GAN loss 配合上面幾種 loss,效果會更好。

11. gradient penalty

Gradient penalty 首次在 wgan-gp 裡面提出來的,記為 1-gp,目的是為了讓 discriminator 滿足 1-lipchitchz 連續,後續 Mescheder, Lars M. et al [19] 又提出了只針對正樣本或者負樣本進行梯度懲罰,記為 0-gp-sample。Thanh-Tung, Hoang et al [20] 提出了 0-gp,具有更好的訓練穩定性。三者的對比如下:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

訓練GAN,你應該知道的二三事

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨著訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

訓練GAN,你應該知道的二三事

仔細盯著上面式子幾秒鐘,不難發現我們優化的目標是相互悖論的,因為 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分佈拉近,但是優化 -JS 是把兩個分佈推遠,這「一推一拉」就會導致梯度更新非常不穩定。此外,我們知道 KL 不是對稱的,對於生成器無法生成真實樣本的情況,KL 對 loss 的貢獻非常大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻非常小。

訓練GAN,你應該知道的二三事

而 JS 是對稱的,不會改變 KL 的這種不公平的行為。這就解釋了我們經常在訓練階段經常看見兩種情況,一個是訓練 loss 抖動非常大,訓練不穩定;另外一個是即使達到了穩定訓練,生成器也大概率上只生成一些安全保險的樣本,這樣就會導致模型缺乏多樣性。

此外,在有監督的機器學習裡面,經常會出現一些過擬合的情況,然而 GANs 也不例外。當生成器訓練得越來越好時候,生成的數據越接近於有限樣本集合裡面的數據。特別是當訓練集裡面包含有錯誤數據時候,判別器會過擬合到這些錯誤的數據,對於那些未見的數據,判別器就不能很好的指導生成器去生成可信的數據。這樣就會導致 GANs 的泛化能力比較差。

綜上所述,原始的 GANs 在訓練穩定性、模式多樣性以及模型泛化性能方面存在著或多或少的問題,後續學術上的工作大多也是基於此進行改進(填坑)。

訓練 GAN 的常用策略

上一節都是基於一些簡單的數學或者經驗的分析,但是根本原因目前沒有一個很好的理論來解釋;儘管理論上的缺陷,我們仍然可以從一些經驗中發現一些實用的 tricks,讓你的 GANs 不再難訓。這裡列舉的一些 tricks 可能跟 ganhacks 裡面的有些重複,更多的是補充,但是為了完整起見,部分也添加在這裡。

1. model choice

如果你不知道選擇什麼樣的模型,那就選擇 DCGAN[3] 或者 ResNet[4] 作為 base model。

2. input layer

假如你的輸入是一張圖片,將圖片數值歸一化到 [-1, 1];假如你的輸入是一個隨機噪聲的向量,最好是從 N(0, 1) 的正態分佈裡面採樣,不要從 U(0,1) 的均勻分佈裡採樣。

3. output layer

使用輸出通道為 3 的卷積作為最後一層,可以採用 1x1 或者 3x3 的 filters,有的論文也使用 9x9 的 filters。(注:ganhacks 推薦使用 tanh)

4. transposed convolution layer

在做 decode 的時候,儘量使用 upsample+conv2d 組合代替 transposed_conv2d,可以減少 checkerboard 的產生 [5];

在做超分辨率等任務上,可以採用 pixelshuffle [6]。在 tensorflow 裡,可以用 tf.depth_to_sapce 來實現 pixelshuffle 操作。

5. convolution layer

由於筆者經常做圖像修復方向相關的工作,推薦使用 gated-conv2d [7]。

6. normalization

雖然在 resnet 裡的標配是 BN,在分類任務上表現很好,但是圖像生成方面,推薦使用其他 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推薦使用 pixel normalization [10]。假如你有選擇困難症,那就選擇大雜燴的 normalization 方法——switchable normalization [11]。

7. discriminator

想要生成更高清的圖像,推薦 multi-stage discriminator [10]。簡單的做法就是對於輸入圖片,把它下采樣(maxpooling)到不同 scale 的大小,輸入三個不同參數但結構相同的 discriminator。

8. minibatch discriminator

由於判別器是單獨處理每張圖片,沒有一個機制能告訴 discriminator 每張圖片之間要儘可能的不相似,這樣就會導致判別器會將所有圖片都 push 到一個看起來真實的點,缺乏多樣性。minibatch discriminator [22] 就是這樣這個機制,顯式地告訴 discriminator 每張圖片應該要不相似。在 tensorflow 中,一種實現 minibatch discriminator 方式如下:

訓練GAN,你應該知道的二三事

上面是通過一個可學習的網絡來顯示度量每個樣本之間的相似度,PGGAN 裡提出了一個更廉價的不需要學習的版本,即通過統計每個樣本特徵每個像素點的標準差,然後取他們的平均,把這個平均值複製到與當前 feature map 一樣空間大小單通道,作為一個額外的 feature maps 拼接到原來的 feature maps 裡,一個簡單的 tensorflow 實現如下:

訓練GAN,你應該知道的二三事

9. GAN loss

除了第二節提到的原始 GANs 中提出的兩種 loss,還可以選擇 wgan loss [12]、hinge loss、lsgan loss [13]等。wgan loss 使用 Wasserstein 距離(推土機距離)來度量兩個分佈之間的差異,lsgan 採用類似最小二乘法的思路設計損失函數,最後演變成用皮爾森卡方散度代替了原始 GAN 中的 JS 散度,hinge loss 是遷移了 SVM 裡面的思想,在 SAGAN [14] 和 BigGAN [15] 等都是採用該損失函數。

訓練GAN,你應該知道的二三事

ps: 我自己經常使用沒有 relu 的 hinge loss 版本。

10. other loss

  • perceptual loss [17]
  • style loss [18]
  • total variation loss [17]
  • l1 reconstruction loss


通常情況下,GAN loss 配合上面幾種 loss,效果會更好。

11. gradient penalty

Gradient penalty 首次在 wgan-gp 裡面提出來的,記為 1-gp,目的是為了讓 discriminator 滿足 1-lipchitchz 連續,後續 Mescheder, Lars M. et al [19] 又提出了只針對正樣本或者負樣本進行梯度懲罰,記為 0-gp-sample。Thanh-Tung, Hoang et al [20] 提出了 0-gp,具有更好的訓練穩定性。三者的對比如下:

訓練GAN,你應該知道的二三事

12. Spectral normalization [21]

譜歸一化是另外一個讓判別器滿足 1-lipchitchz 連續的利器,建議在判別器和生成器裡同時使用。

ps: 在個人實踐中,它比梯度懲罰更有效。

13. one-size label smoothing [22]

平滑正樣本的 label,例如 label 1 變成 0.9-1.1 之間的隨機數,保持負樣本 label 仍然為 0。個人經驗表明這個 trick 能夠有效緩解訓練不穩定的現象,但是不能根本解決問題,假如模型不夠好的話,隨著訓練的進行,後期 loss 會飛。

14. add supervised labels

  • add labels
  • conditional batch normalization


15. instance noise (decay over time)

在原始 GAN 中,我們其實在優化兩個分佈的 JS 散度,前面的推理表明在兩個分佈的支撐集沒有交集或者支撐集是低維的流形空間,他們之間的 JS 散度大概率上是 0;而加入 instance noise 就是強行讓兩個分佈的支撐集之間產生交集,這樣 JS 散度就不會為 0。新的 JS 散度變為:

"

作者:追一科技 AI Lab 研究員 Miracle


寫在前面的話


筆者接觸 GAN 也有一段時間了,從一開始的小白,到現在被 GANs 虐了千百遍但依然深愛著 GANs 的小白,被 GANs 的對抗思維所折服,被 GANs 能夠生成萬物的能力所驚歎。我覺得 GANs 在某種程度上有點類似於中國太極,『太極生兩儀,兩儀生四象』,太極闡明瞭宇宙從無極而太極,以至萬物化生的過程,太極也是講究陰陽調和。(哈哈,這麼說來 GANs 其實在中國古代就已經有了發展雛形了。)

眾所周知,GANs 的訓練尤其困難,筆者自從跳入了 GANs 這個領域(坑),就一直在跟如何訓練 GANs 做「對抗訓練」,受啟發於 ganhacks,並結合自己的經驗記錄總結了一些常用的訓練 GANs 的方法,以備後用。

(⚠️本篇不是 GANs 的入門掃盲篇,初學者慎入。)

什麼是 GANs?


GANs(Generative Adversarial Networks)可以說是一種強大的「萬能」數據分佈擬合器,主要由一個生成器(generator)和判別器(discriminator)組成。生成器主要從一個低維度的數據分佈中不斷擬合真實的高維數據分佈,而判別器主要是為了區分數據是來源於真實數據還是生成器生成的數據,他們之間相互對抗,不斷學習,最終達到Nash均衡,即任何一方的改進都不會導致總體的收益增加,這個時候判別器再也無法區分是生成器生成的數據還是真實數據。

GANs 最初由 Ian Goodfellow [1] 於 2014 年提出,目前已經在圖像、語音、文字等方面得到廣泛研究和應用,特別是在圖像生成方面,可謂是遍地開花,例如圖像風格遷移(style transfer)、圖像修復(image inpainting)、超分辨率(super resolution)等。

GANs 出了什麼問題?


GANs 通常被定義為一個 minimax 的過程:

訓練GAN,你應該知道的二三事

其中 P_r 是真實數據分佈,P_z 是隨機噪聲分佈。乍一看這個目標函數,感覺有點相互矛盾,其實這就是 GANs 的精髓所在—— 對抗訓練。

在原始的 GANs 中,判別器要不斷的提高判別是非的能力,即儘可能的將真實樣本分類為正例,將生成樣本分類為負例,所以判別器需要優化如下損失函數:

訓練GAN,你應該知道的二三事

作為對抗訓練,生成器需要不斷將生成數據分佈拉到真實數據分佈,Ian Goodfellow 首先提出瞭如下式的生成器損失函數:

訓練GAN,你應該知道的二三事

由於在訓練初期階段,生成器的能力比較弱,判別器這時候也比較弱,但仍然可以足夠精準的區分生成樣本和真實樣本,這樣 D(x) 就非常接近1,導致 log(1-D(x)) 達到飽和,後續網絡就很難再調整過來。為了解決訓練初期階段飽和問題,作者提出了另外一個損失函數,即:

訓練GAN,你應該知道的二三事

以上面這個兩個生成器目標函數為例,簡單地分析一下GAN模型存在的幾個問題:

Ian Goodfellow 論文裡面已經給出,固定 G 的參數,我們得到最優的 D^*:

訓練GAN,你應該知道的二三事

也就是說,只有當 P_r=P_g 時候,不管是真實樣本和生成樣本,判別器給出的概率都是 0.5,這個時候就無法區分樣本到底是來自於真實樣本還是來自於生成樣本,這是最理想的情況。

1. 對於第一種目標函數

在最優判別器下 D^* 下,我們給損失函數加上一個與 G 無關的項,(3) 式變成:

訓練GAN,你應該知道的二三事

注意,該式子其實就是判別器的損失函數的相反數。

把最優判別器 D^* 帶入,可以得到:

訓練GAN,你應該知道的二三事

到這裡,我們就可以看清楚我們到底在優化什麼東西了,在最優判別器的情況下,其實我們在優化兩個分佈的 JS 散度。當然在訓練過程中,判別器一開始不是最優的,但是隨著訓練的進行,我們優化的目標也逐漸接近JS散度,而問題恰恰就出現在這個 JS 散度上面。一個直觀的解釋就是隻要兩個分佈之間的沒有重疊或者重疊部分可以忽略不計,那麼大概率上我們優化的目標就變成了一個常數 -2log2,這種情況通過判別器傳遞給生成器的梯度就是零,也就是說,生成器不可能從判別器那裡學到任何有用的東西,這也就導致了無法繼續學習。

Arjovsky [2] 以其精湛的數學技巧提供一個更嚴謹的一個數學推導(手動截圖原論文了)。

訓練GAN,你應該知道的二三事

在 Theorm2.4 成立的情況下:

訓練GAN,你應該知道的二三事

拋開上面這些文縐縐的數學表述,其實上面講的核心內容就是當兩個分佈的支撐集是沒有交集的或者說是支撐集是低維的流形空間,隨著訓練的進行,判別器不斷接近最優判別器,會導致生成器的梯度處處都是為0。

2. 對於第二種目標函數

同樣在最優判別器下,優化 (4) 式等價優化如下

訓練GAN,你應該知道的二三事

仔細盯著上面式子幾秒鐘,不難發現我們優化的目標是相互悖論的,因為 KL 散度和 JS 散度的符號相反,優化 KL 是把兩個分佈拉近,但是優化 -JS 是把兩個分佈推遠,這「一推一拉」就會導致梯度更新非常不穩定。此外,我們知道 KL 不是對稱的,對於生成器無法生成真實樣本的情況,KL 對 loss 的貢獻非常大,而對於生成器生成的樣本多樣性不足的時候,KL 對 loss 的貢獻非常小。

訓練GAN,你應該知道的二三事

而 JS 是對稱的,不會改變 KL 的這種不公平的行為。這就解釋了我們經常在訓練階段經常看見兩種情況,一個是訓練 loss 抖動非常大,訓練不穩定;另外一個是即使達到了穩定訓練,生成器也大概率上只生成一些安全保險的樣本,這樣就會導致模型缺乏多樣性。

此外,在有監督的機器學習裡面,經常會出現一些過擬合的情況,然而 GANs 也不例外。當生成器訓練得越來越好時候,生成的數據越接近於有限樣本集合裡面的數據。特別是當訓練集裡面包含有錯誤數據時候,判別器會過擬合到這些錯誤的數據,對於那些未見的數據,判別器就不能很好的指導生成器去生成可信的數據。這樣就會導致 GANs 的泛化能力比較差。

綜上所述,原始的 GANs 在訓練穩定性、模式多樣性以及模型泛化性能方面存在著或多或少的問題,後續學術上的工作大多也是基於此進行改進(填坑)。

訓練 GAN 的常用策略

上一節都是基於一些簡單的數學或者經驗的分析,但是根本原因目前沒有一個很好的理論來解釋;儘管理論上的缺陷,我們仍然可以從一些經驗中發現一些實用的 tricks,讓你的 GANs 不再難訓。這裡列舉的一些 tricks 可能跟 ganhacks 裡面的有些重複,更多的是補充,但是為了完整起見,部分也添加在這裡。

1. model choice

如果你不知道選擇什麼樣的模型,那就選擇 DCGAN[3] 或者 ResNet[4] 作為 base model。

2. input layer

假如你的輸入是一張圖片,將圖片數值歸一化到 [-1, 1];假如你的輸入是一個隨機噪聲的向量,最好是從 N(0, 1) 的正態分佈裡面採樣,不要從 U(0,1) 的均勻分佈裡採樣。

3. output layer

使用輸出通道為 3 的卷積作為最後一層,可以採用 1x1 或者 3x3 的 filters,有的論文也使用 9x9 的 filters。(注:ganhacks 推薦使用 tanh)

4. transposed convolution layer

在做 decode 的時候,儘量使用 upsample+conv2d 組合代替 transposed_conv2d,可以減少 checkerboard 的產生 [5];

在做超分辨率等任務上,可以採用 pixelshuffle [6]。在 tensorflow 裡,可以用 tf.depth_to_sapce 來實現 pixelshuffle 操作。

5. convolution layer

由於筆者經常做圖像修復方向相關的工作,推薦使用 gated-conv2d [7]。

6. normalization

雖然在 resnet 裡的標配是 BN,在分類任務上表現很好,但是圖像生成方面,推薦使用其他 normlization 方法,例如 parameterized 方法有 instance normalization [8]、layer normalization [9] 等,non-parameterized 方法推薦使用 pixel normalization [10]。假如你有選擇困難症,那就選擇大雜燴的 normalization 方法——switchable normalization [11]。

7. discriminator

想要生成更高清的圖像,推薦 multi-stage discriminator [10]。簡單的做法就是對於輸入圖片,把它下采樣(maxpooling)到不同 scale 的大小,輸入三個不同參數但結構相同的 discriminator。

8. minibatch discriminator

由於判別器是單獨處理每張圖片,沒有一個機制能告訴 discriminator 每張圖片之間要儘可能的不相似,這樣就會導致判別器會將所有圖片都 push 到一個看起來真實的點,缺乏多樣性。minibatch discriminator [22] 就是這樣這個機制,顯式地告訴 discriminator 每張圖片應該要不相似。在 tensorflow 中,一種實現 minibatch discriminator 方式如下:

訓練GAN,你應該知道的二三事

上面是通過一個可學習的網絡來顯示度量每個樣本之間的相似度,PGGAN 裡提出了一個更廉價的不需要學習的版本,即通過統計每個樣本特徵每個像素點的標準差,然後取他們的平均,把這個平均值複製到與當前 feature map 一樣空間大小單通道,作為一個額外的 feature maps 拼接到原來的 feature maps 裡,一個簡單的 tensorflow 實現如下:

訓練GAN,你應該知道的二三事

9. GAN loss

除了第二節提到的原始 GANs 中提出的兩種 loss,還可以選擇 wgan loss [12]、hinge loss、lsgan loss [13]等。wgan loss 使用 Wasserstein 距離(推土機距離)來度量兩個分佈之間的差異,lsgan 採用類似最小二乘法的思路設計損失函數,最後演變成用皮爾森卡方散度代替了原始 GAN 中的 JS 散度,hinge loss 是遷移了 SVM 裡面的思想,在 SAGAN [14] 和 BigGAN [15] 等都是採用該損失函數。

訓練GAN,你應該知道的二三事

ps: 我自己經常使用沒有 relu 的 hinge loss 版本。

10. other loss

  • perceptual loss [17]
  • style loss [18]
  • total variation loss [17]
  • l1 reconstruction loss


通常情況下,GAN loss 配合上面幾種 loss,效果會更好。

11. gradient penalty

Gradient penalty 首次在 wgan-gp 裡面提出來的,記為 1-gp,目的是為了讓 discriminator 滿足 1-lipchitchz 連續,後續 Mescheder, Lars M. et al [19] 又提出了只針對正樣本或者負樣本進行梯度懲罰,記為 0-gp-sample。Thanh-Tung, Hoang et al [20] 提出了 0-gp,具有更好的訓練穩定性。三者的對比如下:

訓練GAN,你應該知道的二三事

12. Spectral normalization [21]

譜歸一化是另外一個讓判別器滿足 1-lipchitchz 連續的利器,建議在判別器和生成器裡同時使用。

ps: 在個人實踐中,它比梯度懲罰更有效。

13. one-size label smoothing [22]

平滑正樣本的 label,例如 label 1 變成 0.9-1.1 之間的隨機數,保持負樣本 label 仍然為 0。個人經驗表明這個 trick 能夠有效緩解訓練不穩定的現象,但是不能根本解決問題,假如模型不夠好的話,隨著訓練的進行,後期 loss 會飛。

14. add supervised labels

  • add labels
  • conditional batch normalization


15. instance noise (decay over time)

在原始 GAN 中,我們其實在優化兩個分佈的 JS 散度,前面的推理表明在兩個分佈的支撐集沒有交集或者支撐集是低維的流形空間,他們之間的 JS 散度大概率上是 0;而加入 instance noise 就是強行讓兩個分佈的支撐集之間產生交集,這樣 JS 散度就不會為 0。新的 JS 散度變為:

訓練GAN,你應該知道的二三事

16. TTUR [23]

在優化 G 的時候,我們默認是假定我們的 D 的判別能力是比當前的 G 的生成能力要好的,這樣 D 才能指導 G 朝更好的方向學習。通常的做法是先更新 D 的參數一次或者多次,然後再更新 G 的參數,TTUR 提出了一個更簡單的更新策略,即分別為 D 和 G 設置不同的學習率,讓 D 收斂速度更快。

17. training strategy

  • PGGAN [10]


PGGAN 是一個漸進式的訓練技巧,因為要生成高清(eg, 1024x1024)的圖片,直接從一個隨機噪聲生成這麼高維度的數據是比較難的;既然沒法一蹴而就,那就循序漸進,首先從簡單的低緯度的開始生成,例如 4x4,然後 16x16,直至我們所需要的圖片大小。在 PGGAN 裡,首次實現了高清圖片的生成,並且可以做到以假亂真,可見其威力。此外,由於我們大部分的操作都是在比較低的維度上進行的,訓練速度也不比其他模型遜色多少。

  • coarse-to-refine


coarse-to-refine 可以說是 PGGAN 的一個特例,它的做法就是先用一個簡單的模型,加上一個 l1 loss,訓練一個模糊的效果,然後再把這個模糊的照片送到後面的 refine 模型裡,輔助對抗 loss 等其他 loss,訓練一個更加清晰的效果。這個在圖片生成裡面廣泛應用。

18. Exponential Moving Average [24]

EMA主要是對歷史的參數進行一個指數平滑,可以有效減少訓練的抖動。強烈推薦!!!

總結

訓練 GAN 是一個精(折)細(磨)的活,一不小心你的 GAN 可能就是一部驚悚大片。筆者結合自己的經驗以及看過的一些文獻資料,列出了常用的 tricks,在此拋磚引玉,由於筆者能力和視野有限,有些不正確之處或者沒補全的 tricks,還望斧正。

最後,祝大家煉丹愉快,不服就 GAN。: )

參考文獻

  • [1]. Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014.
  • [2]. Arjovsky, Martín and Léon Bottou. “Towards Principled Methods for Training Generative Adversarial Networks.” CoRR abs/1701.04862 (2017): n. pag.
  • [3]. Radford, Alec et al. “Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks.” CoRR abs/1511.06434 (2016): n. pag.
  • [4]. He, Kaiming et al. “Deep Residual Learning for Image Recognition.” 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016): 770-778.
  • [5]. https://distill.pub/2016/deconv-checkerboard/
  • [6]. Shi, Wenzhe et al. “Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network.” 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (2016): 1874-1883.
  • [7]. Yu, Jiahui et al. “Free-Form Image Inpainting with Gated Convolution.” CoRRabs/1806.03589 (2018): n. pag.
  • [8]. Ulyanov, Dmitry et al. “Instance Normalization: The Missing Ingredient for Fast Stylization.” CoRR abs/1607.08022 (2016): n. pag.
  • [9]. Ba, Jimmy et al. “Layer Normalization.” CoRR abs/1607.06450 (2016): n. pag.
  • [10]. Karras, Tero et al. “Progressive Growing of GANs for Improved Quality, Stability, and Variation.” CoRR abs/1710.10196 (2018): n. pag.
  • [11]. Luo, Ping et al. “Differentiable Learning-to-Normalize via Switchable Normalization.” CoRRabs/1806.10779 (2018): n. pag.
  • [12]. Arjovsky, Martín et al. “Wasserstein GAN.” CoRR abs/1701.07875 (2017): n. pag.
  • [13]. Mao, Xudong, et al. "Least squares generative adversarial networks." Proceedings of the IEEE International Conference on Computer Vision. 2017.
  • [14]. Zhang, Han, et al. "Self-attention generative adversarial networks." arXiv preprint arXiv:1805.08318 (2018).
  • [15]. Brock, Andrew, Jeff Donahue, and Karen Simonyan. "Large scale gan training for high fidelity natural image synthesis." arXiv preprint arXiv:1809.11096 (2018).
  • [16]. Gulrajani, Ishaan et al. “Improved Training of Wasserstein GANs.” NIPS (2017).
  • [17]. Johnson, Justin et al. “Perceptual Losses for Real-Time Style Transfer and Super-Resolution.” ECCV (2016).
  • [18]. Liu, Guilin et al. “Image Inpainting for Irregular Holes Using Partial Convolutions.” ECCV(2018).
  • [19]. Mescheder, Lars M. et al. “Which Training Methods for GANs do actually Converge?” ICML(2018).
  • [20]. Thanh-Tung, Hoang et al. “Improving Generalization and Stability of Generative Adversarial Networks.” CoRR abs/1902.03984 (2018): n. pag.
  • [21]. Yoshida, Yuichi and Takeru Miyato. “Spectral Norm Regularization for Improving the Generalizability of Deep Learning.” CoRR abs/1705.10941 (2017): n. pag.
  • [22]. Salimans, Tim et al. “Improved Techniques for Training GANs.” NIPS (2016).
  • [23]. Heusel, Martin et al. “GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium.” NIPS (2017).
  • [24]. Yazici, Yasin et al. “The Unusual Effectiveness of Averaging in GAN Training.” CoRRabs/1806.04498 (2018): n. pag.
"

相關推薦

推薦中...