CCCマーケティング TECH Labの Tech Blog

TECH Labスタッフによる格闘記録やマーケティング界隈についての記事など

【完結編?】DCGANを使って"お寿司"の画像を生成できるように色々試した話。

こんにちは、技術開発ユニットの三浦です。

今年度に入り、新しく始めたことに「オンライン英会話」があります。朝仕事を始める前に、短い時間ですが外国の方と話してみています。英会話にチャレンジするのが人生初の経験なので、最初は緊張してちょっとしんどかったのですが、最近は少しずつ慣れてきました。習慣化出来るようにすることが、今の目標です。

さて、以前GANs(Generative adversarial networks:敵対的生成ネットワーク)の記事を書き、DCGANが上手くいかない話を紹介しました。

techblog.cccmk.co.jp

あれから少しアップデートがありましたので、今回の記事はその内容について、ご紹介させていただきます。

直近の生成"お寿司"画像

直近、DCGANで生成した"お寿司"画像はこんな感じです。

直近の成果です!

・・・いかがでしょうか。細かいところはグチャっとしているものの、なんとなく"料理"っぽい雰囲気、出てきてると思います。ちなみにこのデータの元になっている、学習用の画像をいくつか表示してみます。

学習用データセットからのサンプル(The Food-101 Data Set より)

このように、"お寿司"という同じカテゴリの中に、色々なタイプの"お寿司"が含まれていることが分かります。生成されたデータは、それらがごちゃまぜになっているような印象を受けます。その中でも、たとえば下のような画像はなんとなく「あの"お寿司を"表現しているのかな?」といった印象を持ちました。

お品書きっぽく

少なくとも前回の生成画像よりは「何らかの意味」を持った画像が生成できるようになったと思います。ただ、ここに至るまで、けっこう色々な試行錯誤が必要でした。特に実験環境の強化, データ前処理の追加, 学習方法の見直しネットワーク構造の見直し, 学習状況の可視化といった作業が生成画像の改善につながったと思います。

以下、もう少し詳しく説明します。

実行した改善策

実験環境の強化

前回は「Google Colab Pro」で学習処理を実行していたのですが、CCCマーケティングが分析環境として導入している、Azure Databricksに今回の実験環境を移行しました。実験に必要となるコンピュータリソースの作成や環境構築、ストレージへのアクセス、そしてnotebookの開発実行環境がAzure Databricks上に用意されています。コンピュータリソースはGPUを搭載したものも作成することが出来ます。

Azure Databricksの画面

プライベートでちょっとした実験を行うのであればGoogle Colabが最適なのですが、今回のテーマであるDCGANは「ちょっとした実験」で扱うには少し重たいテーマです。実際、1日以上継続して学習処理を行う必要もありました。

実験環境の強化を行う前と後で、同じ処理を10epoch実行した時の累積実行時間は以下のようになりました。

10epoch実行後の累積実行時間(秒)の比較

かなりの時間を圧縮できたことが分かります。これにより、トライ&エラーを何度も行うことが出来るようになりました。

データ前処理の追加

今回使用しているデータセットfood101は、1クラスあたりの画像枚数が1,000枚であることに気が付き、これでは学習データとして不足しているのかも、と考えました。そこでデータにバラエティを持たせるため、ランダムに変形処理を加え、データを増強する手法「Data Augmentation」を取り入れました。

以下のように、KerasのSequentialモデルで前処理用レイヤを作りました。このレイヤを通過すると、ランダムで「水平方向の反転」「画像の回転」「画像の拡大」といった変形が施されます。(notebookで表示する際は、Rescalingの部分を変更し、各ピクセルの値が0~1に収まるようにしています。)

preprocess = tf.keras.Sequential([
  tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),
  tf.keras.layers.Rescaling(scale=1./127.5, offset=-1),
  tf.keras.layers.RandomFlip("horizontal"),
  tf.keras.layers.RandomRotation(0.2),
  tf.keras.layers.RandomZoom(.2, .2),
])

# Datasetの各要素のimageにpreprocessを適用する
train_data = data.map(lambda x:{'image':preprocess(x['image']), 'label':x['label']})

変形後の画像は以下のようになります。同じデータを2回、前処理レイヤに通した結果です。

Data Augmentation

微妙にそれぞれの画像に変化が加えられています。

他にも変形処理は色々あるのですが、大きな変形を加えると生成画像の質に悪影響を与えてしまうのでは、と考え、最小限の変形に留めました。

