GitHubじゃ!Pythonじゃ!

GitHubからPython関係の優良リポジトリを探したかったのじゃー、でも英語は出来ないから日本語で読むのじゃー、英語社会世知辛いのじゃー

MrGemy95

Tensorflow-Project-Template – テンソルフロープロジェクトテンプレートアーキテクチャのベストプラクティス

投稿日:

テンソルフロープロジェクトテンプレートアーキテクチャのベストプラクティス。

Tensorflowプロジェクトテンプレート

ディープラーニングプロジェクトには、シンプルでよくデザインされた構造が不可欠です。テンソルフロープロジェクトでは、たくさんの練習をしてから、 フォルダー構造優れたOOPデザインの シンプルさベストプラクティスを組み合わせたテンソルフローテンプレートが用意さ れています 主な考え方は、テンソルフロープロジェクトを開始するたびに行うことが多いので、この共有されたものをすべてラップすると、新しいテンソルフロープロジェクトを開始するたびにコアアイデアを変更するのに役立ちます。

したがって、ここでは簡単なテンソルフローテンプレートを使用して、メインプロジェクトに素早くアクセスし、コア(モデル、トレーニングなど)に集中するのに役立ちます。

目次

一言で言えば

要約すると、ここではこのテンプレートを使用する方法です。たとえば 、VGGモデルを実装したいと仮定して、次のようにする必要があります。

  • モデルフォルダでは、 “base_model”クラスを継承するVGGという名前のクラスを作成します
    class VGGModel(BaseModel):
        def __init__(self, config):
            super(VGGModel, self).__init__(config)
            #call the build_model and init_saver functions.
            self.build_model() 
            self.init_saver() 
  • vggモデルを実装するこれらの2つの関数 “build_model”と、テンソルフローセーバーを定義する “init_saver”をオーバーライドしてから、それらをinitalizerで呼び出します。
     def build_model(self):
        # here you build the tensorflow graph of any model you want and also define the loss.
        pass
            
     def init_saver(self):
        # here you initalize the tensorflow saver that will be used in saving the checkpoints.
        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
  • トレーナーフォルダで、 “base_train”クラスを継承するVGGトレーナーを作成します
    class VGGTrainer(BaseTrain):
        def __init__(self, sess, model, data, config, logger):
            super(VGGTrainer, self).__init__(sess, model, data, config, logger)
  • これら2つの関数 “train_step”、 “train_epoch”をオーバーライドして、トレーニングプロセスのロジックを記述します
    def train_epoch(self):
        """
       implement the logic of epoch:
       -loop on the number of iterations in the config and call the train step
       -add any summaries you want using the summary
        """
        pass

    def train_step(self):
        """
       implement the logic of the train step
       - run the tensorflow session
       - return any metrics you need to summarize
       """
        pass
  • メインファイルでは、 “Model”、 “Logger”、 “Data_Generator”、 “Trainer”、およびconfigというオブジェクトのセッションとインスタンスを作成します
    sess = tf.Session()
    # create instance of the model you want
    model = VGGModel(config)
    # create your data generator
    data = DataGenerator(config)
    # create tensorboard logger
    logger = Logger(sess, config)
  • これらすべてのオブジェクトをトレーナーオブジェクトに渡し、 “trainer.train()”を呼び出してトレーニングを開始します。
    trainer = VGGTrainer(sess, model, data, config, logger)

    # here you train your model
    trainer.train()

モデルファイルとトレーナーフォルダーにテンプレートファイルと簡単な例があります。これは、最初のモデルを簡単に試してみる方法を示しています。

詳細に

プロジェクトのアーキテクチャ

フォルダ構造

├──  base
│   ├── base_model.py   - this file contains the abstract class of the model.
│   └── base_train.py   - this file contains the abstract class of the trainer.
│
│
├── model               - this folder contains any model of your project.
│   └── example_model.py
│
│
├── trainer             - this folder contains trainers of your project.
│   └── example_trainer.py
│   
├──  mains              - here's the main(s) of your project (you may need more than one main).
│    └── example_main.py  - here's an example of main that is responsible for the whole pipeline.

│  
├──  data _loader  
│    └── data_generator.py  - here's the data_generator that is responsible for all data handling.
│ 
└── utils
     ├── logger.py
     └── any_other_utils_you_need

メインコンポーネント

モデル


  • ベースモデル

    ベースモデルは、あなたが作成したモデルによって継承されなければならない抽象クラスです。このモデルの背後にあるアイデアは、すべてのモデル間で共有されていることです。 ベースモデルに含まれるもの:

    • 保存 – チェックポイントをデスクに保存するこの機能。
    • ロード – デスクからチェックポイントを読み込むためのこの機能。
    • Cur_epoch、Global_step counters –これらの変数は、現在のエポックとグローバルステップを追跡します。
    • Init_Saverチェックポイントの保存とロードに使用されるセーバーを初期化する抽象関数。 注意 :実装するモデルでこの関数をオーバーライドします。
    • Build_modelモデルを定義するための抽象関数です。 :実装したいモデルでこの関数をオーバーライドします。
  • あなたのモデル

    ここにモデルを実装する場所があります。 だからあなたは:

    • モデルクラスを作成し、base_modelクラスを継承する
    • あなたが望むテンソルフローモデルを書く “build_model”をオーバーライドする
    • テンソルフローセーバーを作成してチェックポイントの保存と読み込みに使用する “init_save”をオーバーライドする
    • イニシャライザの “build_model”と “init_saver”を呼び出します。

トレーナー


  • ベーストレーナー

    ベーストレーナーは、トレーニングプロセスをラップする抽象クラスです。

  • あなたのトレーナー

    あなたのトレーナーに実装する必要があるものは次のとおりです。

    1. トレーナークラスを作成し、base_trainerクラスを継承します。
    2. これらの2つの関数 “train_step”、 “train_epoch”をオーバーライドして、各ステップと各エポックのトレーニングプロセスを実装します。

データローダ

このクラスは、すべてのデータ処理と処理を担当し、トレーナーが使用できる簡単なインターフェースを提供します。

ロガー

このクラスはテンソルボードサマリーを担当し、トレーナーで要約したいすべてのテンソルフロー変数の辞書を作成し、この辞書をlogger.summarize()に渡します。

構成

私は設定方法としてJsonを使い、それを解析するので、必要なすべての設定を書き、 “utils / config / process_config”を使って解析し、この設定オブジェクトを他のすべてのオブジェクトに渡します。

メイン

ここで、前のすべての部分を結合します。

  1. 設定ファイルを解析します。
  2. テンソルフローセッションを作成します。
  3. “Model”、 “Data_Generator”、 “Logger”のインスタンスを作成し、すべての設定を解析します。
  4. 「トレーナー」のインスタンスを作成し、以前のすべてのオブジェクトをそのインスタンスに渡します。
  5. これで、 “Trainer.train()”を呼び出してモデルをトレーニングすることができます

今後の仕事

  • データローダ部分を新しいテンソルフローデータセットAPIに置き換えます。

貢献する

どんな種類の増強や貢献も歓迎されます。

謝辞

私の同僚Mo’men Abdelrazekがこの作業に貢献してくれてくれてありがとう。 Mohamed Zahranに感謝の意を表します。 Awesome Tensorflowにレポを含めてくれて、Jtoyに感謝します。







-MrGemy95
-, , , , , , , , , ,

執筆者: