刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

機器學習 辛普森一家 Kaggle Medium 量子位 2017-07-07

王小新 編譯自 Medium
量子位 出品 | 公眾號 QbitAI

Alexandre Attia是《辛普森一家》的狂熱粉絲。他看了一系列辛普森劇集,想建立一個能識別其中人物的神經網絡。

接下來讓我們跟著他的文章來了解下該如何建立一個用於識別《辛普森一家》中各個角色的神經網絡。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

要實現這個項目不是很困難,但可能會比較耗時,因為需要手動標註每個人物的多張照片。

目前在網上沒有《辛普森一家》人物的訓練數據集,所以我正在標註各類圖片來構建訓練數據集。這個數據集的第一個版本已經掛在Kaggle上了,將持續進行更新,希望這個數據集能幫到大家。

在學了用TensorFlow構建不同項目後,我決定用Keras,因為它比TensorFlow更為簡單易上手,而且以TensorFlow作為後端,具有很強的兼容性。Keras是Francois Chollet用Python語言編寫的一個深度學習庫。

本文基於卷積神經網絡(CNN)來完成此項目,CNN網絡是一種能夠學習許多特徵的多層前饋神經網絡。

準備數據集

該數據集目前有18類,有以下人物:Homer,Marge,Lisa,Bart,Burns,Grampa,Flanders,Moe,Krusty,Sideshow Bob,Skinner,Milhouse等。

我的目標是達到20類,當然類別越多越好。各類樣本的大小不一,圖片背景也不盡相同,主要是從第4至24季的劇集中提取出來的。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

部分人物的圖片

在訓練集中,每個人物各大約包括1000個樣本(還在標註數據來達到這個數量)。每個人物不一定處於圖像中間,有時周圍還帶有其他人物。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

人物的樣本量分佈

通過label_data.py函數,我們可以從AVI電影中標註數據:得到裁剪後的圖片(左部分或右部分),或者完整版,然後僅需輸入人物名稱的一部分,如對Charles Montgomery Burns輸入burns。

添加數據時,我也使用了Keras模型。對視頻進行截圖,每一幀可轉化得到3張圖片,分別是左部分、右部分和完整版,然後通過編寫算法來分類每張圖片。

之後,我檢查了此算法的分類效果,雖然是手動的,但這是一個漸進的過程,速度將會不斷提升,特別是對出現頻率較低的小類別人物。

數據預處理

在預處理圖片時,第一步是調整樣本大小。為了節省數據內存,先將樣本轉換為float32類型,併除以255進行歸一化。

然後,使用Keras的自帶函數,將各類人物的標籤從名字轉換為數字,再利用one-hot編碼轉換成矢量:

import keras

進而,使用sklearn庫的train_test_split函數,將數據集分成訓練集和測試集。

構建模型

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

現在讓我們開始進入最有趣的部分:定義網絡模型。

首先,我們構建了一個前饋網絡,包括4個帶有ReLU激活函數的卷積層和一個全連接的隱藏層(隨著數據量的增大,可能會進一步加深網絡)。

這個模型與Keras文檔中的CIFAR示例模型比較相近,接下來還會使用更多數據對其他模型進行測試。我還在模型中加入了Dropout層來防止網絡過擬合。在輸出層中,使用softmax函數來輸出各類的所屬概率。

損失函數為分類交叉熵(Categorical Cross Entropy)。優化器optimizer使用了隨機梯度下降中的RMS Prop方法,通過該權重臨近窗口的梯度平均值來確定該點的學習率。

訓練模型

這個模型在訓練集上迭代訓練了200次,其中批次大小為32。

由於目前的數據集樣本不多,我還用了數據增強操作,使用Keras庫可以很快地實現。

這實際上是對圖片進行一些隨機變化,如小角度旋轉和加噪聲等,所以輸入模型的樣本都不大相同。這有助於防止模型過擬合,提高模型的泛化能力。

