敵対的生成ネットワークをまじめに勉強して実装してみた2 ~CycleGAN~

前回、概念から数学的なことも含めてGANについて説明しました。 今回はCycleGANにフォーカスを当てます。 f:id:IKEP:20191022103139p:plain CycleGANの論文より

CycleGANとは

画像から画像への変換が可能なGANの一種です。 こういった画像to画像の変換にはpix2pixのように、データセット画像がペアになっているのが普通です。 「入力画像に対して画像変換を行うとこういう画像になる」という風に、入力画像群と変換目標画像群をペアにして与えます。

しかし下の画像のようなペアではないものでも、CycleGANは画像間のスタイル変換が行えます。 なので、上の画像のように現実的にペアの画像を用意することが難しいものが対象でも画像変換が可能です。 (馬とシマウマを全く同じ構図、背景、姿勢で写真を取るのはほぼ無理ですよね?)

f:id:IKEP:20191022104223p:plain:w300

工夫点

ペアの画像の場合は入力画像と変換目標の画像間のピクセルごとの差分を利用する(MSEやMAE)などで、学習を進めることができます。 しかし、今回はペアでないものを対象とするための工夫がなされています。 それは、2つのGenerator(G, F)を用意し、Gで変換したものをFで変換した時に、元の画像に戻るかという設計です。 ドメインXの画像群の画像xをGによってドメインYの画像群のような画像y^に変換。 その後、y^をFでXへと変換した画像x^が、xとどれだけ異なっているか利用しています (数学的にいうと、X→Yの写像GとY→Xの写像Fがあり、Xの画像xを入力した時、x-F(G(x))の値を小さくする、ということ)。 これが、Cycle Consistency Lossと言われています。

f:id:IKEP:20191022104715p:plain

損失計算

損失関数は以下のようになっています。

f:id:IKEP:20191022105557p:plain

で、Cycle Consistency Lossがどれかというと、上式の3項目です。 内容は以下

f:id:IKEP:20191022105753p:plain

画像ドメインX, YそれぞれのCycle Consistency Lossの和です。 L1ノルムなので、ただピクセル間の差分を見ているだけですね。

そして、残りの1, 2項目は以下のようになっています。

f:id:IKEP:20191022110058p:plain

GeneratorがDiscriminatorを騙せたかどうかということです。 これは、普通のGANとAdversarial Lossと同じなので、説明は割愛します。

λは、Cycle Consistency LossとAdversarial Lossの割合を決めるハイパーパラメータで、論文の実装では10で実装したとなっています。

実装してみた

てことで、内容は理解したのでChainerを使って実装してみました。 実装したコードはこちら。 データセットは公開されているものの2つ(馬↔︎シマウマりんご↔︎みかん)と、オリジナルで作成したものを利用しました。 オリジナルは、ポケモンモンスターボール↔︎マスターボールの変換です。 google-imagedownloadを使って、Google画像検索で出てくる画像を使いました。(そこから明らかに違う画像は人力排除...)

f:id:IKEP:20191022111131p:plain

f:id:IKEP:20191022111109p:plain

f:id:IKEP:20191022111142p:plain

学習結果

馬↔︎シマウマ

f:id:IKEP:20191022111335p:plain

りんご↔︎みかん

f:id:IKEP:20191022111447p:plain

モンスターボール↔︎マスターボール

f:id:IKEP:20191022111540p:plain

モンスターボール↔︎マスターボールでいい感じのやつを集めてみました↓。 特に左上がMマークが作成されてるので、ちょっとびっくりしました。

f:id:IKEP:20191022111705p:plain f:id:IKEP:20191022111714p:plain f:id:IKEP:20191022111726p:plain f:id:IKEP:20191022111745p:plain

感想

ペアデータなしでここまでの画像変換ができるのは、かなり感動です。 論文の画像を見る限りではもう少しクオリティが高そうだったので調べると、今回実装したコードは学習安定化のテクニックを入れていなかったので、それが原因かなと思います。 でも、楽しむ分にはなかなかな精度が出てるのでいいんじゃないでしょうか(^ ^)

参考

CycleGAN - Qiita

"CycleGAN"の論文解説と"GAN"の補足

[DL輪読会]Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks