StyleGAN2による画像生成をCPU環境/TensorFlow.jsで動かす

memo.sugyan.com

の続き。

ようやくTensorFlow.jsを使ってブラウザ上で動かせるようになったので、そのためにやったことメモ。

(まだまだ画像の質とかパフォーマンスの問題とかは色々ある)

CPU環境で動かす

最終的にはTensorFlow.jsでブラウザ上で動かすことが目標だったので別にCPUで動かせる必要は無かったのだけど、どうもGPU環境でしか動かない特殊なOpなどもあってそれが変換後のモデルで実行時にエラーを引き起こす原因だったりするため、まずはCPU環境でも安定して動くようにするのが確実なようだ。

前回記事にも書いた通り、 StyleGAN2 の学習の過程で出力される .pkl ファイルは、CPU環境では読み込むことも出来ない。

import pickle
from dnnlib import tflib

tflib.init_tf()
with open('network.pkl', 'rb') as fp:
    _, _, Gs = pickle.load(fp, encoding='latin1')

Gs.print_layers()
...

RuntimeError: NVCC returned an error. See below for full command line and output log:

nvcc "/Users/sugyan/.ghq/github.com/NVlabs/stylegan2/dnnlib/tflib/ops/fused_bias_act.cu" ...

/bin/sh: nvcc: command not found

そりゃCUDAが入っていない環境ではこうなる。 ので、無理矢理にコードを書き換えてCUDAを使わない reference implementation を利用するようにする。 特に upfirdn_2d の方は幾つかの場所から呼ばれるので、デフォルト引数をまとめて書き換えるのも良いけど impl_dict を書き換えて reference implementation を強制してしまうのがラク

