学習済みMobileNetV2モデルによる推論をTensorFlow.jsとWebWorkerを使ってブラウザ上で実行

将棋駒画像分類の話の続き。

memo.sugyan.com

学習させたモデルでの分類結果を実際に試すときに Web上でもインタラクティブに出来ると便利そう、と思ってやってみた。

学習済みモデルの変換

まずは学習済みのモデルをTensorFlow.js用に変換する必要がある。ここまではPythonの領域。

普通に tf.train.Saver を使っていると、checkpoint形式でパラメータが保存される。

./logdir
├── checkpoint
├── events.out.tfevents.1537850055.****.local
├── graph.pbtxt
├── model.ckpt-1500.data-00000-of-00001
├── model.ckpt-1500.index
└── model.ckpt-1500.meta

このままではダメなので Frozen Model の形式に変換したいところ。 TensorFlow の freeze_graph を使って

$ freeze_graph --input_checkpoint logdir/model.ckpt-1500 --output_node_names 'MobilenetV2/Logits/output' --input_graph logdir/graph.pbtxt --output_graph output_graph.pb

のようにすれば Frozen Model の形式にはなるのだけど、前回の記事のように学習時に作ったモデルそのままだとダメらしい。

import tensorflow.contrib.slim as slim
from nets.mobilenet import mobilenet_v2


def train():

    ...

    inputs = ...
    class_count = ...
    with slim.arg_scope(mobilenet_v2.training_scope(is_training=True)):
        logits, _ = mobilenet_v2.mobilenet(inputs, num_classes=class_count)

のように training_scope() 上で作られたモデルは dropout や batch normalization などが入った training mode として動作するものになるので、このまま freeze_graph しても挙動は training mode のままになってしまうっぽい。

なので、一度 non-training mode として復元してから freeze するようにする。 freeze_graph コマンドも中では tensorflow.python.framework.graph_util を使って変数を定数に変換しているだけなので、これを利用する。

import tensorflow as tf
from tensorflow.python.framework import graph_util
from nets.mobilenet import mobilenet_v2

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('checkpoint_path', 'logdir/model.ckpt',
                           '''Path to checkpoint file''')
tf.app.flags.DEFINE_string("output_graph", 'output_graph.pb',
                           """Path to write the frozen 'GraphDef'""")


def main(argv=None):
    class_count = ...
    placeholder = tf.placeholder(tf.float32, shape=(None, 96, 96, 3))
    logits, _ = mobilenet_v2.mobilenet(placeholder, class_count)
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, FLAGS.checkpoint_path)
        output = graph_util.convert_variables_to_constants(
            sess, tf.get_default_graph().as_graph_def(), ['MobilenetV2/Logits/output'])
    with open(FLAGS.output_graph, 'wb') as f:
        f.write(output.SerializeToString())


if __name__ == '__main__':
    tf.app.run(main)

training_scope() をつけずに呼ぶことで(もしくは明示的に training_scope(is_training=False) のscopeにしてもよい) 、 non-training mode でのモデルが出来る。 そこから tr.train.Saver で restore した上で graph_util.convert_variables_to_constants() を呼んで、serializeすれば目的の Frozen Model が出来上がる。

label情報も含める

上記の Frozen Model で一応「入力画像に対する推論結果」を得られるが、その出力は単に計算結果の logits としての数値列で、「最も数値の高かったindexは どのlabelに対応するか」という情報が無い。 別ファイルで保存してあればそれを読んで照らし合わせてやればいいけど、ここから変換してJavaScriptの世界で使おうとするところでそれは面倒。 label情報も Frozen Model に含めてやりたい。

labels.txt にlabel名が羅列してあるとしたら、それを "," 区切りで繋げた文字列とかを定数として定義して freeze 時に加えてやれば良い。

def main(argv=None):
    with tf.gfile.Open(FLAGS.labels) as f:
        labels = [line.strip() for line in f.readlines()]
    labels_str = tf.constant(list(','.join(labels).encode()), dtype=tf.int32, name='labels')

    placeholder = tf.placeholder(tf.float32, shape=(None, 96, 96, 3))
    logits, _ = mobilenet_v2.mobilenet(placeholder, len(labels))
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess, FLAGS.checkpoint_path)
        output = graph_util.convert_variables_to_constants(
            sess, tf.get_default_graph().as_graph_def(), ['MobilenetV2/Logits/output', 'labels'])
    with open(FLAGS.output_graph, 'wb') as f:
        f.write(output.SerializeToString())

