The sky is the limit

Vue.js、PHP、Java、Cordova、Monacaを中心にハイブリッドアプリ開発、PWA開発など効率的なWEB、iOS、Androidアプリ開発の情報を共有します。

【機械学習】TensorFlowの基本分類チュートリアル「Fashion MNIST」を実行しながら解説 - Part1 -【TensorFlow】

【機械学習】TensorFlowの基本分類チュートリアル「Fashion MNIST」を実行しながら解説 - Part1 -【TensorFlow】

f:id:duo-taro100:20160218004611p:plain

最近はVue.jsは一旦休憩して、TensorFlowの勉強をしています。
そこで、TensorFlowの基本分類チュートリアルである「Fashion MNIST」を実行しつつ、解説していきたいと思います。
大まかに、こんなことをやっているんだと理解していただければ嬉しいです。

このチュートリアルについて

このチュートリアルは、機械学習の「Hello World!」とも言われる、「MNIST」というものを少しひねった基本分類の機械学習になります。
「Fashion MNIST」と呼ばれ、ファッション関連の画像を使って訓練し、投入した画像がどの分類に該当するかを推測することを目的としています。
やっていることは「MNIST」とあまり変わりないですが、若干難易度が上がっています。「MNIST」に触れたことのない方は、一度試してみてください。

ここでは既にTensorFlowを利用できる環境があることを前提として進めます。
環境構築がまだの方は以下の公式ドキュメントからお願いします。

www.tensorflow.org

全体の動き

まずは全体のソースコードと動きです。
チュートリアル上で、途中の結果を確かめるために用意しているコードなどは省略しています。

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

# get data
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# normalization
train_images = train_images / 255.0
test_images = test_images / 255.0

class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

# create model
model = keras.Sequential([
	# trans to one demention array
    keras.layers.Flatten(input_shape=(28, 28)),
    # has 128 nodes
    keras.layers.Dense(128, activation=tf.nn.relu),
    # has 10 nodes and the ten nodes value's sum is 1
    keras.layers.Dense(10, activation=tf.nn.softmax)
])

