敵対的生成ネットワークをまじめに勉強して実装してみた
敵対的生成ネットワーク(Generative Adversarial Network, GAN)を使うことになったので、いろいろ勉強しました。せっかくなので、備忘録として勉強したことをまとめておきたいと思います。これから、GANを勉強する方の参考にでもなればと思います。
GANとは
Generative Adversarial Network(日本語では、敵対的生成ネットワーク)と呼ばれるディープラーニングを用いた生成モデルの一種です。
GeneratorとDiscriminatorという二つのニューラルネットから構成されており、二つを敵対させながら学習を進めることで、高精度なデータを生成することができるようになります。
よく、説明される例を挙げると、偽札屋と警察の関係に似ています。偽札屋(=Generator)は警察(=Discriminator)が本物と誤判別する偽札を作る。警察は、与えられたお札が本物の紙幣なのか、偽札なのかを見分けられるように学習する。警察の判別能力が向上すると、偽札屋はより本物そっくりな偽札を作らなければならず、偽札屋が本物そっくりな偽札を作れるようになると、警察はより高度な判別能力が必要となります。これを繰り返すと、偽札屋は徐々により高度で本物そっくりな偽札を作る方法を学ぶことができます。
このように、GeneratorはDiscriminatorを騙せるように、Discriminatorは与えられたデータが本物か偽物かを正しく見分けられるようにとお互いが敵対することで、最終的にGeneratorが本物そっくりな高精度なデータを生成できるようになるというのがGANの考え方です。
上図がGANの概念図です。潜在変数(VAEなどを勉強すればわかりやすいが、ここでは乱数と思ってOK)をGeneratorに入力し、偽物データを生成。DiscriminatorはGeneratorが生成した偽物データか本物のデータを受け取り、それが本物なのか、偽物なのかを判別、判別結果を出力します。
データとして画像を扱う場合は、GeneratorとDiscriminatorは大体、畳み込みニューラルネットワーク(CNN)がベースとなっています。
GANの学習
GANの損失関数は以下の通りです。
Dは0~1を返すDiscriminatorモデルの出力、GはGeneratorの出力(偽物生成データ)、zは潜在変数、xは本物データです。P_dataとP_zは確率分布でxとzがそれぞれの確率分布から生成されるとしています(詳しくは後述の「GANの数学的な考え方」)。
GANの学習アルゴリズムは以下の通り(GANの論文より引用)です。
GANの数学的な考え方
ここでは、GANの学習についてもう少し踏み込むために、GANの数学的な位置づけについて述べます。
前述したとおり、GANはディープラーニングを用いた生成モデルの一種です。ここでいう生成モデルとは、「データの生成過程を数理的にモデル化したもの」ということです。そして、数理モデルは確率分布によって表されます。つまり、何らかのデータは何らかの確率分布によって生成されると言えます。
簡単な例としてサイコロを考えます。サイコロを振った時に出る数字は、1~6どの面も1/6の確率です。これはつまりサイコロは、1/6の確率で1の目が、1/6の確率で2の目が...といったように、1~6すべての事象が1/6の確率である一様分布を確率分布とする生成モデルであると言えます。このように何らかのデータは何らかの確率分布によって生成されるといえることがわかると思います。
当然、Generatorはデータを生成するため、Generatorは暗黙的に何らかの確率分布を持っていると言えます。もちろん、学習に使う本物のデータ群も何らかの確率分布を持っていると言えます(上記の損失関数でいうところのP_data)。この本物データ群の確率分布にGeneratorの確率分布をできるだけ近づけていくというのが、GANの数学的な考え方です。
ここで損失関数を展開して計算します。
展開したところで、Gを固定してDの最適解について考えます。
ミニマックス問題なので、Dの最適解はVを最大化します。Vを最大化するには積分の中身を最大化すれば良いです。P_tとP_gは確率なので0~1、Dも0~1なので、積分の中身の式は、要は下のような上に凸なグラフになります。上凸なグラフの最大点が確率の割合でずれるだけと考えれば、Dで微分してやればいいです。
微分してやると
これが0になるDが最適解なので、式を整理してやると最適解は
となります。
次にこれを代入してGの最適解を考えると、
となり、最適なDのもとでのGの最適化は、Jensen-Shannonダイバージェンスの最適化に相当するということになります。
Jensen-Shannonダイバージェンスを軽く説明すると、2つの分布が距離としてどれだけ異なるかを測る尺度です。そして、最小値は2つの分布が等しい時で0になります。
つまり、Gの最適解ではVを最小化するのでPtとPgは等しく、V=-log4となります。よって、損失関数のGの最小化はPgをPdataに近づけることに相当し、本物データ群の確率分布にGeneratorの確率分布をできるだけ近づけていくことになるのです!
以下、イメージです(GANの論文より引用)。
ちなみに、Dの最適解
は、0.5となり本物と偽物を五分五分でしか見分けられないということになります。
実装してみた
以上のことをDCGANとして実装してみました。
GANの論文自体は、ネットワークの構成について一切述べていないので、DCGANを使いました。DCGANはGenerator、DiscriminatorをそれぞれFCNで実装したものです。
実装コードはこちら
データセットはgoogle-images-downloadを使い、Google画像検索で出てくるマリオの画像を使いました。ダウンロードした画像から個人目線で、マリオだけが写っている画像を手作業で抽出しました。その後、256*256pxに変換しています。
学習結果
個人的にいい感じで生成されていたのを抽出してみました。
初めてGANをやったので驚きです!!
もう少し綺麗な画像を出して欲しいけど、まあとりあえずここまでで 笑
参考
http://mizti.hatenablog.com/entry/2016/12/10/224426
https://qiita.com/triwave33?page=2
https://elix-tech.github.io/ja/2017/02/06/gan.html
https://qiita.com/kenmatsu4/items/b029d697e9995d93aa24
https://qiita.com/FrontPark/items/a591ad680dbb2589d4ba
https://www.slideshare.net/masa_s/gan-83975514
https://qiita.com/kzkadc/items/f49718dc8aedbe8a1bee