tf.string の型にしてやれば良いだけかと思ったけど、次のconvert時に tf.string はsupportされていない、とエラーになってしまうので、仕方ないのでバイト列に変換したものを tf.int32 の1階Tensorとして無理矢理使う。 graph_util.convert_variables_to_constants() の第3引数 output_node_names は複数を指定できるので、 'MobilenetV2/Logits/output''labels' と2つ指定するよう変更。

Web format に変換

ここまで出来れば、あとは tensorflowjs_converter で変換するだけ。

$ pip install tensorflowjs
$ mkdir js
$ tensorflowjs_converter --input_format tf_frozen_model --output_node_names 'MobilenetV2/Logits/output,labels' output_graph.pb ./js
$ ls js
group1-shard1of3
group1-shard2of3
group1-shard3of3
tensorflowjs_model.pb
weights_manifest.json

といった感じで コマンド一発で変換されたファイルが生成される。 あとはこれらを静的に配信するようにしておけば、JavaScriptの世界で利用できるようになる。はず。

TensorFlow.jsで推論

$ npm install @tensorflow/tfjs

必要なのは上記のみ。 前節で生成したファイルを配信しているURLを指定して loadFrozenModel() を呼ぶことでモデルが読み込まれる。

import * as tf from "@tensorflow/tfjs";
import { loadFrozenModel } from "@tensorflow/tfjs-converter";

const MODEL_URL = "****/tensorflowjs_model.pb";
const WEIGHTS_URL = "****/weights_manifest.json";

loadFrozenModel(
    MODEL_URL, WEIGHTS_URL,
).then((model: tf.FrozenModel) => {
    const labelsTensor: tf.Tensor = tf.tidy(() => {
        return model.execute(tf.tensor([], [0, 96, 96, 3]), "labels") as tf.Tensor;
    });
    const labels: string[] = String.fromCharCode(...labelsTensor.dataSync()).split(",");
    labelsTensor.dispose();

    ...
});

計算して出力する tensor の名前を model.execute() の第2引数で指定できるので、まずは適当な空の入力(定数を取り出すだけなので入力は何でも良い) と "labels" を上記のように指定することで、label情報を取り出せる。

tf.Tensor まわりはちょっとクセがあるので注意する。

不要になった tensordispose() で明示的に破棄する、もしくは tf.tidy() の中で処理を書いてGPUが余分なメモリ消費を消費しないよう気をつける必要がある、ということらしい。 計算結果の中身は data()dataSync()TypedArray 形式で得ることができる。

ということで実際に画像データ(ここでは ImageData[])を受け取って推論した softmax 出力の上位3つの結果を得るのは以下のような感じ。 こっちは "MobilenetV2/Logits/output" を出力する tensor として指定。 tf.topk() というAPIもあるし それを使っても良かったかも…

loadFrozenModel(
    MODEL_URL, WEIGHTS_URL,
).then((model: tf.FrozenModel) => {

    ...

    const data: ImageData[] = ...
    const softmax: tf.Tensor = tf.tidy(() => {
        const tensors: tf.Tensor[] = data.map((d: ImageData) => tf.fromPixels(d));
        const inputs: tf.Tensor = tf.stack(tensors).toFloat().div(tf.scalar(255.0));
        const logits: tf.Tensor = model.execute(inputs, "MobilenetV2/Logits/output") as tf.Tensor;
        return tf.softmax(logits);
    });
    const resultData: Float32Array = softmax.dataSync() as Float32Array;
    softmax.dispose();

    // sort and get top-k
    const results = [];
    for (let i: number = 0; i < resultData.length / labels.length; i++) {
        const values: Iscored[] = [];
        resultData.slice(labels.length * i, labels.length * (i + 1)).forEach((score: number, index: number) => {
            values.push({ index, score });
        });
        values.sort((a: Iscored, b: Iscored) => b.score - a.score);
        results.push(values.slice(0, 3).map((value: Iscored) => {
            const label: string = labels[value.index];
            return { score: value.score, label };
        }));
    }

    ...
});

