CCCマーケティング TECH Labの Tech Blog

TECH Labスタッフによる格闘記録やマーケティング界隈についての記事など

深層学習フレームワーク「PyTorch Lightning」に入門してみました!

こんにちは、技術開発の三浦です。

今年はセミの鳴き声が聞こえないな、と思っていたのですが、最近になって聞こえるようになってきました。自分の中ではセミの鳴き声がしたら夏が始まったんだなと感じます。

今回は深層学習のフレームワークPyTorch Lightningを使ってみたお話をしたいと思います。PyTorch Lightningを使って簡単な画像分類モデルを学習させるコードを作ってみましたので、コードを通じながらどんなフレームワークなのかをご紹介出来ればと思います。

PyTorch Lightning

PyTorch Lightningは同じく深層学習のフレームワークPyTorchのモデルを、PyTorchよりも簡潔に、分かりやすいコードで構築することが出来るフレームワークです。私は深層学習のフレームワークはTensorFlowをよく使ってきたのですが、最近PyTorchを勉強したいな、と思う機会が増えてきました。

その大きなきっかけとなったのが、以前このブログでも書いたことがあるHugging Faceです。

huggingface.co

AIコミュニティHugging Faceでは様々なタスクのモデルが提供されていて、それらをHugging Face Transformersというライブラリで利用することが出来ます。自然言語処理、特にTransformerのモデルが色々あって、「こんなことも出来るんだ」とサイトを見ているだけでとても勉強になります。

さて、このHugging Faceでは様々な深層学習のフレームワーク向けにモデルが提供されているのですが、今のところPyTorchの方がTensorFlowよりも多い印象があります。今後どうなるかは分かりませんが、以前から一つのフレームワークに依存しすぎるのはよくないなと考えており、これを機会にPyTorchの勉強をしよう、と思い至りました。そしてPyTorchの情報を色々探していたところ、目に留まったのがPyTorch Lightningでした。

www.pytorchlightning.ai

PyTorch Lightningのサイトを訪れて目にした「You do the research. Lightning will do everything else.」というフレーズがとても印象に残りました。最近マルチGPUやマルチワーカーでの学習方法を色々試していて、これにかなり苦戦していたので・・・。

PyTorch Lightningではこのフレーズが示すように、モデルの構造を考える部分とモデルの学習方法を考える部分がかなり明確に分けられているので、マルチGPU対応する場合にもモデルの構造には手を加えず、簡単な変更でマルチGPUに対応出来そうです。今の自分のニーズにぴったりとはまったので、Pytorch Lightningに入門してみよう!と思い立ちました。

PyTorch Lightningの使い方

ここからはPyTorch Lightningの使い方について学んだことを、深層学習フレームワークにおける「Hello, World!」に該当すると個人的に考えているCIFAR-10データセットの画像分類モデルの構築を通じて説明していきたいと思います。やることは大きく以下の3つに分けることが出来ます。

  • データを管理するLightningDataModuleを作る
  • モデルの構造と振る舞いを管理するLightningModuleを作る
  • モデルの学習を管理するTrainerを作る

これらの3つのモジュールが独立していて、いずれかの変更が他に影響を及ぼさないようになっています。

データを管理するLightningDataModuleを作る

PyTorchではデータそのものを表現するDatasetと、Datasetからデータをサンプリングし処理に渡すDataLoaderによってデータを取り扱います。PyTorch Lightningではさらにそれら2つをまとめたLightningDataModuleという抽象クラスがあり、この抽象クラスのメソッドを実装することで、DatasetDataLoader、その2つを結び付けることが出来るようになっています。

LightningDataModuleの以下のメソッドを実装します。

  • prepare_data
    データのダウンロードなどの処理を行います。メンバ変数の設定などは、ここでは行いません。

  • setup
    データの加工はこのメソッドの中で行います。stageという引数を受け取りますが、この引数の値によってfit(学習)やtest(テスト)といった、何の用途でデータが呼び出されたのかを判断し、それぞれに応じたデータを準備することが出来ます。

  • train_dataloader/val_dataloader/test_dataloader
    それぞれ学習用、検証用、テスト用のDataLoaderを返すメソッドです。これを実装しておけば、Pytorch Lightning側で該当するフェーズで適切なDataLoaderを呼び出してくれます。

LightningDataModuleを継承した、CIFAR10DataModuleを以下のように実装してみました。

from torchvision.transforms.transforms import RandomHorizontalFlip
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader

from torchvision.datasets import CIFAR10
from torchvision import transforms

