将棋駒画像分類をMobileNetV2で最初から学習させる

前回の将棋駒画像分類の話の続き。

memo.sugyan.com

TensorFlow Hubの学習済みモデルを利用して 最終層にあたる部分だけ(?)を再学習させることで簡単に特定ドメインの画像分類のモデルを作成した。 …が、結果としてあまり精度が良くなくて、特に未学習の画像に対してかなりの高確率で誤分類してまっていた。

やはりもっと色々なバリエーションの駒画像を用意して学習させる必要があるか… と思ったが、その前にもう一つ試しておきたかったのでやってみた。

「学習済みモデルは本当にこのドメインの分類に適した特徴を学習できているのか?」

TensorFlow Hub で公開されている学習済みモデルは ILSVRC-2012-CLS 用のデータセットを用いて1000クラス分類のために学習してある。 けど、これが必ずしも自分がやろうとしている将棋駒画像の分類に適した特徴を抽出できるようになっているとは限らない。 もしかしたら、自分の用意したデータに適するよう全層を最初から学習させたら 学習済みモデルをretrainしたものより良くなることもあるのでは…?

というもの。

MobileNet V2

TensorFlow Models のrepositoryを見てみると、MobileNetV2モデルについてのコードや説明が載っている。

V1やV2の違いはよく分かっていないけど、とりあえず読んでみると tensorflow.contrib.slim を使ってモデルの構造が記述されていて、簡単に利用できるようになっているようだ。

from nets.mobilenet import mobilenet_v2

with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope()):
    logits, endpoints = mobilenet_v2.mobilenet(input_tensor)

という形で training_scope() の中で呼ぶことで学習モードとしてモデルを定義できるらしい。 なのでこれに合わせて入力と学習手続きを定義していけば良さそう。

Inputs

とりあえず前回の記事で使った retrain.py に倣って学習用画像データを保存したディレクトリから 一定の割合で "training" と "validation" で別々のデータに分かれるようにしてリストを取得。

def create_image_lists(image_dir, validation_percentage):
    result = collections.OrderedDict()
    sub_dirs = [d for d in tf.gfile.ListDirectory(image_dir) if tf.gfile.IsDirectory(os.path.join(image_dir, d))]
    for sub_dir in sub_dirs:
        file_list = []
        dir_name = os.path.basename(sub_dir)
        file_glob = os.path.join(image_dir, dir_name, '*.jpg')
        file_list.extend(tf.gfile.Glob(file_glob))
        training_images = []
        validation_images = []
        for file_name in file_list:
            base_name = os.path.basename(file_name)
            # https://github.com/tensorflow/hub/blob/master/examples/image_retraining/retrain.py
            hashed = hashlib.sha1(tf.compat.as_bytes(file_name)).hexdigest()
            percentage_hash = int(hashed, 16) % 100
            if percentage_hash < validation_percentage:
                validation_images.append(base_name)
            else:
                training_images.append(base_name)
        result[dir_name] = {
            'training': training_images,
            'validation': validation_images,
        }
    return result

Datasets

取得した画像のリストを使って、入力データを作成する。 今回はせっかくなので tf.data APIを利用してみることにした。

基本的には、 tf.data.Datasetfrom_...Dataset instanceを作り、そこからiteratorを取得、その tf.data.Iterator instanceから get_next() を呼ぶことで入力Tensorを作ることができる、というものらしい。 画像分類のためのデータセットを作る場合は、対応するimages, labelsのセットを用意して使えば良い。 画像のファイルパスから内容を展開するなどの中間処理をする必要がある場合は Dataset.map で処理を書く。