WebWorker経由で結果を得る

で、上記のようなのを実際に画像を入力して試してみると。 TensorFlow.jsはデフォルトで WebGL backend を利用して高速に計算してくれるのだけど、 どうしても初回の呼び出し時には 2000-3000ms ほどかかってしまう。 以下のような理由らしい。

  1. Why is the predict() method for inference so much slower on the first call than the subsequent calls?

The time of first call also includes the compilation time of WebGL shader programs for the model. After the first call the shader programs are cached, which makes the subsequent calls much faster. You can warm up the cache by calling the predict method with an all zero inputs, right after the completion of the model loading.

https://github.com/tensorflow/tfjs-converter/blob/master/README.md#faq

また、これは自分の使い方が悪いのかもしれないけど 1回の入力に使う画像数が 76枚くらいを超えると極端にそれが遅くなる… どっかでメモリ使い過ぎてるのかな。要調査。 [75, 96, 96, 3] くらいまでの入力なら大丈夫だけど それより多くなると 5000-6000ms とか一気に遅くなる。

ともかく、例えば実行ボタンを押してから結果を表示させようとすると 数秒かかってしまうことがあるわけで その間 UIが止まってしまう。それは出来れば避けたい。

というわけで 計算のロジックを Web Workers に移すことを考えた。

worker-loader

今回はサーバサイドとフロントエンドを分離して開発していたので、できれば単一のJSでまとめてしまいたい。 ってことで webpack の worker-loader を利用。 これを使って { inline: true } を指定することで別ファイルでWorker用のJSを用意しなくてもイイカンジにやってくれるようだ。

$ npm install worker-loader

webpack.config.js は以下のような感じ。

module.exports = {
    entry: './src/index.tsx',
    module: {
        rules: [
            {
                test: /\.?worker\.ts$/,
                use: {
                    loader: 'worker-loader',
                    options: { inline: true }
                }
            },
            {
                test: /\.tsx?$/,
                use: 'ts-loader',
                exclude: /node_modules/
            }
        ]
    },
    resolve: {
        extensions: ['.tsx', '.ts', '.js']
    },
    ...
};

ts-loader と併用して使う場合は rules の順番に注意しないといけないようだ。

worker.ts に先程のような処理を以下のように書ける。

import * as tf from "@tensorflow/tfjs";
import { loadFrozenModel } from "@tensorflow/tfjs-converter";

const ctx: Worker = self as any;

...

loadFrozenModel(
    MODEL_URL, WEIGHTS_URL,
).then((model: tf.FrozenModel) => {
    ...

    ctx.addEventListener("message", (message: MessageEvent) => {
        ...

        ctx.postMessage(...);
    });
});

export default null as any;

入力データを addEventListener() で待ち受け、結果を postMessage() で返せば良い。

使う側(UI側)は、この Worker に対して postMessage() で入力画像を投げて addEventListener() で結果を待ち受ければ良い。 とはいえ連続で投げてしまうことも出来てしまうので、投げたものに対する結果を正しく得る必要がある。 単一のWorkerインスタンスを持つSingletonクラスのようなものを使って Promise を返す関数を提供するようにしてみた。

import Worker from "./worker";

export interface IpredictResult {
    score: number;
    label: string;
}

interface Iresponse {
    key: string;
    results: IpredictResult[][];
}

export default class WorkerProxy {
    public static predict(inputs: ImageData[]): Promise<IpredictResult[][]> {
        const worker: Worker = WorkerProxy.getInstance().worker;
        const key: string = Math.random().toString(36).slice(-8);
        return new Promise<IpredictResult[][]>((resolve) => {
            const listener = (ev: MessageEvent) => {
                const data: Iresponse = ev.data;
                if (data.key === key) {
                    resolve(data.results);
                    worker.removeEventListener("message", listener);
                }
            };
            worker.addEventListener("message", listener);
            worker.postMessage({ key, inputs });
        });
    }
    private static instance: WorkerProxy;
    private static getInstance(): WorkerProxy {
        if (!this.instance) {
            WorkerProxy.instance = new WorkerProxy();
        }
        return WorkerProxy.instance;
    }
    private worker: Worker;
    private constructor() {
        this.worker = new Worker();
    }
}

