用Keras+TF,實現ImageNet數據集日常對象的識別

編程語言 Python 人工智能 GitHub 量子位 2017-04-25

王新民 編譯自 Deep Learning Sandbox博客

量子位 出品 | 公眾號 QbitAI

用Keras+TF,實現ImageNet數據集日常對象的識別

在計算機視覺領域裡,有3個最受歡迎且影響非常大的學術競賽:ImageNet ILSVRC(大規模視覺識別挑戰賽),PASCAL VOC(關於模式分析,統計建模和計算學習的研究)和微軟COCO圖像識別大賽。這些比賽大大地推動了在計算機視覺研究中的多項發明和創新,其中很多都是免費開源的。

博客Deep Learning Sandbox作者Greg Chu打算通過一篇文章,教你用Keras和TensorFlow,實現對ImageNet數據集中日常物體的識別。

量子位翻譯了這篇文章:

你想識別什麼?

看看ILSVRC競賽中包含的物體對象。如果你要研究的物體對象是該列表1001個對象中的一個,運氣真好,可以獲得大量該類別圖像數據!以下是這個數據集包含的部分類別:

椅子
汽車鍵盤箱子
嬰兒床旗杆iPod播放器
輪船麵包車項鍊
降落傘枕頭桌子
錢包球拍步槍
校車薩克斯管足球
襪子舞臺火爐
火把吸塵器自動售貨機
眼鏡紅綠燈菜餚
盤子西蘭花紅酒

表1 ImageNet ILSVRC的類別摘錄

完整類別列表見:https://gist.github.com/gregchu/134677e041cd78639fea84e3e619415b

如果你研究的物體對象不在該列表中,或者像醫學圖像分析中具有多種差異較大的背景,遇到這些情況該怎麼辦?可以藉助遷移學習(transfer learning)和微調(fine-tuning),我們以後再另外寫文章講。

圖像識別

圖像識別,或者說物體識別是什麼?它回答了一個問題:“這張圖像中描繪了哪幾個物體對象?”如果你研究的是基於圖像內容進行標記,確定盤子上的食物類型,對癌症患者或非癌症患者的醫學圖像進行分類,以及更多的實際應用,那麼就能用到圖像識別。

Keras和TensorFlow

Keras是一個高級神經網絡庫,能夠作為一種簡單好用的抽象層,接入到數值計算庫TensorFlow中。另外,它可以通過其keras.applications模塊獲取在ILSVRC競賽中獲勝的多個卷積網絡模型,如由Microsoft Research開發的ResNet50網絡和由Google Research開發的InceptionV3網絡,這一切都是免費和開源的。具體安裝參照以下說明進行操作:

Keras安裝:https://keras.io/#installation

TensorFlow安裝:https://www.tensorflow.org/install/

實現過程

我們的最終目標是編寫一個簡單的python程序,只需要輸入本地圖像文件的路徑或是圖像的URL鏈接就能實現物體識別。

以下是輸入非洲大象照片的示例:

1. python classify.py --image African_Bush_Elephant.jpg

2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg

輸入:

用Keras+TF,實現ImageNet數據集日常對象的識別

輸出將如下所示:

用Keras+TF,實現ImageNet數據集日常對象的識別

該圖像最可能的前3種預測類別及其相應概率

預測功能

我們接下來要載入ResNet50網絡模型。首先,要加載keras.preprocessingkeras.applications.resnet50模塊,並使用在ImageNet ILSVRC比賽中已經訓練好的權重。

想了解ResNet50的原理,可以閱讀論文《基於深度殘差網絡的圖像識別》。地址:https://arxiv.org/pdf/1512.03385.pdf

import numpy as np

from keras.preprocessing import image

from keras.applications.resnet50

import ResNet50, preprocess_input, decode_predictions model = ResNet50(weights='imagenet')

接下來定義一個預測函數:

def predict(model, img, target_size, top_n=3):
 """Run model prediction on image
 Args:
 model: keras model
 img: PIL format image
 target_size: (width, height) tuple
 top_n: # of top predictions to return
 Returns:
 list of predicted labels and their probabilities
 """
 if img.size != target_size:
 img = img.resize(target_size)
 x = image.img_to_array(img)
 x = np.expand_dims(x, axis=0)
 x = preprocess_input(x)
 preds = model.predict(x) 

return decode_predictions(preds, top=top_n)[0]

在使用ResNet50網絡結構時需要注意,輸入大小target_size必須等於(224,224)。許多CNN網絡結構具有固定的輸入大小,ResNet50正是其中之一,作者將輸入大小定為(224,224)

image.img_to_array:將PIL格式的圖像轉換為numpy數組。

np.expand_dims:將我們的(3,224,224)大小的圖像轉換為(1,3,224,224)。因為model.predict函數需要4維數組作為輸入,其中第4維為每批預測圖像的數量。這也就是說,我們可以一次性分類多個圖像。

preprocess_input:使用訓練數據集中的平均通道值對圖像數據進行零值處理,即使得圖像所有點的和為0。這是非常重要的步驟,如果跳過,將大大影響實際預測效果。這個步驟稱為數據歸一化。

model.predict:對我們的數據分批處理並返回預測值。

decode_predictions:採用與model.predict函數相同的編碼標籤,並從ImageNet ILSVRC集返回可讀的標籤。

keras.applications模塊還提供4種結構:ResNet50、InceptionV3、VGG16、VGG19和XCeption,你可以用其中任何一種替換ResNet50。更多信息可以參考https://keras.io/applications/。

繪圖

我們可以使用matplotlib函數庫將預測結果做成柱狀圖,如下所示:

def plot_preds(image, preds):

主體部分

為了實現以下從網絡中加載圖片的功能:

1. python classify.py --image African_Bush_Elephant.jpg

2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg

我們將定義主函數如下:

if __name__=="__main__":
 a = argparse.ArgumentParser()
 a.add_argument("--image",

help="path to image") a.add_argument("--image_url",

help="url to image") args = a.parse_args()

if args.image is None and args.image_url is None: a.print_help() sys.exit(1)

if args.image is not None: img = Image.open(args.image) print_preds(predict(model, img, target_size))

if args.image_url is not None: response = requests.get(args.image_url) img = Image.open(BytesIO(response.content)) print_preds(predict(model, img, target_size))

其中在寫入image_url功能後,用python中的Requests庫就能很容易地從URL鏈接中下載圖像。

完工

將上述代碼組合起來,你就創建了一個圖像識別系統。項目的完整程序和示例圖像請查看GitHub鏈接:

https://github.com/DeepLearningSandbox/DeepLearningSandbox/tree/master/image_recognition

招聘

我們正在招募編輯記者、運營等崗位,工作地點在北京中關村,期待你的到來,一起體驗人工智能的風起雲湧。

相關細節,請在公眾號對話界面,回覆:“招聘”兩個字。

One More Thing…

今天AI界還有哪些事值得關注?在量子位(QbitAI)公眾號會話界面回覆“今天”,看我們全網蒐羅的AI行業和研究動態。筆芯~

另外,歡迎加量子位小助手的微信:qbitbot,如果你研究或者從事AI領域,小助手會把你帶入量子位的交流群裡。

相關推薦

推薦中...