TensorFlowのtf.contrib.layers.batch_normの注意点

TensorFlowのtf.contrib.layers.batch_normは他のtf.contrib.layersの関数(conv2dfully_connectedなど)にnormalizer_fnパラメーターで渡すことができ、Batch Normalizationを含めひとつの関数呼び出しにまとめることができるのでコードがスッキリして便利です。

output = tf.contrib.layers.conv2d(
    inputs,
    num_outputs=128,
    kernel_size=4,
    stride=2,
    normalizer_fn=tf.contrib.layers.batch_norm,
    normalizer_params=params,
    trainable=True)

しかし、このbatch_normやtf.contrib.layersの関数にはいくつか重要な注意点があります。

batch_normがtf.GraphKeys.UPDATE_OPSに登録したopsの更新処理を忘れないようにする

batch_normのmoving_meanやmoving_varianceの更新に関するopsはtf.GraphKeys.TRAINABLE_VARIABLESではなくtf.GraphKeys.UPDATE_OPSに登録されます。このUPDATE_OPSに登録されたopsは、以下のようにoptimizerなどが走る前に明示的に更新する必要があります。

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
  train_op = optimizer.minimize(loss)

ドキュメントにも以下のような注意書きがありますが、青色のセクションに書かれているので、NoteというよりはFYIという風に見えて見逃しやすいので気をつけましょう。

Note : when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in tf.GraphKeys.UPDATE_OPS, so they need to be added as a dependency to the train_op. For example: (コード例に続く…)

tf.contrib.layers.batch_normのis_trainingパラメーターの設定を忘れないようにする

他のtf.contrib.layersの関数と違い、tf.contrib.layers.batch_normにはtrainableだけでなくis_trainingパラメーターもあります。

batch_norm(
    inputs,
    is_training=True,
    trainable=True,
    ...)

trainableはbatch_normの内部にあるvariablesをGraphKeys.TRAINABLE_VARIABLESに登録するかどうかを制御するためのboolパラメーターなのに対して、is_trainingはmoving_meanやmoving_varianceの挙動に関するboolパラメーターです。どちらもデフォルトでTrueですが、学習以外の時は明示的にFalseを設定する必要があります。特にis_trainingがTrueのままの場合、同じ入力に対してbatch_normが毎回違う出力をしてしまい、モデルの再現性がなくなる場合があるので注意が必要です。

tf.contrib.layersの関数にnormalizer_fn=tf.contrib.layers.batch_normを指定する場合はnormalizer_paramsも設定する

冒頭にも紹介したように、tf.contrib.layersの関数はnormalizer_fnパラメーターにbatch_normを指定することで、レイヤーの計算後にBatch Normalizationを適用してくれます。しかし、normalizer_fnはデフォルトでinputsパラメーターだけ指定して呼び出され、他のパラメーターはデフォルト引数が使われます。

# 期待した通りに動かない例
# conv2d自体にtrainable=Falseを指定してもbatch_normには渡されず、
# batch_normはデフォルト引数で呼ばれる(is_training=True, trainable=True)
is_training = False
output = tf.contrib.layers.conv2d(
    inputs,
    num_outputs=128,
    kernel_size=4,
    stride=2,
    normalizer_fn=tf.contrib.layers.batch_norm,
    trainable=is_training)

この問題は、normalizer_fnを指定したtf.contrib.layersの関数のnormalizer_paramsパラメーターを使うことで解決できます。

# 正しく動く例
# batch_normにis_training=Falseとtrainable=Falseが渡される
is_training = False
normalizer_params = {'is_training': is_training, 'trainable': is_training}
output = tf.contrib.layers.conv2d(
    inputs,
    num_outputs=128,
    kernel_size=4,
    stride=2,
    normalizer_fn=tf.contrib.layers.batch_norm,
    # batch_norm用のパラメーターを指定する
    normalizer_params=normalizer_params,
    trainable=is_training)

この挙動はドキュメントからだと分かりにくいので注意しましょう。