class CIFAR10DataModule(pl.LightningDataModule):
  def __init__(self, data_dir='./'):
    super().__init__()
    self.data_dir = data_dir

    # 水平方向の反転を、DataAugmentationとして設定
    self.train_transform = transforms.Compose(
        [
          transforms.RandomHorizontalFlip(p=0.5),
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )
    self.predict_transform = transforms.Compose(
        [
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )
    # CIFAR-10のラベル
    self.classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  
  def prepare_data(self):
    '''
    データをダウンロードしたりする処理
    '''
    CIFAR10(self.data_dir, train=True, download=True) # For Training
    CIFAR10(self.data_dir, train=False, download=True) # For Testing
  

  def setup(self, stage=None):
    '''
    dataloader用にダウンロードしたdatasetを加工する
    '''

    if stage == 'fit' or stage is None:
      # 学習時
      all_train_data = CIFAR10(
          self.data_dir, 
          train=True, 
          transform=self.train_transform, 
          download=False    
      )
      #学習用:40,000 検証用:10,000
      self.train_data, self.valid_data = random_split(all_train_data, [40000, 10000])
    if stage == 'test' or stage is None:
      # テスト時
      self.test_data = CIFAR10(
          self.data_dir,
          train=False,
          transform=self.predict_transform,
          download=False
      )

  def train_dataloader(self):
    return DataLoader(self.train_data, batch_size=64, shuffle=True, drop_last=True, pin_memory=True, num_workers=2)
  
  def val_dataloader(self):
    return DataLoader(self.valid_data, batch_size=64, drop_last=True, pin_memory=True, num_workers=2)
  
  def test_dataloader(self):
    return DataLoader(self.test_data, batch_size=64, drop_last=True, pin_memory=True, num_workers=2)

モデルの構造と振る舞いを管理するLightningModuleを作る

一番盛りだくさんなパートです。PyTorch LightningのドキュメントにはLightningModuleについて、「A lightning module defines a system not just a model.」と記述されているのですが、実際にこのイメージに近く、モデルを定義するだけでなく、その振る舞いやモデルをどう使うのかまでをLightningModuleで定義していきます。

LightningDataModuleと同じように、抽象クラスLightningModuleのメソッドを実装していくことで、モデルの定義から振る舞いまでを決めていくことが出来ます。ケースバイケースだと思いますが、今回実装したのは以下のメソッドです。

  • forward
    モデルの入力と出力を定義します。モデルを構築する作業は、このメソッドの中で行います。
  • training_step
    学習ステップで実行される処理を定義します。今回はforwardを呼び出し、その値と正解の値とのlossaccuracyを計算して返す、という処理を実装しました。ここで計算したlossの値が後に定義するoptimizerによって最小化されます。
  • validation_step
    特に指定がなければ、1epoch分のtraining_stepが終わると検証フェーズが実行されます。その時にこのメソッドの処理が実行されます。
  • test_step
    モデル学習後のテストフェーズで実行する処理を、このメソッドの中に記述します。
  • configure_optimizers
    Adamなどのoptimizerを設定し、返すメソッドです。今回は1つのoptimizerですが、GANsなどで複数のoptimizerを使う場合でもこのメソッドで対応出来るようです。

LightningDataModuleにはlogというメソッドが用意されていて、これにkeyとvalueの形でloggingしたい値を渡すと、デフォルトでTensorBoard形式でディレクトリにログを書き出してくれます。パスを指定していなければ、./lightning_logsに書き出されます。

LightningModuleを継承した、CIFAR-10を分類するCIFAR10Classifierを以下のように実装してみました。

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics

class CIFAR10Classifier(pl.LightningModule):
  def __init__(self, learning_rate=0.01):

    super().__init__()

    self.learning_rate = learning_rate
    # モデルで使用するレイヤなど
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3,padding='same')
    self.relu1 = nn.ReLU()
    self.pool = nn.MaxPool2d(kernel_size=2)
    self.conv2 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3,padding='same')
    self.relu2 = nn.ReLU()
    self.linear = nn.Linear(in_features=16 * 16 * 6, out_features=100)
    self.final = nn.Linear(in_features=100, out_features=10)
    self.loss = nn.CrossEntropyLoss()

    #accuracy測定用
    self.train_top1_acc = torchmetrics.Accuracy(top_k=1)
    self.train_top3_acc = torchmetrics.Accuracy(top_k=3)
    self.valid_top1_acc = torchmetrics.Accuracy(top_k=1)
    self.valid_top3_acc = torchmetrics.Accuracy(top_k=3)

  def forward(self, input):
    output = self.conv1(input)
    output = self.relu1(output)
    output = self.pool(output)
    output = self.conv2(output)
    output = self.relu2(output)
    output = output.view(-1, 16 * 16 * 6)
    output = self.linear(output)
    return self.final(output)

  def training_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x) # self.forwardが呼ばれる
    loss = self.loss(pred, y)
    top1_acc = self.train_top1_acc(pred, y)
    top3_acc = self.train_top3_acc(pred, y)
    self.log('train_loss', loss)
    self.log('train_top1_acc', top1_acc)
    self.log('train_top3_acc', top3_acc)
    return {'loss':loss, 'train_top1_accuracy':top1_acc, 'train_top3_accuracy':top3_acc}

  def configure_optimizers(self):
    return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
  
  def validation_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)
    loss = self.loss(pred, y)
    top1_acc = self.valid_top1_acc(pred, y)
    top3_acc = self.valid_top3_acc(pred, y)
    self.log('val_loss', loss)
    self.log('val_top1_acc', top1_acc)
    self.log('val_top3_acc', top3_acc)
  
  def test_step(self, batch, batch_idx):
    x, y = batch
    pred = self(x)
    top1_acc = self.valid_top1_acc(pred, y)
    top3_acc = self.valid_top3_acc(pred, y)
    self.log('test_top1_acc', top1_acc)
    self.log('test_top3_acc', top3_acc)