これによって、UI側では複数の画像も非同期にリクエストして結果を得ることができる。

import WorkerProxy, { IpredictResult } from "./worker-proxy";

...

const inputs: ImageData[] = ...
inputs.forEach((data: ImageData, i: number) => {
    WorkerProxy.predict([data]).then((results: IpredictResult[][]) => {
        ...
    });
});     

これで、冒頭のTweetみたいな感じにUIを止めずに バックグラウンドでのWebWorkerによる計算を使って推論結果を順次表示していける。

問題点

…ということで実現できたけど、これ、すごい遅いぞ…? 1枚の画像に対する処理でも 300ms前後かかるし、複数枚をまとめて渡すと線形に処理時間が倍増していく。

なんと、 WebWorker 内でのTensorFlow.jsの計算は、GPUを使ってくれないらしい。 確認したら WebWorker 内では tf.getBackend() の結果が cpu になっていた。

今現在、まさに進行中の話のようで 近いうちに解決してくれるのかもしれない。

warm up と 分割?

と、こういう内容を調べながらこの記事を書いていて思ったけど、結局 webgl backend を使えれば、「ある程度の決まったサイズの入力なら 初回以外は高速に処理できる」ということが分かったので、

  • [10, 96, 96, 3] くらいの入力を受け取るよう固定
    • 1個の入力([1, 96, 96, 3]) に対しては残り [9, 96, 96, 3] のダミーデータを連結して埋める
    • 11個以上の入力は分割して処理する
  • modelのload直後にはやっぱり空の [10, 96, 96, 3] を入力してwarm upして それから使うようにする

というようにすれば、この程度の計算なら もしかしてわざわざWebWorkerを使って非同期に処理しようとしなくても十分に高速に(UIを止めることなく)結果を得ることが出来るのでは… と思ったのでした。

結論

ぜんぜんわかってねぇ

追記

試しに WebWorker を使わずにフォアグラウンドで webgl を使って固定長入力に分割して warm up して処理時間を計測してみた。 入力画像は 81。

input [1 * 96 * 96 * 3] * 81 loop => warm up: 2088ms, execute: 1520ms
input [2 * 96 * 96 * 3] * 41 loop => warm up: 2086ms, execute: 960ms
input [3 * 96 * 96 * 3] * 27 loop => warm up: 2093ms, execute: 761ms
input [4 * 96 * 96 * 3] * 21 loop => warm up: 2135ms, execute: 640ms
input [5 * 96 * 96 * 3] * 17 loop => warm up: 2149ms, execute: 557ms
input [6 * 96 * 96 * 3] * 14 loop => warm up: 2270ms, execute: 527ms
input [7 * 96 * 96 * 3] * 12 loop => warm up: 2299ms, execute: 496ms
input [8 * 96 * 96 * 3] * 11 loop => warm up: 2146ms, execute: 444ms
input [9 * 96 * 96 * 3] *  9 loop => warm up: 2271ms, execute: 394ms
input [10 * 96 * 96 * 3] * 9 loop => warm up: 2204ms, execute: 452ms
input [11 * 96 * 96 * 3] * 8 loop => warm up: 2256ms, execute: 381ms
input [12 * 96 * 96 * 3] * 7 loop => warm up: 2203ms, execute: 399ms
input [14 * 96 * 96 * 3] * 6 loop => warm up: 2373ms, execute: 359ms
input [17 * 96 * 96 * 3] * 5 loop => warm up: 2310ms, execute: 332ms
input [21 * 96 * 96 * 3] * 4 loop => warm up: 2383ms, execute: 328ms
input [27 * 96 * 96 * 3] * 3 loop => warm up: 2410ms, execute: 316ms
input [42 * 96 * 96 * 3] * 2 loop => warm up: 2410ms, execute: 395ms

なるほど適切な数に分割して処理することで 初回実行のWarmUpではどうしても約2秒ほど止まってしまうけど、それ以降はかなり高速に計算が終わる感じにはなる。