Github: https://github.com/tkarras/progressive_growing_of_gans
改善された品質、安定性、およびバリエーションのためのGANの漸進的な成長
– ICLR 2018論文の公式TensorFlow実装
Tero Karras (NVIDIA)、 Timo Aila (NVIDIA)、 Samuli Laine (NVIDIA)、 Jaakko Lehtinen (NVIDIA and Aalto University)
Picture:乱数ジェネレータによって夢見ていた2人の想像上の有名人。
抽象:
我々は、生成的な対立ネットワークのための新しいトレーニング方法を説明する。 重要なアイデアは、ジェネレータとディスクリミネータの両方を徐々に成長させることです。低解像度から始めて、訓練が進むにつれてますます細かい詳細をモデル化する新しいレイヤーを追加します。 これはトレーニングのスピードアップとそれを大幅に安定化させ、前例のない品質の画像、例えば1024²のCelebA画像を生成することを可能にします。 我々はまた、生成された画像の変化を増加させる簡単な方法を提案し、教師なしのCIFAR10で記録開始スコア8.80を達成する。 さらに、ジェネレータとディスクリミネータの間の不健全な競争を阻止するために重要ないくつかの実装の詳細について説明します。 最後に、GANの結果を評価するための新しいメトリックを、画質とバリエーションの両方で提案します。 追加貢献として、CelebAデータセットのより高品質なバージョンを構築します。
リソース
- ペーパー(NVIDIA調査)
- 紙(arXiv)
- 結果ビデオ(YouTube)
- 追加資料(Googleドライブ)
- ICLR 2018ポスター(
karras2018iclr-poster.pdf
) - ICLR 2018スライド(
karras2018iclr-slides.pptx
) - 代表画像(
images/representative-images
) - 高品質のビデオクリップ(
videos/high-quality-video-clips
) - 各データセット(
images/100k-generated-images
)の非キュレーション画像の巨大なコレクション - 各データセット(
videos/one-hour-of-random-interpolations
)のランダム補間の広範なビデオは、 - 事前訓練されたネットワーク(
networks/tensorflow-version
) - 事前に訓練されたネットワークをインポートするための最小サンプルスクリプト(
networks/tensorflow-version/example_import_script
) - CelebA-HQデータセット(
datasets/celeba-hq-deltas
)を再構築するために必要なデータファイル - トレーニングログと進行状況スナップショットの例(
networks/tensorflow-version/example_training_runs
)
- ICLR 2018ポスター(
ソースコードを含むすべての素材は、クリエイティブコモンズのCC BY-NC 4.0ライセンスの下で、非商用利用のために自由に利用可能にされています。 私たちの論文のタイトルと著者リストに言及して適切なクレジットを与えている限り、自分の作品の中のどの素材も自由に使用してください。
バージョン
ソースコードには2つの異なるバージョンがあります。 TensorFlowのバージョンはより新しいもので、より洗練されています。私たちのテクニックを試してみたり、構築したり、新しいデータセットに適用したい場合は、一般的に開始点としてお勧めします。 一方、 オリジナルのTheanoバージョンは 、私たちの論文に示されているすべての結果を生成するために使用したものです。 CIFAR-10、MNIST-RGB、CelebAなどのベンチマークデータセットの正確な結果を再現したい場合にのみ使用することをお勧めします。
主な相違点を次の表にまとめます。
特徴 | TensorFlowバージョン | オリジナルのTheanoバージョン |
---|---|---|
ブランチ | マスター (このブランチ) | オリジナル – テアノ – バージョン |
マルチGPUサポート | はい | いいえ |
FP16混在精度サポート | はい | いいえ |
パフォーマンス | 高い | 低い |
CelebA-HQのトレーニング時間 | 2日間(8 GPU) 2週間(1 GPU) |
1-2ヶ月 |
Repro CelebA-HQの結果 | はい – 非常に近い | はい – 同一 |
Repro LSUNの結果 | はい – 非常に近い | はい – 同一 |
Repro CIFAR-10の結果 | いいえ | はい – 同一 |
Repro MNISTモード回復 | いいえ | はい – 同一 |
再切断アブレーション試験(表1) | いいえ | はい – 同一 |
データセット形式 | TFRecords | HDF5 |
下位互換性 | ネットワークをインポートできます テアノと訓練された |
N / A |
コード品質 | 合理的 | やや厄介な |
コードステータス | アクティブな使用 | もう維持されない |
システム要求
- LinuxとWindowsの両方がサポートされていますが、パフォーマンスと互換性の理由から、Linuxを強くお勧めします。
- 1.31.3またはそれ以降のnumpyの64ビットPython 3.6インストール。 Anaconda3をお勧めします。
- 16GBのDRAMを搭載した1つ以上のハイエンドNVIDIA PascalまたはVolta GPU。 8つのTesla V100 GPUを搭載したNVIDIA DGX-1をお勧めします。
- NVIDIAドライバ391.25以降、CUDAツールキット9.0以降、cuDNN 7.1.2以降。
-
requirements-pip.txt
リストされているその他のPythonパッケージ
事前に訓練されたネットワークのインポートと使用
Googleドライブにあるトレーニング済みのネットワークはすべて、トレーニングスクリプトで作成されたネットワークと同様に、Python PKLファイルとして保存されます。 (1)プログレッシブGANコードリポジトリを含むディレクトリは、PYTHONPATH環境変数に含める必要があります。(2) tf.Session()
オブジェクトには、 tf.Session()
2つの条件が満たされている必要があります。事前に作成され、デフォルトとして設定されています。 各PKLファイルには、 tfutil.Network
3つのインスタンスが含まれています。
# Import official CelebA-HQ networks.
with open('karras2018iclr-celebahq-1024x1024.pkl', 'rb') as file:
G, D, Gs = pickle.load(file)
# G = Instantaneous snapshot of the generator, mainly useful for resuming a previous training run.
# D = Instantaneous snapshot of the discriminator, mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator, yielding higher-quality results than the instantaneous snapshot.
また、TensorFlowバージョン(ミニバッチ識別、バッチ標準化など)によってネイティブにサポートされていない機能を使用しない限り、Theano実装を使用して生成されたネットワークをインポートすることもできます。 ただし、Theanoネットワークのインポートを有効にするには、 misc.load_pkl()
代わりにpickle.load()
使用する必要があります。
# Import Theano versions of the official CelebA-HQ networks.
import misc
G, D, Gs = misc.load_pkl('200-celebahq-1024x1024/network-final.pkl')
ネットワークをインポートしたら、 Gs.run()
を呼び出して潜在ベクトルの画像セットを生成するか、 Gs.get_output_for()
を呼び出してより大きなTensorFlow式にジェネレータネットワークを組み込むことができます。 詳細については、Googleドライブのサンプルスクリプトをご覧ください。 指示:
- プログレッシブGANコードリポジトリを取り出し、PYTHONPATH環境変数に追加します。
- 必要なPythonパッケージを
pip install -r requirements-pip.txt
-
networks/tensorflow-version/example_import_script
からimport_example.py
をダウンロードしてimport_example.py
-
networks/tensorflow-version
からkarras2018iclr-celebahq-1024x1024.pkl
ダウンロードし、スクリプトと同じディレクトリに置きます。 -
python import_example.py
スクリプトを実行する - すべてがうまくいけば、スクリプトは
networks/tensorflow-version/example_import_script
正確に一致する10個のPNG画像(img0.png
–img9.png
)を生成する必要がありnetworks/tensorflow-version/example_import_script
。
トレーニングのためのデータセットの準備
プログレッシブGANコードリポジトリには、この論文で使用したデータセットのビット正確なレプリカを再作成するためのコマンドラインツールが含まれています。 このツールは、データセットを操作するためのさまざまなユーティリティも提供します。
usage: dataset_tool.py [-h] <command> ...
display Display images in dataset.
extract Extract images from dataset.
compare Compare two datasets.
create_mnist Create dataset for MNIST.
create_mnistrgb Create dataset for MNIST-RGB.
create_cifar10 Create dataset for CIFAR-10.
create_cifar100 Create dataset for CIFAR-100.
create_svhn Create dataset for SVHN.
create_lsun Create dataset for single LSUN category.
create_celeba Create dataset for CelebA.
create_celebahq Create dataset for CelebA-HQ.
create_from_images Create dataset from a directory full of images.
create_from_hdf5 Create dataset from legacy HDF5 archive.
Type "dataset_tool.py <command> -h" for more information.
データセットは、効率的なストリーミングを可能にするために、いくつかの解像度で同じ画像データを含むディレクトリによって表されます。 分解能ごとに別々の*.tfrecords
ファイルがあり、データセットにラベルが含まれている場合は、それらも別のファイルに保存されます。
> python dataset_tool.py create_cifar10 datasets/cifar10 ~/downloads/cifar10
> ls -la datasets/cifar10
drwxr-xr-x 2 user user 7 Feb 21 10:07 .
drwxrwxr-x 10 user user 62 Apr 3 15:10 ..
-rw-r--r-- 1 user user 4900000 Feb 19 13:17 cifar10-r02.tfrecords
-rw-r--r-- 1 user user 12350000 Feb 19 13:17 cifar10-r03.tfrecords
-rw-r--r-- 1 user user 41150000 Feb 19 13:17 cifar10-r04.tfrecords
-rw-r--r-- 1 user user 156350000 Feb 19 13:17 cifar10-r05.tfrecords
-rw-r--r-- 1 user user 2000080 Feb 19 13:17 cifar10-rxx.labels
create_*
コマンドは、指定されたデータセットの標準バージョンを入力として受け取り、対応する*.tfrecords
ファイルを出力として生成します。 さらに、 create_celebahq
コマンドは、元のCelebAデータセットに関するデルタを表すデータファイルのセットを必要とします。 これらのデルタ(27.6GB)は、 datasets/celeba-hq-deltas
からダウンロードできます。
モジュールのバージョンに関する注意 :データセットコマンドの中には、特定バージョンのPythonモジュールやシステムライブラリ(例:pillow、libjpeg)が必要なものがあります。バージョンが一致しないとエラーになります。 エラーメッセージに注意してください。これらの特定のバージョンをインストールする以外の方法でコマンドを実行する方法はありません 。
トレーニングネットワーク
必要なデータセットが設定されたら、自分のネットワークをトレーニングすることができます。 一般的な手順は次のとおりです。
-
config.py
を編集して、特定の行のコメントを解除/編集して、データセットとトレーニング設定を指定します。 -
python train.py
トレーニングスクリプトを実行します。 - 結果は、
config.result_dir
下に新しく作成されたサブディレクトリにconfig.result_dir
れconfig.result_dir
- 収束するために数日間(または数週間)待ってから、結果を分析します。
デフォルトでconfig.py
は、CelebA-HQ用の1024×1024ネットワークを単一GPUを使用してトレーニングするように設定されています。 これは最高のNVIDIA GPUでも約2週間かかると予想されています。 より速いトレーニングを可能にする鍵は、複数のGPUを採用したり、低解像度のデータセットを使用することです。 このため、 config.py
は、一般的に使用されるデータセットの複数の例と、マルチGPUトレーニングのための一連の「設定プリセット」が含まれています。 すべてのプリセットはCelebA-HQとほぼ同じ画質を得ることが期待されていますが、トレーニング時間は大幅に異なる場合があります。
-
preset-v1-1gpu
:論文に示されているCelebA-HQとLSUN結果を生成するために使用された元の設定。 NVIDIA Tesla V100で約1ヶ月かかると予想されています。 -
preset-v2-1gpu
:元のものよりもかなり高速に収束する最適化preset-v2-1gpu
。 1xV100で約2週間かかると予想されます。 -
preset-v2-2gpus
:2 GPUの最適化preset-v2-2gpus
。 2xV100で約1週間かかります。 -
preset-v2-4gpus
:4つのGPUのための最適化preset-v2-4gpus
。 4xV100で約3日間かかります。 -
preset-v2-8gpus
:8 GPUの最適化preset-v2-8gpus
。 8xV100で約2日間かかります。
参考までに、CelebA-HQの各設定プリセットの予想出力は、 networks/tensorflow-version/example_training_runs
他の注目すべき設定オプション:
-
fp16
: FP16混合精度トレーニングを有効にして、トレーニング時間をさらに短縮します。 実際のスピードアップはGPUアーキテクチャとcuDNNのバージョンに大きく依存しており、将来的にはかなり増加すると予想されます。 -
BENCHMARK
:決断を素早く繰り返して、生の訓練のパフォーマンスを測定します。 -
BENCHMARK0
:BENCHMARK0
同じですが、最高解像度のみを使用します。 -
syn1024rgb
:黒い画像だけからなる合成の1024×1024データセット。 ベンチマークに役立ちます。 -
VERBOSE
:イメージとネットワークのスナップショットを非常に頻繁に保存して、デバッグを容易にします。 -
GRAPH
andHIST
:TensorBoardレポートに追加データを含めます。
結果の分析
トレーニング結果は、いくつかの方法で分析できます。
- 手動検査 :トレーニングスクリプトは、ランダムに生成されたイメージのスナップショットを
fakes*.png
定期的に保存し、log.txt
全体的な進捗状況を報告します。 - TensorBoard :トレーニングスクリプトは、テンソル
tensorboard --logdir <result_subdir>
でTensorBoardで視覚化できる*.tfevents
ファイル内のさまざまな実行統計を*.tfevents
ます。 - 画像と動画の生成 :
config.py
の最後には、ユーティリティスクリプト(generate_*
)を起動するためのいくつかの事前定義された設定があります。 例えば:-
010-pgan-celebahq-preset-v1-1gpu-fp32
と題された進行中のトレーニングがあり、最新のスナップショットのランダムな補間のビデオを生成したいとします。 -
config.py
のgenerate_interpolation_video
行のコメントを外し、run_id=10
置き換えて、python train.py
を実行しpython train.py
- スクリプトは自動的に最新のネットワークスナップショットを探し、単一のMP4ファイルを含む新しい結果ディレクトリを作成します。
-
- 品質メトリクス :前の例と同様に、
config.py
は、既存のトレーニングの実行時にさまざまな品質メトリック(スライスされたワッサースタイン距離、フレーズ開始距離など)を計算するための事前定義された設定が含まれています。 メトリックは、ネットワークスナップショットごとに連続して計算され、元の結果ディレクトリのmetric-*.txt
に格納されmetric-*.txt
。