Subscribed unsubscribe Subscribe Subscribe

TensorFlow 0.12 のEmbedding Visualizationを動かす

TensorFlow Python

というTweetをしたところ、結構な反応があったので せっかくなので記事にしておく。

元々自分が知ったのはこの記事からだったのですが。 qiita.com

どうやらTensorFlow 0.12(現時点ではまだRC0)にはTensorBoardに"Embedding Visualization"というのが追加されたそうで。

https://www.tensorflow.org/versions/r0.12/how_tos/embedding_viz/index.html

これに従って可視化したい変数を保存しておけば、tensorboardでそれらを可視化できる、ということで 自分のこれまでに作ってきたアイドル顔画像の分類器とデータセットを使って冒頭のTweetの画面は出来たのだけど、誰もがそういうイイカンジの分類器やデータを持っているわけじゃなくてすぐには触れないかも、と思ったのでデモ用のリポジトリを作っておきました。

github.com

中身

TensorFlowは学習済みのモデルも提供していて、1000クラスの画像分類を行うInception-v3なんかを簡単に動かして試せるようになっている。

https://www.tensorflow.org/versions/master/tutorials/image_recognition/index.html#usage-with-python-api

ので、今回はそれを使ってみることにした。

tensorflow.models.image.imagenetclassify_image.pyというのが入っていて、これを使うと学習済みのモデルをダウンロードして使えるようになる。

from tensorflow.models.image.imagenet import classify_image

classify_image.maybe_download_and_extract()
classify_image.create_graph()

with tf.Session() as sess:
    sess.run(...)

便利!

で、そのclassify_image.pyに以下のようなコメントが書いてある。

    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.

つまり'DecodeJpeg/contents:0'に入力のJPEGバイナリを入れて、'pool_3:0'を取り出せばその画像の特徴成分が抽出できるはず、と。 普通にImageNetの画像を入れて分類結果を見ても面白くないので、なんか違った画像を入力して特徴を可視化してみよう。

というところで自分のスマホを漁ってみたら、何故かうどんの写真がいっぱい出てきたので 今回のデモにはそれらを同梱しておきました。(ラーメンも1枚混ざってる)

f:id:sugyan:20161205015626p:plain

で、これらを入力してsess.run(pool_3)みたいに実行するとその画像に対する2048個の中間層出力が得られるので、これらを集めることで可視化対象の2D Tensorを得られる。 これを変数に入れてtf.train.Saverで保存すれば良いだけ、らしい。

embedding_var = tf.Variable(tf.pack([tf.squeeze(x) for x in outputs], axis=0), trainable=False, name='pool3')
sess.run(tf.variables_initializer([embedding_var]))
saver = tf.train.Saver([embedding_var])
saver.save(sess, <checkpoint path>)

これで点は表示されるけど、ラベルとか画像は別のファイルに書き出して関連付ける必要があるらしい。

config = projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = embedding_var.name
summary_writer = tf.train.SummaryWriter(<checkpoint dir>)

...

embedding.metadata_path = metadata_path
embedding.sprite.image_path = image_path
embedding.sprite.single_image_dim.extend([100, 100])
projector.visualize_embeddings(summary_writer, config)

ラベルについてはTSVとのことで、今回はファイル名だけを付けておいた。TensorBoard上でクリックすると表示されるはず。

f:id:sugyan:20161205021832p:plain

サムネ画像については、順番に並べたsprite画像を用意する必要がある、とのことで今回のように入力の元画像が全部分かっていればあらかじめ作っておくこともできるけど、 いちおうtf.concatとかで画像を表現するtensorを繋げて作ったりもできる。

デモ用画像は300x300で用意したけど、リサイズして100x100を繋げたものを生成するようにした。

thumbnail = tf.cast(tf.image.resize_images(tf.image.decode_jpeg(jpeg_data), [100, 100]), tf.uint8)

...

size = int(math.sqrt(len(images))) + 1
while len(images) < size * size:
    images.append(np.zeros((100, 100, 3), dtype=np.uint8))
rows = []
for i in range(size):
    rows.append(tf.concat(1, images[i*size:(i+1)*size]))
jpeg = tf.image.encode_jpeg(tf.concat(0, rows))
with open(image_path, 'wb') as f:
    f.write(sess.run(jpeg))

これで、以下のような画像が生成されるはず。

f:id:sugyan:20161205021851j:plain

で、めでたくデータが生成されて保存されればtensorboardで動かせる。 10数点しかないと寂しい感じになってよく分からないけど、雰囲気くらいは味わえるのでは…

あんまり特徴から分布を上手く作れている気もしないけど、T-SNEで次元削除した場合は釜揚げうどんが近いものになっているのは確認できた。

f:id:sugyan:20161205021838p:plain

注意点

Qiitaの参照記事にも書いてあるけど、0.12rc0ではまだちょっとバグもあるようで たぶんpython3だと上手く動かなくて、lib/python3.5/site-packages/tensorflow/tensorboard/plugins/projector/plugin.pyとかを書き換えてやる必要がありそう。

既に報告されていて修正も進んでいるようなので次のリリースでは直っていると思うけども…