昨天花了些時間複習 tf.keras API (點我)。除了建個簡單的卷積神經網路 (Convolution Nerual Network, CNN),也丟了最簡單的 MNIST 數字集來訓練,很快就複習一遍相關的程式,反而整理程式寫文章花了許多時間QQ。雖說是複習 API,但還是試了不少新東西,但主要就是用 tf.data.Dataset 讀取資料的方式,所以就決定單獨寫一篇文章來紀錄一下。
tf.data.Dataset
先來看官方解說,內容就直接貼上來了。
The tf.data.Dataset API supports writing descriptive and efficient input pipelines. Dataset usage follows a common pattern:
- Create a source dataset from your input data.
- Apply dataset transformations to preprocess the data.
- Iterate over the dataset and process the elements.
Iteration happens in a streaming fashion, so the full dataset does not need to fit into memory.
簡單講就是用他們方式建立數據集 (Dataset),是最有效率的方法,也能藉由他們的方法進行前處理 (Preprocessing)。
官方範例
看完簡單的說明之後,基本上就是來了解如何使用 Dataset 這個 API 了。那我這邊先列出一些官方的範例,來快速了解一下怎麼去使用這個 API,以及這個 API 可以做什麼。
The simplest way to create a dataset is to create it from a python list:
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:
print(element)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
To process lines from files, use tf.data.TextLineDataset:
dataset = tf.data.TextLineDataset(["file1.txt", "file2.txt"])
To process records written in the TFRecord format, use TFRecordDataset:
dataset = tf.data.TFRecordDataset(["file1.tfrecords", "file2.tfrecords"])
To create a dataset of all files matching a pattern, use tf.data.Dataset.list_files:
dataset = tf.data.Dataset.list_files("/path/*.txt")
See tf.data.FixedLengthRecordDataset and tf.data.Dataset.from_generator for more ways to create datasets.
從上面範例可以看到,至少有六種的 API 可以讀取資料,這樣要應用各種形式資料就更彈性了。
那我自己有使用過的方式有兩種,分別是:
tf.data.Dataset.from_tensor_slices:這可以將 Python 的 List、Dict、NumPy 陣列,包裝成 Dataset 物件。tf.data.Dataset.from_generator:這是將 Python 的產生器包裝成 Dataset 物件。
參數輸入的方式,就直接點連結,去看官方的說明了。接下來就針對這兩個 API 來練習,程式一樣是找官方的範例。
tf.data.Dataset.from_tensor_slices
這個 API 很有意思的是,Python 容器型態 list、dict、tuple、NumPy Array 都支援。例如:
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4]) # 1d list
dataset = tf.data.Dataset.from_tensor_slices([[1,2,3,4],[5,6,7,8]]) # 2d list
dataset = tf.data.Dataset.from_tensor_slices({"a": [1,1,1], "b";: [2,2,2]}) # dict
dataset = tf.data.Dataset.from_tensor_slices((images, labels)) # images & labels is np.array
上面的範例都可以轉成 Dataset 阿,這等於支援所有 Python 的資料結構了,非常強大阿!!!
再來讓我們來看個更貼近實際應用的例子。
轉換 NumPy arrays (引用官方範例)
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255
dataset = tf.data.Dataset.from_tensor_slices((images, labels))
print(dataset)
會印出下面的結果,
<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>
通常都是會使用 **NumPy**、**scikit-image** 或是 **opencv** 來處理機器視覺的資料,通常資料型態會是 **`Numpy Array`** 型態,所以用上面那種轉換方式其實就會很常用到。
tf.data.Dataset.from\_generator
-------------------------------
再來另一個範例就是使用 **Keras** 影像前處理工具,由於這個是利用 Python 的"產生器"的方式來讀取資料並進行影像處理,因此用 **`from_generator` 這個 API 就可以很輕鬆的轉換成 Dataset 了。**
關於 **Keras** 的影像前處理工具可以去看: [**`preprocessing.image.ImageDataGenerator`**](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator)。
官方範例是用花來做,程式如下:
下載花的資料
flowers = tf.keras.utils.get_file( “flower_photos”, “https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz", untar=True)
建立 ImageDataGenerator
img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
讀取資料
images, labels = next(img_gen.flow_from_directory(flowers))
顯示資料 type 和 shape
print(images.dtype, images.shape) print(labels.dtype, labels.shape)
建立 Dataset 實例
ds = tf.data.Dataset.from_generator( lambda: img_gen.flow_from_directory(flowers), output_types=(tf.float32, tf.float32), output_shapes=([32,256,256,3], [32,5]) )
秀出資料的一些資訊
ds.element_spec
for images, label in ds.take(1): print(“images.shape: “, images.shape) print(“labels.shape: “, labels.shape)
跑完大概會印出下面的資訊。(因為沒實際跑,所以不確定,反正先記下官方的結果)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz 228818944/228813984 [==============================] - 6s 0us/step
Found 3670 images belonging to 5 classes.
float32 (32, 256, 256, 3) float32 (32, 5)
(TensorSpec(shape=(32, 256, 256, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 5), dtype=tf.float32, name=None))
Found 3670 images belonging to 5 classes. images.shape: (32, 256, 256, 3) labels.shape: (32, 5)
Dataset 的方法使用範例
---------------
剛剛其實都只講了怎麼轉換而已,再來就是要把看看 **`Dataset`** 真正好用的地方。當然就一樣使用官方範例來說明。
### 先引入相關模組
import numpy as np import tensorflow as tf
### 讀取資料
DATA_URL = “https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz";
path = tf.keras.utils.get_file(“mnist.npz”, DATA_URL) with np.load(path) as data: train_examples = data[“x_train”] train_labels = data[“y_train”] test_examples = data[“x_test”] test_labels = data[“y_tes”]
### 設定 Dataset 的批次大小以及打亂順序。
之前都是要去 **`fit()`** 那邊設定,現在建完 **`Dataset`** 就直接設定這些東西了,其實資料集相關處理就在資料集設定,這樣其實蠻直覺的,代表建立出來的 **`Dataset`** 物件,可以用一些方法去設定如何讀取資料,打亂資料,而這些屬性等同於跟著這個物件了。
train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels)) test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))
BATCH_SIZE = 64 SHUFFLE_BUFFER_SIZE = 100
train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE) test_dataset = test_dataset.batch(BATCH_SIZE)
### 建立神經網路模型
model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation=“relu”), tf.keras.layers.Dense(10) ])
model.compile(optimizer=tf.keras.optimizers.RMSprop(), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=[“sparse_categorical_accuracy”])
### 訓練模型、評估模型
model.fit(train_dataset, epochs=10)
model.evaluate(test_dataset)
```
經過 **`Dataset`** 的建立之後,訓練以及評估就變得精簡許多了,也代表我們可以**更容易替換**不同 **`Dataset`** 來訓練模型了。
#### 更詳細的使用過程,可以去看這個連結 ([Loading NumPy arrays](https://www.tensorflow.org/tutorials/load_data/numpy))。
結語
--
其實講了這麼多,還是要有個目標再來測試程式碼,才會比較快。這邊的官方範例說明只是冰山一角,還有很多東西其實沒講到,但我會用的也就那兩個 **`Dataset`** API,一個是**`tf.data.Dataset.from_tensor_slices`**,另一個是**`tf.data.Dataset.from_generator`**,所以也不需要全部都會用。
其實最後一個範例,除了 **`shuffle()`** 跟 **`batch()`**,還有一個 **`Dataset`** 的方法沒講到,就是 **`map()`** 函數,基本上前處理都靠它了,也可以藉由它減少記憶體使用,但這個函數就留到下一篇再來講吧!
預計這會做兩篇文章,下一篇文章就是我自己真正實作的程式碼了,今天這篇文章只是先來快速認知一下這東西的強大 (但我文筆不好,可能表現不出來),不知道有沒有感受到,雖然這裡應該只有我自己會看,但還是說一下,如果有更好的說明方式或是有錯誤的地方,歡迎留言告知。