モデルの学習を管理するTrainerを作る

LightningDataModuleLightningModuleを結びつけ、学習/テスト/推論を実行するTrainerクラスのオブジェクトを作ります。このTrainerにはどのデバイスで実行するのか(cpu/gpu)、それをいくつ使用するのか(devices)などを指定することが出来て、恐らくここを変えることでマルチGPUにも対応出来るのだと思います。こちらはまだ試していないので、今度実際に変えて動かしてみようと思います。

また、Trainerには学習時に呼び出すcallbacksを指定することが出来ます。epoch終了時にモデルファイルを書き出したり、数ステップでモニタリング対象の指標の改善が見られなければ学習を打ち切りにする処理などを、簡単に呼び出すことが出来ます。今回はモデルファイルの書き出しを行うModelCheckpointと、検証データに対するlossをモニタリング対象にしたEaryStoppingを試してみました。

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning import Trainer

callback_checkpoint = ModelCheckpoint(dirpath='./checkpoints', save_top_k=1, monitor="val_loss")
callback_earlystopping = EarlyStopping(monitor="val_loss", patience=3)

cifar10 = CIFAR10DataModule()
cifar_classifier = CIFAR10Classifier(learning_rate=0.001)

trainer = Trainer(
    max_epochs=100,
    accelerator='gpu', 
    devices=1, 
    callbacks=[
               callback_checkpoint, 
               callback_earlystopping
               ]
)

trainer.fit(cifar_classifier, cifar10)

実行すると、以下のように学習処理が開始されます。

学習時の様子

プログレスバーなども表示されました。

LightningModuleはTensorBoard用のログを書き出してくれるので、以下のコマンドをnotebookのセル内で実行すると、TensorBoardで学習曲線を確認できます。

%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

TensorBoard

実行する度にversionを切ってくれるようです。

モデルの学習が終わったら、テストデータによりテストを行ってみます。テストフェーズでは、LightningDataModuletest_dataloaderと、LightningModuletest_stepが実行されます。

trainer.test(cifar_classifier,cifar10)

テストの実行結果

最後にモデルによる推論結果を取得してみます。Trainerpredictメソッドを実行すると、モデルの推論結果が得られます。LightningModulepredict_stepを独自に実装していなければ、forwardメソッドが実行されます。

test_sample = next(iter(cifar10.test_dataloader()))

#trainer.predictではforwardが呼ばれる
predict = trainer.predict(cifar_classifier,test_sample[0])

#top3の予測取得
pred_labels = torch.topk(predict[0],3)
print([cifar10.classes[x.numpy()] for x in pred_labels.indices[0]])

対象の画像と、それに対する予測ラベルのトップ3は「猫、カエル、犬」になりました。

テスト画像に対する推論

こちらの画像に対する正解ラベルは「」なので、モデルは正しく機能しているようです。

以上、PyTorch Lightningを使った簡単な画像分類モデル構築の流れをご紹介しました。

まとめ

今回はPyTorchをベースにした深層学習フレームワークPyTorch Lightningに入門した話をご紹介しました。実際に触ってみると、フレームワーク側で明確なテンプレートが用意されていて、それに従ってコードを実装していけばいつの間にか深層学習の一連の流れが出来ている、というところがとても面白いと感じました。このテンプレートをチームの共通認識にしておけば、お互いのコードを理解するハードルもだいぶ下がるように思います。

深層学習は色々なタスクに柔軟に対応出来る強みから、様々なモデルの構造や学習方法が現在存在しています。それらをPyTorch Lightningでどこまでカバー出来るのか、もっと知りたいと思いました。今度はGANsの実装を試してみようと思います。