def parser(file_path, label_index):
    image = tf.image.decode_jpeg(tf.read_file(file_path), channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize_images(image, [96, 96])
    return image, tf.to_int64(label_index)

image_files, labels = [...], [...]
dataset = tf.data.Dataset.from_tensor_slices((image_files, labels))
dataset = dataset.map(parser)
dataset = dataset.repeat()
dataset = dataset.batch(FLAGS.batch_size)

iterator = dataset.make_initializable_iterator()
inputs, labels = iterator.get_next()

このようにして入力のbatchを作成できる。

今回は traing用のデータセットと validation用のデータセットを明示的に分けようと思って、 retrain.py と同様に一定の割合で振り分け、それぞれのデータを元に別々の Dataset, Iterator を返すようにした。

中身が違うだけで同shapeの入力を扱う場合には "reinitializable iterator" というのがあって、単一のiteratorから各Dataset用にinitializerを作ってそれを実行することで get_next() で得られる値を切り換える、という仕組みもあるようなのだけど、trainingとvalidationの切り替えを頻繁に行う場合はそのたびにinitializeすることになり、そうなるとshuffleする際に大きめの buffer_size を確保する必要がありそうで、そうすると学習開始時にかなり大きくメモリ確保してバッファリング処理をするようになって ちょっと微妙だなーと思って 結局ここでは "reinitializable iterator" は使わず 別々の Dataset として処理するようにした。 shuffleの重要度、validationの頻度などによっては普通に便利に使用できるのかもしれない。ちょっとよく分からない。

def shogi_inputs(image_lists):
    class_count = len(image_lists.keys())
    t_count, v_count = 0, 0
    for l in image_lists.values():
        t_count += len(l['training'])
        v_count += len(l['validation'])

    def generate_dataset(category):
        images = []
        labels = []
        label_names = []
        for label_index in range(class_count):
            label_name = list(image_lists.keys())[label_index]
            label_names.append(label_name)
            category_list = image_lists[label_name][category]
            for basename in category_list:
                images.append(os.path.join(FLAGS.image_dir, label_name, basename))
                labels.append(label_index)
        zipped = list(zip(images, labels))
        random.shuffle(zipped)
        return tf.data.Dataset.from_tensor_slices((
            [e[0] for e in zipped],
            [e[1] for e in zipped]))

    def parser(file_path, label_index):
        image = tf.image.decode_jpeg(tf.read_file(file_path), channels=3)
        image = tf.image.convert_image_dtype(image, tf.float32)
        image = tf.image.resize_images(image, [96, 96])
        return image, tf.to_int64(label_index)

    t_dataset = generate_dataset('training')
    t_dataset = t_dataset.map(parser)
    t_dataset = t_dataset.repeat()
    t_dataset = t_dataset.shuffle(FLAGS.batch_size * 10)
    t_dataset = t_dataset.batch(FLAGS.batch_size)

    v_dataset = generate_dataset('validation')
    v_dataset = v_dataset.map(parser)
    v_dataset = v_dataset.repeat()
    v_dataset = v_dataset.batch(FLAGS.batch_size * 5)

    return [
        t_dataset.make_initializable_iterator(),
        v_dataset.make_initializable_iterator(),
        t_count,
    ]

Training and Validation

それぞれの Dataset を使って入力画像とラベルのTensorを得られるので、実際にモデルに入力して得た結果を使って training operation と validation accuracy を作る。 trainingのあたりは models/research/slim/nets/mobilenet_v1_train.py という MobileNetV1用のスクリプトがあったのでそれを真似している。

t_iter, v_iter, training_count = shogi_inputs(image_lists)
t_inputs, t_labels = t_iter.get_next()
v_inputs, v_labels = v_iter.get_next()
with slim.arg_scope(mobilenet_v2.training_scope()):
    t_logits, _ = mobilenet_v2.mobilenet(t_inputs, num_classes=class_count)
    v_logits, _ = mobilenet_v2.mobilenet(v_inputs, num_classes=class_count, reuse=True)

# training
tf.losses.sparse_softmax_cross_entropy(t_labels, t_logits)
total_loss = tf.losses.get_total_loss(name='total_loss')
num_epochs_per_decay = 2.5
decay_steps = int(training_count / FLAGS.batch_size * um_epochs_per_decay)
learning_rate = tf.train.exponential_decay(
    0.045,
    tf.train.get_or_create_global_step(),
    decay_steps,
    _LEARNING_RATE_DECAY_FACTOR,
    staircase=True)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    train_tensor = slim.learning.create_train_op(
        total_loss,
        optimizer=tf.train.GradientDescentOptimizer(learning_rate))

# validation accuracy
indices = tf.argmax(v_logits, axis=1)
correct = tf.equal(indices, v_labels)
accuracy = tf.reduce_mean(tf.to_float(correct))

これが出来たら、あとは学習を回すだけ。 ここも mobilenet_v1_train.py では slim.learning.train というのを使っていたので、それを真似してみた。

slim.learning.train はどうやらtrain tensorを渡すと定期的に記録しながら指定したstep数の学習を回してくれるようなのだけど、例えば50 stepごとにvalidationを回して記録したい、とか思うと train_step_fn で各stepで何をするかを定義してやる必要があるようだ。 正直ここまでやるんだったら slim.learning.train は使わずに普通にfor loop回すだけでも良いような気もする…。

# train step function
def train_step(sess, train_op, global_step, train_step_kwargs):
    start_time = time.time()
    total_loss, np_global_step = sess.run([train_op, global_step])
    time_elapsed = time.time() - start_time
    # validation
    if np_global_step % 50 == 0:
        logging.info('validation accuracy: %.4f', sess.run(accuracy))

    if 'should_log' in train_step_kwargs:
        if sess.run(train_step_kwargs['should_log']):
            logging.info('global step %d: loss = %.4f (%.3f sec/step)',
                         np_global_step, total_loss, time_elapsed)
    if 'should_stop' in train_step_kwargs:
        should_stop = sess.run(train_step_kwargs['should_stop'])
    else:
        should_stop = False
    return total_loss, should_stop

# start training
g = tf.Graph()
with g.as_default():
    init_op = tf.group(t_init, v_init, tf.global_variables_initializer())
    slim.learning.train(
        train_tensor,
        FLAGS.checkpoint_dir,
        graph=g,
        number_of_steps=FLAGS.number_of_steps,
        save_summaries_secs=FLAGS.save_summaries_secs,
        save_interval_secs=FLAGS.save_interval_secs,
        local_init_op=init_op,
        train_step_fn=train_step,
        global_step=tf.train.get_global_step())

Results

実際このスクリプトを回して学習開始してみると、最初は当然3%くらいの正答率で そこから徐々にlossが減少し正答率が上昇していくのが観測できる。 1,000 step前後で正答率98%〜 と、十分に高い精度まで上がるようだった。

f:id:sugyan:20180707000650p:plain

f:id:sugyan:20180707000700p:plain

さすがにCPUだと1 stepにも数秒かかるくらいのスピードでそれなりに時間がかかる。 EC2 P3 instanceを使って回してみたところ、1,000 step程度ならものの数分で終了するようだった。

Evaluation

この学習済みモデルを使って、実際に前回記事と同じ画像を与えてどう識別されるかを確認してみる。

1200 stepほど学習した後のcheckpointファイルからモデルを復元し、局面図の画像を分割してそれぞれ識別してみる。

1. 学習データに使った素材で作ったもの

f:id:sugyan:20180501012958p:plain

前回の結果:

-KI * +TO * -OU * +TO+TO+NY
 * -FU *  *  * +FU *  *  * 
+RY-TO *  * -FU-FU * -FU-KY
 * +KE+KE-GI-KI-TO *  * +RY
-UM-NK+KY+TO * -TO *  *  * 
 * +FU *  *  * -GI-GI+FU * 
 *  * +FU * +TO-GI *  * +KE
 *  *  *  * +TO * +FU * +KY
+UM *  *  *  *  * +KI * -KI

今回の結果:

-KI * +TO * -OU * +TO+TO+NY
 * -FU *  *  * +FU *  *  * 
+RY-TO *  * -FU-FU * -FU-KY
 * +KE+KE-GI-KI-TO *  * +RY
-UM-NK+KY+TO * -TO *  *  * 
 * +FU *  *  * -GI-GI+FU * 
 *  * +FU * +TO-GI *  * +KE
 *  *  *  * +TO * +FU * +KY
+UM *  *  *  *  * +KI * -KI

さすがにこれは全部正解するようだ。

2. Shogipicで生成された局面図

f:id:sugyan:20180501013010p:plain

前回の結果:

-KI * +TO * -OU * +TO+TO+NY
 * -UM *  *  * +FU *  *  * 
+RY-TO+UM+UM-UM-FU ? -UM-KY
 * +KE+KE-KA-KI-TO ?  * +RY
-UM-KA+KY+TO * -TO *  *  * 
 * +FU ?  ?  * -KE-KE+FU * 
 *  * +FU ? +TO-KE ?  * +KE
 *  *  *  * +TO * +FU * +KY
+UM *  *  *  *  * +KI * -KI

今回の結果:

-KI * +TO * -OU * +NK+NK-NG
 * -FU *  *  * +FU *  *  * 
+RY-TO *  * -FU-FU * -FU-KY
 * +KE+KE-GI-KI-TO *  * +RY
-UM-NK+KY+TO * -TO *  *  * 
 * +FU *  *  * -GI-GI+FU * 
 *  * +FU * +TO-GI *  * +KE
 *  *  *  * +TO * +FU * +KY
+UM *  *  *  *  * +KI * -KI

何故か1段目の「と金」が「成桂」になっていたりといった誤答はあるが、前回のように駒の無いはずのところで駒があると判別してしまうような間違いは無くなっているようだ。

3. 激指 14の局面スクリーンショット

f:id:sugyan:20180501013026p:plain

前回の結果:

-KI * +TO * -OU * +TO+TO+NY
 * -FU *  *  * +FU *  *  * 
+RY-TO *  * -FU-FU * -FU-KY
 * +HI+HI-GI-KI-TO *  * +RY
-UM-NG+KY+TO * -TO *  *  * 
 * +FU *  *  * -GI-GI+FU * 
 *  * +OU * +TO-GI *  * +HI
 *  *  *  * +TO * +FU * +KY
+RY *  *  *  *  * +KI * -KI

今回の結果:

-KI *  *  * -OU *  *  * +NK
 * -FU *  *  * +FU *  *  * 
+RY-TO *  * -FU-FU * -FU-KY
 * +KE+KE-GI-KI-TO *  * +RY
-UM-NK+KY *  * -TO *  *  * 
 * +FU *  *  * -GI-GI+KY * 
 *  * +FU * +TO-GI *  * +KE
 *  *  *  *  *  * +KY * +KY
+RY *  *  *  *  * +KI * -KI

攻方の「と金」が空白として識別されて消えてしまっていたり、「歩兵」が「香車」と認識されていたり、間違いが多い…

4. Shogi.ioの局面スクリーンショット

f:id:sugyan:20180501013041p:plain

前回の結果:

-NG * +TO * -FU * +TO+TO+UM
 * -FU *  *  * +NK *  *  * 
-UM-TO ?  * -FU-FU * -FU-NG
 * +KE+OU-GI-NG-TO *  * +KE
-NG-UM+NG+TO * -TO *  *  * 
 * +NK ?  *  * -GI-GI-UM * 
 *  * +NK * +TO-GI *  * +KE
 *  *  *  * +TO * -UM * +NG
-UM *  *  *  *  * +NG * -NG

今回の結果:

-NG *  *  * +OU *  *  * -TO
 * -NK *  *  * +NY *  *  * 
-UM-TO *  * -NK-NK * -NK-NG
 * +KE+KE-NG-NG-TO *  * +GI
-UM-NG-UM *  * -TO *  *  * 
 * +NY *  *  * +RY-HI+NY * 
 *  * +NG *  * +RY *  * +KE
 *  *  *  *  *  * +NY * +FU
-KE *  *  *  *  * +FU * -NG

前回のも全然合ってなくてひどかったが、今回のはもっとひどい…

まとめ

TensorFlowでMobileNetV2を最初から学習させることができた、けど別にそれが出来たからといって性能が改善するわけでもない。

ラクして生成したものだけを使って…ではなく、ちゃんと様々な学習用データセットを用意して 学習させていくしかなさそう。

Repository