人工智能進階-TensorFlow創建TFRecord數據集(附實例代碼)

Python 人工智能 圖像處理 操作系統 HELIX人工智能實驗室 2019-07-12

在測試TensorFlow實例時發現好多的圖片數據集都是處理好的,直接在庫中調用,比如Mnist,CIFAR-10等等。但是在運行自己編制的項目的時候,如何去讀取自己的數據集呢?其實,一方面TensorFlow官方已經給出方法,那就是將圖片製作成tfrecord格式的數據,供TensorFlow讀取。另一方面可以用Python以及Python的圖像處理第三方庫都有數據集讀取製作的方法。

人工智能進階-TensorFlow創建TFRecord數據集(附實例代碼)

1. TFRecord數據集介紹

TFRecord是一種二進制文件,可以支持多線程數據讀取,可以通過batch_size和epoch參數來控制訓練時單次batch的大小和樣本迭代次數,同時能更好地利用內存和方便數據的複製和移動,所以是TensorFlow進行大規模深度學習訓練的首選。

每個訓練樣本在TFRecord中稱為example,TensorFlow使用tf.train.Example協議來存儲訓練樣本,每個example本質上是一個字典dict類型,用來存儲一個訓練樣本的多個feature信息(如input、label、mask等等),且每個feature信息必須是TensorFlow預定義好的類型(ByteList,FloatList以及Int64List中的一種)。最後,example通過SerializeToString()方法將樣例序列化成字符串存儲,TensorFlow通過TFRecordWriter將這些序列化之後的字符串存成TFRecord形式。

2. 創建TFRecord文件

# coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import os
# 這是設置的路徑,可以根據您的需要修改
image_train_path = 'E:\\GitLab\\VSCode\\images'
label_train_path = 'E:\\GitLab\\VSCode\\images\\train.txt'
tfRecord_train = 'E:\\GitLab\\VSCode\\images\\test.tfrecords'
image_test_path = 'E:\\GitLab\\VSCode\\images'
label_test_path = 'E:\\GitLab\\VSCode\\images\\test.txt'
tfRecord_test = 'E:\\GitLab\\VSCode\\images\\test01.tfrecords'
data_path = 'E:\\GitLab\\VSCode\\images'
# 設置長寬像素點個數
resize_height = 28
resize_width = 28
# 生成tfrecords文件
def write_tfRecord(tfRecordName, image_path, label_path):
writer = tf.python_io.TFRecordWriter(tfRecordName) # 新建一個writer
num_pic = 0
f = open(label_path, 'r')
contents = f.readlines() # 一次全部讀入,速度比較快
f.close()
for content in contents:
'''
該目錄下的文件下的txt內容為:
0_5.jpg 5
1_0.jpg 0
2_4.jpg 4
.......
'''
value = content.split() # 用空格分開
img_path = image_path + value[0]
img = Image.open(img_path)
img_raw = img.tobytes() # 轉化為二進制文件
labels = [0] * 10
labels[int(value[1])] = 1 # 設置標籤位為1
# 用tf.train.Example的協議存儲訓練數據,訓練數據的特徵用鍵值對的形式表示
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
})) # 把每張圖片和標籤封裝到example中
writer.write(example.SerializeToString()) # 將example序列化(把數據序列化成字符串存儲)
num_pic += 1
print("the number of picture:", num_pic)
writer.close() # 關閉writer
print("write tfrecord successful")
# 產生數據集
def generate_tfRecord():
isExists = os.path.exists(data_path) # 判斷路徑是否存在
if not isExists: # 如果不存在
os.makedirs(data_path) # 新建一個目錄
print('The directory was created successfully')
else:
print('directory already exists')
# 生成tfRecords文件
write_tfRecord(tfRecord_train, image_train_path, label_train_path)
write_tfRecord(tfRecord_test, image_test_path, label_test_path)
# 解析tfrecords文件
def read_tfRecord(tfRecord_path):
# [tfRecord_path]為文件的路徑,如果文件比較大可以寫多個
filename_queue = tf.train.string_input_producer(
[tfRecord_path], shuffle=True)
reader = tf.TFRecordReader() # 新建一個reader
_, serialized_example = reader.read(
filename_queue) # 將讀出的每個樣本保存在serialize_example中
features = tf.parse_single_example(serialized_example,
features={
# 10分類寫10
'label': tf.FixedLenFeature([10], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
}) # 解序列化
img = tf.decode_raw(features['img_raw'], tf.uint8) # 恢復img_raw 到 img
img.set_shape([784]) # 把img的shape設為[1,784]
img = tf.cast(img, tf.float32) * (1. / 255) # 歸一化到0-1
label = tf.cast(features['label'], tf.float32) # 同時把label值也設為浮點型
return img, label
# 批獲取訓練集或測試集的內容和標籤
def get_tfrecord(num, isTrain=True):
if isTrain: # 獲取訓練集,isTrain參數設置為True
tfRecord_path = tfRecord_train
else: # 獲取測試集,isTrain參數設置為False
tfRecord_path = tfRecord_test
img, label = read_tfRecord(tfRecord_path)
# 從總樣本中順序獲取capactiy組數據,打亂順序,每次輸出batch_size組,用了2個線程
img_batch, label_batch = tf.train.shuffle_batch([img, label],
batch_size=num,
num_threads=2,
capacity=1000,
min_after_dequeue=700)
return img_batch, label_batch
def main():
generate_tfRecord()
if __name__ == '__main__':
main()

數據集創建結果:

人工智能進階-TensorFlow創建TFRecord數據集(附實例代碼)

3. 製作數據集流程

人工智能進階-TensorFlow創建TFRecord數據集(附實例代碼)

人工智能進階-TensorFlow創建TFRecord數據集(附實例代碼)

相關推薦

推薦中...