diff --git a/dnnlib/tflib/ops/fused_bias_act.py b/dnnlib/tflib/ops/fused_bias_act.py
index 52f6bfd..c294277 100755
--- a/dnnlib/tflib/ops/fused_bias_act.py
+++ b/dnnlib/tflib/ops/fused_bias_act.py
@@ -63,7 +63,7 @@ def fused_bias_act(x, b=None, axis=1, act='linear', alpha=None, gain=None, impl=

     impl_dict = {
         'ref':  _fused_bias_act_ref,
-        'cuda': _fused_bias_act_cuda,
+        'cuda': _fused_bias_act_ref,
     }
     return impl_dict[impl](x=x, b=b, axis=axis, act=act, alpha=alpha, gain=gain)

diff --git a/dnnlib/tflib/ops/upfirdn_2d.py b/dnnlib/tflib/ops/upfirdn_2d.py
index fd23777..1df2935 100755
--- a/dnnlib/tflib/ops/upfirdn_2d.py
+++ b/dnnlib/tflib/ops/upfirdn_2d.py
@@ -57,7 +57,7 @@ def upfirdn_2d(x, k, upx=1, upy=1, downx=1, downy=1, padx0=0, padx1=0, pady0=0,

     impl_dict = {
         'ref':  _upfirdn_2d_ref,
-        'cuda': _upfirdn_2d_cuda,
+        'cuda': _upfirdn_2d_ref,
     }
     return impl_dict[impl](x=x, k=k, upx=upx, upy=upy, downx=downx, downy=downy, padx0=padx0, padx1=padx1, pady0=pady0, pady1=pady1)

これで 引数が impl='cuda' で指定されてこようが関係なく ref の方が使われるようになる。

こうすると .pkl を読み込むのは問題なくできるようになる。 ただここから実際に計算を実行しようとすると問題が起きるわけで。

Gs.run()GPU device を必要とするので output tensor を指定して tf.Session.run() してみる。

import pickle
import numpy as np
import tensorflow as tf
from dnnlib import tflib

tflib.init_tf()
with open('network.pkl', 'rb') as fp:
    _, _, Gs = pickle.load(fp, encoding='latin1')

graph = tf.get_default_graph()
inputs = graph.get_tensor_by_name(f'Gs/{Gs.input_names[0]}:0')
outputs = graph.get_tensor_by_name(f'Gs/{Gs.output_names[0]}:0')

rnd = np.random.RandomState(0)
z = rnd.randn(1, *Gs.input_shape[1:])
print(tf.get_default_session().run(outputs, feed_dict={inputs: z}))
2020-02-04 23:26:46.471886: E tensorflow/core/common_runtime/executor.cc:642] Executor failed to create kernel. Invalid argument: Conv2DCustomBackpropInputOp only supports NHWC.
         [[{{node Gs/G_synthesis/8x8/Conv0_up/conv2d_transpose}}]]

...

StyleGAN2のモデルはGPU環境で学習する前提で作られているので、その環境に最適化された処理が幾つかある。そのうちの一つが NCHW のdata formatを使っていること。 この形でmodelが作られていると、CPU環境では NHWC にしか対応していないので計算を実行することが出来ないようだ。

1. Graphを書き換える

ということで困った、という記事を書いたところ、以下のようなフィードバックをいただいた。

なるほどー、構築されたGraphを舐めていって inputsoperation を書き換えることで NCHWNHWC に変換する方法があるのか…!

ということで上記を参考にしながら自分で書いてみた。

tflib.init_tf()
with open('network.pkl', 'rb') as fp:
    _, _, Gs = pickle.load(fp, encoding='latin1')

graph = tf.get_default_graph()

target_ops = []
for op in graph.get_operations():
    if not op.name.startswith('Gs/'):
        continue
    if 'data_format' in op.node_def.attr and op.node_def.attr['data_format'].s == b'NCHW':
        target_ops.append(op)

まずは Gs/ 以下の、 'NCHW' という data_format attribute を持つ operation をすべて抽出する。これらが書き換える対象となる。

for target_op in target_ops:
    print(f'op: {target_op.name} ({target_op.type})')
op: Gs/G_synthesis/4x4/Conv/Conv2D (Conv2D)
op: Gs/G_synthesis/4x4/ToRGB/Conv2D (Conv2D)
op: Gs/G_synthesis/8x8/Conv0_up/conv2d_transpose (Conv2DBackpropInput)
op: Gs/G_synthesis/8x8/Conv0_up/Conv2D (Conv2D)
op: Gs/G_synthesis/8x8/Conv1/Conv2D (Conv2D)
op: Gs/G_synthesis/8x8/Upsample/Conv2D (Conv2D)
op: Gs/G_synthesis/8x8/ToRGB/Conv2D (Conv2D)
op: Gs/G_synthesis/16x16/Conv0_up/conv2d_transpose (Conv2DBackpropInput)
op: Gs/G_synthesis/16x16/Conv0_up/Conv2D (Conv2D)

...

それらの operation に対し、まずは inputs のshapeを変換していく。

for target_op in target_ops:
    # Input tensors
    if target_op.type == 'Conv2D':
        inputs = [
            tf.transpose(
                target_op.inputs[0],
                [0, 2, 3, 1],
                name=f'{target_op.name}_input_transpose'),
            target_op.inputs[1]
        ]
    elif target_op.type == 'Conv2DBackpropInput':
        inputs = [
            tf.gather(
                target_op.inputs[0],
                [0, 2, 3, 1],
                name=f'{target_op.name}_output_shape_transpose'),
            target_op.inputs[1],
            tf.transpose(
                target_op.inputs[2],
                [0, 2, 3, 1],
                name=f'{target_op.name}_value_transpose')
        ]

実際に見てみると分かるが、こうして 'NCHW' な op を抽出してみると すべて typeConv2DConv2DBackpropInput のどちらかしかない。 少なくとも StyleGAN2 では、この2つのtypeに対してそれぞれ対応するだけで良い、ということになる。

Conv2D の場合、 inputs0 番目に入力のTensorが入ってくる。これが例えば (?, 1, 11, 11) だったりして (N, C, H, W) に対応する。 ので、 tf.transpose[0, 2, 3, 1] を指定することでこの入力Tensor(N, H, W, C) に変換することが出来る。 1 番目の入力は filter の値のようで、これはdata formatに依存しないのでこのまま使えば良い。

Conv2DBackpropInput の場合はもう少し厄介。どうやら入力は output_shape, filter, value という順番で来るらしい。 1 番目は Conv2D の場合と同様そのまま使って 2 番目が入力Tensorなので やはり同様に [0, 2, 3, 1] で transpose してやる。 そして 0 番目が その出力結果のshapeをどうするか指定するという役割のようで、そのshapeを示す (4,)Tensorが入ってくる。 これは例えば [1, 512, 9, 9] といった形の値で やはり NCHW ならその形に指定するわけだけど、ここではこの op を NHWC に変えてやりたいので、この output_shape も書き換えてやらないと その後の出力の型が合わなくなってしまう。 tf.gather を使ってこの output_shape の中身の順番を入れ替える。

これが出来たら次は attributes。

    # Attributes
    attrs = {}
    for k, v in target_op.node_def.attr.items():
        if k == 'data_format':
            continue
        if target_op.type == 'Conv2DBackpropInput' and k == 'strides':
            strides = v.list.i
            attrs[k] = tf.AttrValue(list=tf.AttrValue.ListValue(i=[
                strides[0],
                strides[2],
                strides[3],
                strides[1],
            ]))
        else:
            attrs[k] = v

各 operation は入力値とは別に? attributes というものを持っているようで、これも書き換えてやる必要がある。

data_formatNCHW である、という情報はここに含まれているので、変換する際にはこれを捨ててしまうことで defaultの NHWC にすることが出来る。

もう一つ Conv2DBackpropInput の場合に変更する必要があるのが strides の値で、これも NCHW のときと NHWC のときで扱いが変わるものらしい。

この strides と前述の output_shape については upfirdn_2d.upsample_conv_2d() に分岐が書かれている。

upfirdn_2d.py 抜粋:

    # Determine data dimensions.
    if data_format == 'NCHW':
        stride = [1, 1, factor, factor]
        output_shape = [_shape(x, 0), outC, (_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW]
        num_groups = _shape(x, 1) // inC
    else:
        stride = [1, factor, factor, 1]
        output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
        num_groups = _shape(x, 3) // inC

というわけで この strides attributes の値も [0, 2, 3, 1] の順に並び換えたものを用意する。

ここまで準備できたらいよいよ operation の置き換え。

    # New operations
    new_op = graph.create_op(op_type=target_op.type, inputs=inputs, name=f'{target_op.name}_nhwc', attrs=attrs)
    output = tf.transpose(new_op.outputs[0], [0, 3, 1, 2], name=f'{new_op.name}_output')

    # Update connections
    ops = [op for op in graph.get_operations() if target_op.outputs[0] in op.inputs]
    for op in ops:
        for i, input_tensor in enumerate(op.inputs):
            if input_tensor.name == target_op.outputs[0].name:
                op._update_input(i, output)

元の operation と同じ type, attrs を持つ新しい operation を作成する。入力は NHWC に transpose したもの。 ということは この operation の outputs も NHWC になっているので、その後の計算に支障が出ないよう この出力は NCHW に戻しておいてやる必要がある。ので outputs を今度は [0, 3, 1, 2] でtransposeする。

そして、元々の outputs を受け取っていた 次の operation たちの入力を この新しい operation の outputs に置き換えてやる。

PrevOp ---> NCHW ---------> Op(NCHW) -----------NCHW-----> NextOp
        |                                              |
         -> NCHW to NHWC -> Op(NHWC) -> NHWC to NCHW --

元々上段の流れだけだったものに対して、下段のルートを付け足した形になる。

最後に、元々あった NCHW の operation を graph から消しておく。

# Delete old nodes
graph_def = graph.as_graph_def()
for target_op in target_ops:
    graph_def.node.remove(target_op.node_def)

これで、一度 SavedModel に graph を保存してみよう。

inputs = graph.get_tensor_by_name(f'Gs/{Gs.input_names[0]}:0')
outputs = graph.get_tensor_by_name(f'Gs/{Gs.output_names[0]}:0')

tf.compat.v1.enable_resource_variables()
tf.compat.v1.saved_model.simple_save(
    tf.get_default_session(),
    './savedmodel',
    {'inputs': inputs},
    {'outputs': outputs},
)

これを load して実行してみると…

import tensorflow as tf

model = tf.compat.v2.saved_model.load('./savedmodel')
generate = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

z = tf.random.normal([1, 512])
outputs = generate(inputs=z)['outputs']
with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.initializers.tables_initializer())
    print(sess.run(outputs))
2020-02-04 23:24:38.004217: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-02-04 23:24:38.014420: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fbdbabd65a0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-02-04 23:24:38.014438: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
[[[[0.65079993 0.7005807  0.71635807 ... 0.70904994 0.672469
    0.65049875]
   [0.7084702  0.73016286 0.7354313  ... 0.74900943 0.7347464
    0.7381906 ]
   [0.71451813 0.7372785  0.73747885 ... 0.75236607 0.7619566
    0.7417464 ]
   ...

何らかの数値が出力された!

これはやはり NCHW(1, 3, 256, 256) のような形で来ているので、RGBの画像として Pillow などで扱うにはやはり NHWC に transpose したりといった処理は必要になる。

import tensorflow as tf
from PIL import Image

model = tf.compat.v2.saved_model.load('./savedmodel')
generate = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

z = tf.random.normal([1, 512])
outputs = generate(inputs=z)['outputs']
outputs = tf.transpose(outputs, [0, 2, 3, 1])
outputs = tf.saturate_cast((outputs + 1.0) * 127.5, tf.uint8)
with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.initializers.tables_initializer())
    img = Image.fromarray(sess.run(outputs)[0])
img.save('output.png')

f:id:sugyan:20200204233514p:plain

CPU環境でも生成が出来た! やったー!!

2. Modelを構築し直す

標準的な畳み込みを使ったネットワークであれば ここまでの変換で問題なく動くようになるかもしれない。が、StyleGAN2の場合はまだちょっとだけ問題が残っている。

単一の画像生成なら上述のように出来るけど、 batch_size1 より大きくすると、またエラーになってしまう。

z = tf.random.normal([2, 512])
...

tensorflow.python.framework.errors_impl.UnimplementedError: [_Derived_]{{function_node __inference_pruned_14657}} {{function_node __inference_pruned_14657}} The Conv2D op currently does not support grouped convolutions on the CPU. A grouped convolution was attempted to be run because the input depth of 1024 does not match the filter input depth of 512
         [[{{node Gs/G_synthesis/4x4/Conv/Conv2D_nhwc}}]]
         [[StatefulPartitionedCall_1]]

grouped convolution というものが使われていて、これがまた CPU環境ではまだサポートされていないものだった。

このへんかな?

TensorFlow 1.14 あたりから入ったもののようだ。 普通は filter のshapeは [filter_height, filter_width, in_channels, out_channels] と なっていて、inputchannelfilter2 番目の次元と等しくなければならないが、cuDNNによって filter.shape[2] の倍数であればまとめて計算できるようになる、という機能… なのかな。(よく分かっていない)

これがどうやら networks_stylegan2.modulated_conv2d_layer() の中で fused_modconv=True のときにそういった処理をするようになっているらしい。

networks_stylegan2.py 抜粋:

    # Reshape/scale input.
    if fused_modconv:
        x = tf.reshape(x, [1, -1, x.shape[2], x.shape[3]]) # Fused => reshape minibatch to convolution groups.
        w = tf.reshape(tf.transpose(ww, [1, 2, 3, 0, 4]), [ww.shape[1], ww.shape[2], ww.shape[3], -1])
    else:
        x *= tf.cast(s[:, :, np.newaxis, np.newaxis], x.dtype) # [BIhw] Not fused => scale input activations.

    # Convolution with optional up/downsampling.
    if up:
        x = upsample_conv_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
    elif down:
        x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
    else:
        x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME')

    # Reshape/scale output.
    if fused_modconv:
        x = tf.reshape(x, [-1, fmaps, x.shape[2], x.shape[3]]) # Fused => reshape convolution groups back to minibatch.

up/downsampling の処理をかける前に xw をreshapeして、その結果を後でまたreshapeして元に戻す、といった形になっているようだ。

これは単一の operation の書き換えでは対応できない…。


ということで、結局 graph の書き換えだけでは無理そうなので 「重みだけ再利用してModelはCPUでも動かせる形に構築し直す」という方針を取ることにした。

まずは .pkl をloadした後、 checkpoint 形式で 変数の値だけを保存する。

import pickle
import tensorflow as tf
from dnnlib import tflib

tflib.init_tf()
with open('network.pkl', 'rb') as fp:
    _, _, Gs = pickle.load(fp, encoding='latin1')

saver = tf.compat.v1.train.Saver(Gs.vars)
saver.save(tf.get_default_session(), './ckpt/network')
$ ls -l ./ckpt
total 240176
-rw-r--r--  1 sugyan  staff         71 Feb  4 23:45 checkpoint
-rw-r--r--  1 sugyan  staff  120838348 Feb  4 23:45 network.data-00000-of-00001
-rw-r--r--  1 sugyan  staff       4898 Feb  4 23:45 network.index
-rw-r--r--  1 sugyan  staff    2114457 Feb  4 23:45 network.meta

で、 fused_modconv=False になるような設定で Generator を作成する。 どうせ graph を構築しなおすことになるのだし、前述したような NCHW -> NHWC の書き換えもコード上でやってしまおう。

Gs/ に関係するところだけなら 以下の3箇所の tf.nn.conv2dtf.nn.conv2d_transpose の入出力まわりを書き換えてやれば大丈夫だ。

diff --git a/dnnlib/tflib/ops/upfirdn_2d.py b/dnnlib/tflib/ops/upfirdn_2d.py
index fd23777..26cc573 100755
--- a/dnnlib/tflib/ops/upfirdn_2d.py
+++ b/dnnlib/tflib/ops/upfirdn_2d.py
@@ -93,7 +93,9 @@ def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
     x = tf.transpose(x, [0, 3, 1, 2])
     x = tf.reshape(x, [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1])
     w = tf.constant(k[::-1, ::-1, np.newaxis, np.newaxis], dtype=x.dtype)
-    x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NCHW')
+    x = tf.transpose(x, [0, 2, 3, 1])
+    x = tf.nn.conv2d(x, w, strides=[1,1,1,1], padding='VALID', data_format='NHWC')
+    x = tf.transpose(x, [0, 3, 1, 2])
     x = tf.reshape(x, [-1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, inW * upx + padx0 + padx1 - kernelW + 1])
     x = tf.transpose(x, [0, 2, 3, 1])

@@ -288,7 +290,11 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NCHW', impl='c
     w = tf.reshape(w, [convH, convW, -1, num_groups * inC])

     # Execute.
-    x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format=data_format)
+    x = tf.transpose(x, [0, 2, 3, 1])
+    stride = [1, factor, factor, 1]
+    output_shape = [_shape(x, 0), (_shape(x, 1) - 1) * factor + convH, (_shape(x, 2) - 1) * factor + convW, outC]
+    x = tf.nn.conv2d_transpose(x, w, output_shape=output_shape, strides=stride, padding='VALID', data_format='NHWC')
+    x = tf.transpose(x, [0, 3, 1, 2])
     return _simple_upfirdn_2d(x, k, pad0=(p+1)//2+factor-1, pad1=p//2+1, data_format=data_format, impl=impl)

 #----------------------------------------------------------------------------
diff --git a/training/networks_stylegan2.py b/training/networks_stylegan2.py
index 6c96fc1..8fe2979 100755
--- a/training/networks_stylegan2.py
+++ b/training/networks_stylegan2.py
@@ -117,7 +117,9 @@ def modulated_conv2d_layer(x, y, fmaps, kernel, up=False, down=False, demodulate
     elif down:
         x = conv_downsample_2d(x, tf.cast(w, x.dtype), data_format='NCHW', k=resample_kernel)
     else:
-        x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NCHW', strides=[1,1,1,1], padding='SAME')
+        x = tf.transpose(x, [0, 2, 3, 1])
+        x = tf.nn.conv2d(x, tf.cast(w, x.dtype), data_format='NHWC', strides=[1,1,1,1], padding='SAME')
+        x = tf.transpose(x, [0, 3, 1, 2])

     # Reshape/scale output.
     if fused_modconv:

で、 Generator を作成し、保存しておいた変数を checkpoint ファイルからloadする。

import tensorflow as tf
from dnnlib import tflib
from dnnlib import EasyDict

tflib.init_tf()
G_args = EasyDict(func_name='training.networks_stylegan2.G_main')
G_args.fused_modconv = False
G = tflib.Network(
    'G',
    num_channels=3,
    resolution=256,
    **G_args)
Gs = G.clone('Gs')

saver = tf.compat.v1.train.Saver(Gs.vars)
saver.restore(tf.get_default_session(), 'ckpt/network')

graph = tf.get_default_graph()
inputs = graph.get_tensor_by_name(f'Gs/{Gs.input_names[0]}:0')
outputs = graph.get_tensor_by_name(f'Gs/{Gs.output_names[0]}:0')
tf.compat.v1.enable_resource_variables()
tf.compat.v1.saved_model.simple_save(
    tf.get_default_session(),
    './savedmodel',
    {'inputs': inputs},
    {'outputs': outputs}
)

これで今度は NCHW の operation も fused_modconv による grouped convolution も含まない SavedModel が出力された、はず。

再度 batch_size > 1 で生成をしてみよう。

model = tf.compat.v2.saved_model.load('./savedmodel')
generate = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

z = tf.random.normal([3, 512])
outputs = generate(inputs=z)['outputs']
outputs = tf.transpose(outputs, [0, 2, 3, 1])
outputs = tf.saturate_cast((outputs + 1.0) * 127.5, tf.uint8)
with tf.compat.v1.Session() as sess:
    sess.run(tf.compat.v1.initializers.tables_initializer())
    img = Image.fromarray(np.concatenate(sess.run(outputs), axis=1))
img.save('output.png')

f:id:sugyan:20200205232314p:plain

ちゃんと3つ画像が生成された! やったー!!

TensorFlow.jsで動かす

さて、ここまで出来ているのならば特殊な operation なども使っていないはずだろうし TensorFlow.js の GraphModel に変換できるだろう、と思うわけです。

$ tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_format=tfjs_graph_model \
    --signature_name=serving_default \
    --saved_model_tags=serve \
    --skip_op_check \
    ./savedmodel ./tfjs

やはり RandomStandardNormal の operation はサポートされていないので これがある場合は --skip_op_check optionが必要になる。

ちなみに random入力を使っている場所は randomize_noise というパラメータで制御されていて、 Generator の作成時にこれを fused_modconv と同様に G_args.randomize_noise = False と指定しておくとこの operation は使われなくなって、 --skip_op_check でskipする必要が無くなる。

ransomize_noise がどれくらい出力の品質に影響あるかちょっとよく分かっていないけど、問題ないようなら False にしておけば計算の負荷も減るし良いかもしれない。

ともかく、変換した GraphModel を読み込んで実行してみると…

const url = '.../tfjs/model.json'
const randomNormal = (node) => {
    return tf.randomNormal(node.inputs[0].dataSync())
};
tf.registerOp('RandomStandardNormal', randomNormal);
tf.loadGraphModel(url).then((model) => {
    tf.tidy(() => {
        const z = tf.randomNormal([1, 512])
        const results = model.execute(z)
        console.log(results.shape)
    })
}).catch((err) => {
    console.error(err);
})
webgl_util.ts:110 Uncaught Error: Failed to compile fragment shader.

やはり webgl backend ではエラーが出てしまう。 試しに tf.setBackend('cpu') にしてみると、ものすごい時間はかかるが 一応実行は可能なようだ。しかし現実的ではない。

何故 webgl backend では上手くいかないのか… とひたすら地道に探っていたところ、一つ 特殊な箇所を発見した。

upfirdn_2d の reference implementation で、Upsampleするために Rank 6Tensorreshape してから pad を行っている場所がある。

upfirdn_2d.py 抜粋:

def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):
    """Slow reference implementation of `upfirdn_2d()` using standard TensorFlow ops."""

    ...

    # Upsample (insert zeros).
    x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
    x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
    x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])

    ...

この関数の処理の詳細は理解できていないけど、ともかくこの箇所では 0 で padding することにより Tensor のサイズを大きくしていることだけは分かる。

ところで TensorFlow.js の tf.padAPI reference を見てみると…

Also available are stricter rank-specific methods with the same signature as this method that assert that paddings is of given length.

  • tf.pad1d
  • tf.pad2d
  • tf.pad3d
  • tf.pad4d

…これ、Rank 5 以上のものには対応していないのでは!?

軽く tensorflow/tfjs のコードを見てみたがちょっと分からず… しかしまぁ Rank 4 までしか対応していない、というのは実に有り得る気がする。

ということで 前述の Rank 6Tensor に対する pad の回避するよう処理を書き換えてみることにした。

とはいえ data_format のときのように transpose すれば良いというものでもなさそうだし ちょっとどうすれば良いか分からない…。

が、幸い ここで Upsample するためのパラメータ upy, upx はどうやら 1 もしくは 2 の値でしか渡されてこないらしい、ということが分かった。 1 のときは結局 pad すべき shape は 0 になるので、何もせずに skip してしまえば良い。 2 のときだけ どうにか2箇所だけ shape を増やしてあげる必要がある…。

いや、待てよこれは 1列分だけ 0 padding するだけなのだから、同じ shape の zeros Tensor を後ろから concat してやれば同じ意味になるのでは? と思いついた。

diff --git a/dnnlib/tflib/ops/upfirdn_2d.py b/dnnlib/tflib/ops/upfirdn_2d.py
index fd23777..49d11ed 100755
--- a/dnnlib/tflib/ops/upfirdn_2d.py
+++ b/dnnlib/tflib/ops/upfirdn_2d.py
@@ -82,7 +82,10 @@ def _upfirdn_2d_ref(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1):

     # Upsample (insert zeros).
     x = tf.reshape(x, [-1, inH, 1, inW, 1, minorDim])
-    x = tf.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]])
+    if upy == 2:
+        x = tf.concat([x, tf.zeros(tf.shape(x))], axis=2)
+    if upx == 2:
+        x = tf.concat([x, tf.zeros(tf.shape(x))], axis=4)
     x = tf.reshape(x, [-1, inH * upy, inW * upx, minorDim])

     # Pad (crop if negative).

このように書き換える。

これでまた Generator を作成しなおして SavedModel に保存して GraphModel に変換して…。

f:id:sugyan:20200206005109p:plain

TensorFlow.js でも 動いた!! やったー!!

やはり webgl backend でのエラーは tf.pad が Rank 4 までしか対応していなかったことが原因だったようだ…。いやーまさかこんな方法で解決できるとは。。

実際のところ、 TensorFlow.js での生成を 256x256 サイズで試した感じでは 最初の実行時に 10秒弱かかるが、その後は 数十 ms くらいで計算できてそう。そこから計算結果のデータを取得して、という部分で 3000 ms くらいかかってしまっているが…。

とりあえず 計算はWebWorkerに任せる などして、描画とかUIのところだけ作っていけば、 誰でもブラウザ上で画像生成を試せるようになる、かも…!? という希望は見えた。

ここまでの変更は一応 ここに残しておく。