学習方法の見直し

色々な方法を試し、その都度学習の様子を学習曲線で見ていると、「discriminatorが強すぎて、generatorが育っていないのでは?」と考えるようになりました。

学習曲線

上記の学習曲線のように、discriminatorgenerator双方のlossが一定の値に収束してしまうと、以降どれだけ処理を回してもgeneratorが生成する画像が前回のような一様な画像になってしまいます。

どうしてdiscriminatorが強くなってしまうのか。その原因は恐らく学習処理にあるのでは、と考えました。前回失敗した時は、1回の学習ループの中で以下のような処理を行っていました。それぞれbatch_sizeのデータで実行しています。

  1. リアル画像とラベル(リアル)を使って、discriminatorを学習
  2. フェイク画像とラベル(フェイク)を使って、discriminatorを学習
  3. フェイク画像とラベル(リアル)を使って、generatorを学習

つまり、discriminatorの方が1回多く学習処理を行っていました。ここで強弱が生まれたと考え、generatorを1ループの中で2回、学習させるように変更しました。

また、batch_sizeを64から128に変更しました。この変更については明確な理由はなく、この値を採用しているという意見が、私が調べた範囲では多かったからです。

ネットワーク構造の見直し

DCGANの学習が進まない原因を調べていると、Stack Overflowで以下のような質問が見つかりました。

stackoverflow.com

DCGANの論文通りにモデルを作ったのに、モデルの学習が一向に進まないという内容で、まさに私が直面している問題と同じようです。最終的にdiscriminatorからBatchNormalizationレイヤを取り除くと上手くいった、と書かれていて、私の場合も同様に取り除くことで、少なくともlossや生成画像が一定のものしか出力されないという事態は脱することが出来ました。

他にも調整を行い、ネットワークの構造やハイパーパラメータの設定などを最終的に以下のようにしました。

  • discriminatorからBatchNormalizationレイヤを全て取り除く
  • 各種weightを平均0, 標準偏差0.02の正規分布による乱数で初期化
  • generatorBatchNormalizationレイヤのmomentumパラメータの値を0.9に変更(Kerasのデフォルトは0.99)
  • discriminatorDropoutレイヤを追加
  • Adam Optimizerのlearning_lateを2e-4に、beta_1を0.5に変更

学習状況の可視化

処理の途中で1epochごとのlossaccuracyの値を確認したり、モデルがどのような画像を生成するのかを確認できるよう、TensorBoardを利用しました。TensorBoardでは以下のように学習途中で各値をグラフで確認したり、生成画像を埋め込んだplotを見ることが出来ます。

学習曲線の表示

生成データの確認

この結果を見て、あまり芳しくない傾向が見られたら早めに処理を中断し、調整にうつるようにしました。

生成画像の品質を上げるためには・・・?

以上のように改善を施した結果、確かに"料理"っぽい、色々な画像が生成出来るようになりました。しかし、生成される"お寿司"の画像の質は、まだ良いとは言えません。

この原因として考えているのが、先ほどお見せしたように、データセットに含まれる"お寿司"の画像の種類が多く複雑なので、フェイク画像が上手く作れないのではないか、ということです。データを見ていると握り寿司や巻き寿司、カリフォルニアロール、さらには人物まで含まれていて、その中から共通する潜在的な特徴を抽出することは難しそうです。

またDCGANは安定した学習が難しいと言われているように、同じ設定を使っているにも関わらず、あるデータでは学習が進んだのに他のデータではまったく学習が進まない、といったこともありました。

さらにearly_stoppingのような機械的に学習を止める指標もないため、「どこまで学習させたらいいんだろう」という点も、実際に生成される画像を見ながら判断しました。

この辺りは後継のGANsのシリーズで改善が図られているようで、今後は別のGANsを使ってどれだけ効率的に、高品質な画像が生成できるかを見ていきたいと思います。

まとめ

ということで、DCGANの実験について、前回から色々と対策を取ることで、"料理"っぽい画像を生成出来るところまでは改善することが出来ました。GANsはとても面白い領域で、実験をしながらモデルが次はどんな画像を生成するんだろう、とワクワクしながら観察していました

先にも述べたように、GANsは新しいタイプのものがいくつも発表されているので、それらを使うとどんなデータが生成できるのか、これからも継続して見ていきたいと思います!

枝豆の画像も作ってみました