datagen = ImageDataGenerator(

在CPU上訓練模型時會耗費較長時間,所以我使用AWS EC2上的GPU資源:每次迭代需要8秒鐘,一共使用了20分鐘。在訓練深度學習模型時,這已經是較快了。

在200次迭代後,我們畫出了模型指標,可以看出性能已經較為穩定,沒有明顯的過擬合現象,且實際正確率較高。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

訓練時驗證集和訓練集的損失值和正確率

評估模型

由於當前樣本量較小,所以很難得到準確的模型精度。但隨著訓練集樣本的增多,這將更貼近實際的模型性能。我們使用sklearn庫很快地輸出了各類的識別效果。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

各類別的識別效果

從上圖可以看出,模型的正確率(f1-score)較高:除了Lisa,其餘各類的正確率都超過了80%。Lisa類的平均正確率為82%,可能是在樣本中Lisa與其他人物混在一起。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

各類別的交叉關係圖

的確,Lisa樣本中經常帶有Bart,所以正確率較低可能受到Bart的影響。

添加閾值來提高正確率

為了提高模型正確率和減少召回率,我添加了一個閾值。

在討論閾值之前,先介紹下關於召回和正確率的關係圖。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

各類別的交叉關係圖

現在統計下正確預測和錯誤預測的相關數據:最佳概率預測,兩個最相似人物的概率差和標準偏差STD。

  • 正確預測:最大值為0.83,最優點概率差為0.773,STD值為0.21;

  • 錯誤預測:最大值為0.27,最優點概率差為0.092,STD值為0.07。

如果人物1的預測正確率太低,預測人物2時標準偏差太高或是兩個最相似人物間的概率差太低,那麼可以認為網絡沒有學習到這個人物。

因此,對兩個類別,繪製測試集的3個指標,希望找到一個超平面來分離正確預測和錯誤預測。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

測試集中多個指標的散點圖

上圖中,想要通過直線或是設置閾值,來分離出正確預測和錯誤預測,這是不容易實現的。當然還可以看出,錯誤預測的樣本一般在圖表的左下方,但在這個位置也分佈了很多正確預測樣本。如果設置了一個閾值(關於最相似人物間的概率差和概率),則實際召回率也會降低。

我們希望在提高準確性的同時,而不會很大程度上影響召回率,因此要為每個人物或是低正確率的人物(如Lisa Simpson)來繪製這些散點圖。

此外,對於沒有主角或是不存在人物的樣本,加入閾值後效果很好。目前我在模型中添加了一個“無人物”的類別,可以添加閾值來處理。我認為很難在最佳概率預測、概率差和標準偏差之間找到平衡點,所以我重點關注最佳預測概率。

關於最佳預測概率的召回率和正確率

在模型中,很難平衡好召回率與正確率之間的關係,同時也無法同時提高召回率和正確率。所以往往根據實際目標,來提高單個值。

對於預測類別的概率最小值,畫出F1-score、召回率和正確率來比較效果。

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

正確率、召回率和F1-score與預測類別概率最小值的關係

從圖10中看出,模型效果取決於不同人物。重點研究Lisa Simpson類別,為該類添加概率最小值0.2可能會提高效果,但是組合所有類別後,這個閾值並不完全適用。

所以考慮全局效果,對於預測類別的概率最小值,應該增加一個合適的閾值,且不能位於區間[0.2,0.4]內。

可視化預測人物

刷劇不忘學習:用Keras識別辛普森一家人物|教程+代碼

12個不同人物的實際類別和預測類別

在圖11中,用於分類人物的神經網絡效果很好,故應用到視頻中實時預測。在實際中,每張圖片的預測時間不超過0.1s,可以做到每秒預測多幀。

相關鏈接

1. 辛普森一家的人物數據集:

https://www.kaggle.com/alexattia/the-simpsons-characters-dataset

2. 完整項目代碼:

https://github.com/alexattia/SimpsonRecognition

【完】

一則通知

量子位正在組建自動駕駛技術群,面向研究自動駕駛相關領域的在校學生或一線工程師。李開復、王詠剛、王乃巖等大牛都在群裡。歡迎大家加量子位微信(qbitbot),備註“自動駕駛”申請加入哈~

相關推薦

推薦中...