# compile model
model.compile(optimizer=tf.train.AdamOptimizer(), 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# train model
# 1. send training data to model
# 2. study relationship image and label
# 3. confirm to match test_image and test_label
model.fit(train_images, train_labels, epochs=5)

# evaluate sccuracy
test_loss, test_acc = model.evaluate(test_images, test_labels)

# prediction
predictions = model.predict(test_images)

# Grab an image from the test dataset
img = test_images[0]

# Add the image to a batch where it's the only member.
img = (np.expand_dims(img,0))

predictions = model.predict(img)
prediction = predictions[0]

result = np.argmax(prediction)
print(result)

これを動かした結果は以下の画像のようになります。

/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:34: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
Epoch 1/5
2018-07-20 23:53:28.865592: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
60000/60000 [==============================] - 6s 102us/step - loss: 0.5030 - acc: 0.8248
Epoch 2/5
60000/60000 [==============================] - 6s 95us/step - loss: 0.3825 - acc: 0.8635
Epoch 3/5
60000/60000 [==============================] - 5s 78us/step - loss: 0.3407 - acc: 0.8759
Epoch 4/5
60000/60000 [==============================] - 5s 83us/step - loss: 0.3164 - acc: 0.8844
Epoch 5/5
60000/60000 [==============================] - 4s 72us/step - loss: 0.2984 - acc: 0.8900
10000/10000 [==============================] - 0s 34us/step
9

細かい解説はこの後にやりますが、ここでは最後に「9」と表示されています。
これは投入データの分類が「9」であると推測されたことを示しています。
何がなんだか分からないかもしれませんが、とにかく進みましょう。

早速ソースコードを見ていきます。

チュートリアルに必要なインポート

まずは、TensorFlowをはじめとして、必要なライブラリなどをインポートします。

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
チュートリアル用のデータをインポート

チュートリアルで利用できるデータを取得するために以下のインポート文を記述します。

from tensorflow.examples.tutorials.mnist import input_data
TensorflowとKerasをインポート

主役のTensorflowと、ディープニューラルネットワークの実装を簡潔に可能にするライブラリである「Keras」をインポートします。

import tensorflow as tf
from tensorflow import keras

「Keras」については以下のページを参照してください。
Keras Documentation

NumPyとmatplotlibをインポート

最後に、数値計算のためのモジュールであるNumPyとPythonでグラフを描画する時に使用されるmatplotlibをインポートします。

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

NumPyは数値計算が得意です。行列計算なども可能で、機械学習の計算には欠かせません。
http://www.numpy.org/

matplotlibは前述した通り、グラフ描画用です。
Pythonでグラフ描画などを行う際によく使われます。
matplotlibはチュートリアル上では使っていますが、当ページで紹介した最初のソースコードでは使っていませんので、ここで出番は終わりです。
気になる方は、チュートリアル通りに進めて、途中の画像表示を確認して見てください。

必要なライブラリなどのインポートは完了です。
続いて、チュートリアルで利用するデータを読み込みましょう。

データの取得

データの読み込み方法

このチュートリアルでは訓練用の画像が60,000枚、評価用の画像が10,000枚用意されています。
TensorFlowからFashion MNISTにアクセスします。

# access to fashion_mnist
fashion_mnist = keras.datasets.fashion_mnist

上記のFashion MNISTを読み込みます。
左から訓練用の画像、訓練用のラベル、評価用の画像、評価用のラベルという順序で読み込まれます。
これらはそれぞれNumPy配列として取得されることに注意してください。

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
データの形式と分類

画像データ(ここではtrain_imagesやtest_images)は0〜255ピクセルで表される28x28のNumPy配列です。
例えば、train_imagesの1つ目のデータを、matplotlibを使って画像として表示します。またそのデータのNumPy配列も表示してみましょう。 

plt.figure()
plt.imshow(train_images[0])
plt.colorbar()
plt.gca().grid(False)
plt.show()

print(train_images[0])

上記を実行すると、matplotlibを使って表示した画像です。

f:id:duo-taro100:20180803172251p:plain

続いて、NumPy配列が以下のように帰ってきます。

[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   1   0   0  13  73   0   0   1   4   0   0   0   0   1   1   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3   0  36 136 127  62  54   0   0   0   1   3   4   0   0   3]
 [  0   0   0   0   0   0   0   0   0   0   0   0   6   0 102 204 176 134 144 123  23   0   0   0   0  12  10   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0 155 236 207 178 107 156 161 109  64  23  77 130  72  15]
 [  0   0   0   0   0   0   0   0   0   0   0   1   0  69 207 223 218 216 216 163 127 121 122 146 141  88 172  66]
 [  0   0   0   0   0   0   0   0   0   1   1   1   0 200 232 232 233 229 223 223 215 213 164 127 123 196 229   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0 183 225 216 223 228 235 227 224 222 224 221 223 245 173   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0 193 228 218 213 198 180 212 210 211 213 223 220 243 202   0]
 [  0   0   0   0   0   0   0   0   0   1   3   0  12 219 220 212 218 192 169 227 208 218 224 212 226 197 209  52]
 [  0   0   0   0   0   0   0   0   0   0   6   0  99 244 222 220 218 203 198 221 215 213 222 220 245 119 167  56]
 [  0   0   0   0   0   0   0   0   0   4   0   0  55 236 228 230 228 240 232 213 218 223 234 217 217 209  92   0]
 [  0   0   1   4   6   7   2   0   0   0   0   0 237 226 217 223 222 219 222 221 216 223 229 215 218 255  77   0]
 [  0   3   0   0   0   0   0   0   0  62 145 204 228 207 213 221 218 208 211 218 224 223 219 215 224 244 159   0]
 [  0   0   0   0  18  44  82 107 189 228 220 222 217 226 200 205 211 230 224 234 176 188 250 248 233 238 215   0]
 [  0  57 187 208 224 221 224 208 204 214 208 209 200 159 245 193 206 223 255 255 221 234 221 211 220 232 246   0]
 [  3 202 228 224 221 211 211 214 205 205 205 220 240  80 150 255 229 221 188 154 191 210 204 209 222 228 225   0]
 [ 98 233 198 210 222 229 229 234 249 220 194 215 217 241  65  73 106 117 168 219 221 215 217 223 223 224 229  29]
 [ 75 204 212 204 193 205 211 225 216 185 197 206 198 213 240 195 227 245 239 223 218 212 209 222 220 221 230  67]
 [ 48 203 183 194 213 197 185 190 194 192 202 214 219 221 220 236 225 216 199 206 186 181 177 172 181 205 206 115]
 [  0 122 219 193 179 171 183 196 204 210 213 207 211 210 200 196 194 191 195 191 198 192 176 156 167 177 210  92]
 [  0   0  74 189 212 191 175 172 175 181 185 188 189 188 193 198 204 209 210 210 211 188 188 194 192 216 170   0]
 [  2   0   0   0  66 200 222 237 239 242 246 243 244 221 220 193 191 179 182 182 181 176 166 168  99  58   0   0]
 [  0   0   0   0   0   0   0  40  61  44  72  41  35   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0]]

分かりにくいので、画像と配列を横に並べてみました。

f:id:duo-taro100:20180803172309p:plain

見ていただくとわかるように、画像の濃淡をNUmPy配列で表しています。
この形を見るとどうやらブーツのようなものを表しているようです。

次に、ラベルデータ(ここではtrain_labelsやtest_labels)ですが、これらは0〜9の整数配列です。
この0〜9の値は、それぞれ分類されるクラスを表しています。
このチュートリアルでは以下のようなクラス分けがされています。

0 => Tシャツ/トップ
1 => ズボン
2 => セーター
3 => ドレス
4 => コート
5 => サンダル
6 => シャツ
7 => スニーカー
8 => バッグ
9 => アンクルブーツ

そこで、先ほどの画像データに対応するラベル(train_labelsの1つ目のデータ)も表示します。

print(train_labels[0])

すると整数の「9」と表示されます。
一つ目の画像データは「9」、つまりアンクルブーツに分類されることを示しているわけです。

さて、データの意味と形式がわかったところで、機械学習実装の本題に入っていきます。
ただ、ここからも解説が長くなりますので、実装については次回解説したいと思います。

次回は以下のリンクから。
www.sky-limit-future.com