最後に・・・

参考までに、今回作ったKerasでDCGANを実装したコードを掲載させて頂きます。少しでも同じようにお困りの方の参考になれば幸いです。

import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Input, Dense, Reshape, Conv2D, Conv2DTranspose, Flatten, Dropout, ReLU, LeakyReLU, BatchNormalization, Activation 
from tensorflow.keras.activations import tanh
from tensorflow.keras.initializers import random_normal
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import TensorBoard
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

import io
import os
import time

# データの読み込み処理
data = tfds.load('food101',data_dir='', download=False)
t_data = data['train']
v_data = data['validation']
data = t_data.concatenate(v_data).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
data = data.filter(lambda x:x['label']==95)

# 前処理レイヤの作成
IMG_SIZE = 64

# 実際に学習データに適用する方
preprocess = tf.keras.Sequential([
  tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),
  tf.keras.layers.Rescaling(scale=1./127.5, offset=-1),
  tf.keras.layers.RandomFlip("horizontal"),
  tf.keras.layers.RandomRotation(0.2),
  tf.keras.layers.RandomZoom(.2, .2),
])

# notebookで変形処理を確認する用途に使う
# (Scalingだけ異なる)
check_preprocess = tf.keras.Sequential([
  tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),
  tf.keras.layers.Rescaling(scale=1./255.),
  tf.keras.layers.RandomFlip("horizontal"),
  tf.keras.layers.RandomRotation(0.2),
  tf.keras.layers.RandomZoom(.2, .2),
  #tf.keras.layers.Rescaling(scale=1./127.5, offset=-1)
])

# 学習データ
train_data = data.map(lambda x:{'image':preprocess(x['image']), 'label':x['label']})

# Data Augmentation確認用
check_data = data.map(lambda x:{'image':check_preprocess(x['image'],training=True), 'label':x['label']})

# Data Augmentationの様子を確認する用
row = 5
col = 5
plt.figure(figsize=(20,20))

num = 0

for data in check_data:
  data = data['image']
  if num == row * col:
    break
  
  num += 1
  plt.subplot(row, col, num)
  plt.imshow(data)
  plt.axis('off')

# generator構築用の関数
momentum = 0.9
def build_generator():
    input = Input(shape=(z_dim))
    x = Dense(4 * 4 * 1024)(input)
    x = Reshape((4 ,4 ,1024))(x)

    #8 * 8 * 512
    x = Conv2DTranspose(
            filters=512, 
            kernel_size=4,
            kernel_initializer=random_normal(mean=0., stddev=0.02),
            strides=2, 
            padding='same',
            use_bias=False
        )(x)
    x = BatchNormalization(momentum=momentum)(x)
    x = ReLU()(x)

    #16 * 16 * 256
    x = Conv2DTranspose(
            filters=256, 
            kernel_size=4, 
            kernel_initializer=random_normal(mean=0., stddev=0.02),
            strides=2, 
            padding='same',
            use_bias=False
        )(x)
    x = BatchNormalization(momentum=momentum)(x)
    x = ReLU()(x)

    #32 * 32 * 128
    x = Conv2DTranspose(
            filters=128, 
            kernel_size=4, 
            kernel_initializer=random_normal(mean=0., stddev=0.02),
            strides=2, 
            padding='same',
            use_bias=False
        )(x)
    x = BatchNormalization(momentum=momentum)(x)
    x = ReLU()(x)

    #64 * 64 * 3
    x = Conv2DTranspose(
            filters=3, 
            kernel_size=4, 
            kernel_initializer=random_normal(mean=0., stddev=0.02),
            strides=2, 
            padding='same',
            use_bias=False
        )(x)
    output = Activation(tanh)(x)
    model = Model(inputs=input, outputs=output)
    return model

# discriminator構築用の関数
drop_out = 0.4
alpha = 0.2
img_size = (64, 64, 3)
def build_discriminator():
    input = Input(shape=img_size)
    #64 * 64 * 3 → 32 * 32 * 128
    x = Conv2D(
        filters=128,
        kernel_size=4,
        kernel_initializer=random_normal(mean=0., stddev=0.02),
        strides=2,
        padding='same',
        use_bias=False
    )(input)
    x = LeakyReLU(alpha=alpha)(x)
    #x = ReLU()(x)
    x = Dropout(drop_out)(x)
    #32 * 32 * 128 → 16 * 16 * 256
    x = Conv2D(
        filters=256,
        kernel_size=4,
        kernel_initializer=random_normal(mean=0., stddev=0.02),
        strides=2,
        padding='same',
        use_bias=False
    )(x)
    x = Dropout(drop_out)(x)
    #x = BatchNormalization(momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)

    #16*16*256 → 8*8*512
    x = Conv2D(
        filters=512,
        kernel_size=4,
        kernel_initializer=random_normal(mean=0., stddev=0.02),
        strides=2,
        padding='same',
        use_bias=False
    )(x)
    x = Dropout(drop_out)(x)
    #x = BatchNormalization(momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)
    
    #8*8*512 → 4*4*1024
    x = Conv2D(
        filters=1024,
        kernel_size=4,
        kernel_initializer=random_normal(mean=0., stddev=0.02),
        strides=2,
        padding='same',
        use_bias=False
    )(x)
    x = Dropout(drop_out)(x)
    #x = BatchNormalization(momentum=momentum)(x)
    x = LeakyReLU(alpha=alpha)(x)
    
    x = Flatten()(x)
    output = Dense(1, activation='sigmoid',kernel_initializer=random_normal(mean=0., stddev=0.02),)(x)
    model = Model(inputs=input, outputs=output)
    return model

# モデルの構築
z_dim = 100

#discriminator
discriminator = build_discriminator()
discriminator.compile(
    loss='binary_crossentropy',
    optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
    metrics=['accuracy']   
)

#gan
generator = build_generator()
z = Input(z_dim)
gen_img = generator(z)
discriminator.trainable = False
gan_output = discriminator(gen_img)
gan = Model(inputs=z, outputs=gan_output)
gan.compile(
    loss='binary_crossentropy',
    optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
)

# tensorboardの起動
%load_ext tensorboard
!rm -rf logs_sushi
%tensorboard --logdir logs_sushi

# 学習処理

# 生成画像を並べたグラフを画像に変換する 
def plot_to_image(figure):
  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  plt.close(figure)
  buf.seek(0)
  image = tf.image.decode_png(buf.getvalue(), channels=4)
  image = tf.expand_dims(image, 0)
  return image

# 生成画像を並べたグラフを生成する
def write_gen_image(row, col):
  fig = plt.figure(figsize=(20,20))
  num = 0
  while num < row * col:
    num += 1
    plt.subplot(row, col, num)
    input_z = np.random.normal(0, 1, (batch_size, z_dim))
    gen_img = generator.predict(input_z)
    gen_img = (127.5 * gen_img + 127.5).astype(int)
    plt.imshow(gen_img[0])
    plt.axis('off')
  return fig

epochs = 10000
batch_size = 128

log_dir = './logs_sushi'
plot_dir = './logs_sushi/plot'

file_writer = tf.summary.create_file_writer(plot_dir)
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True)

iterations = []
losses = []
acces = []

tfboard_logs = {}
tensorboard.set_model(gan)

train_data = train_data.shuffle(900, reshuffle_each_iteration=True).batch(batch_size,drop_remainder=True)
real_label = np.ones((batch_size, 1))
fake_label = np.zeros((batch_size, 1))
fix_input_z = np.random.normal(0, 1, (1, z_dim))

for epoch in range(epochs):

  for real_img in train_data:
    real_img = real_img['image']
    input_z = np.random.normal(0, 1, (batch_size, z_dim))
    fake_img = generator.predict(input_z)

    # discriminatorの学習
    d_loss_real = discriminator.train_on_batch(real_img, real_label)
    d_loss_fake = discriminator.train_on_batch(fake_img, fake_label)
    d_loss = (d_loss_real[0] + d_loss_fake[0]) * 0.5
    d_acc = (d_loss_real[1] + d_loss_fake[1]) * 0.5
    
    # generatorの学習
    g_loss = 0
    for _ in range(2):
      input_z = np.random.normal(0, 1, (batch_size, z_dim))
      g_loss += gan.train_on_batch(input_z, real_label)
    g_loss *= 0.5
  
  # tensorboardへの書き込み
  fig = write_gen_image(5, 5)
  tfboard_logs = {'d_loss':d_loss, 'd_acc':d_acc, 'g_loss':g_loss}
  tensorboard.on_epoch_end(epoch, tfboard_logs)

  with file_writer.as_default():
    tf.summary.image("generate_data", plot_to_image(fig), step=epoch)