AI・機械学習・ディープラーニング

 

TensorFlowモデルをdeeplearn.jsに移植する(翻訳)

こんにちは、荒井(@yutakarai)です。

NOTE

本記事は、deeplearn.jsのサイトのPort TensorFlow modelsを翻訳(適宜意訳)したものです。誤り等あればご指摘いただけたら幸いです。

TensorFlowモデルをdeeplearn.jsに移植する

このチュートリアルでは、TensorFlowモデルをトレーニングしてdeeplearn.jsに移植する方法を示します。このチュートリアルで使用されているコードと必要なリソースはすべてdemos/mnistに格納されています。

MNISTデータセットの手書き数字を予測する、完全結合ニューラルネットワーク(fully connected neural network)を使用します。このコードは公式のTensorFlow MNISTチュートリアルからforkされています。

注意

deeplearn.js repo のベース・ディレクトリを $BASE として参照します。

まず最初に、deeplearn.jsリポジトリをクローンし、TensorFlowがインストールされていることを確認します。$BASEに移動(cd)して次のコマンドを実行し、モデルをトレーニングします。

python demos/mnist/fully_connected_feed.py

トレーニングには約1分かかり、/tmp/tensorflow/mnist/tensorflow/mnist/logs/fully_connected_feed/にモデルチェックポイントが格納されます。

次に、TensorFlowチェックポイントからdeeplearn.jsに重み(weight)を移植する必要があります。これを行うスクリプトを提供しています。 $BASEディレクトリから実行します。

python scripts/dump_checkpoint_vars.py --output_dir=demos/mnist/ --checkpoint_file=/tmp/tensorflow/mnist/logs/fully_connected_feed/model.ckpt-1999

このスクリプトは、demos/mnistディレクトリに一連のファイル(variableごとに1つのファイルと、manifest.json)を保存します。manifest.jsonは、変数名をファイルとその形状にマップする単純なディクショナリーです。

{
  ...,
  "hidden1/weights": {
    "filename": "hidden1_weights",
    "shape": [784, 128]
  },
  ...
}

コーディングを開始する前に、$BASEディレクトリから静的なHTTPサーバーを起動する必要があります。

npm run prep
./node_modules/.bin/http-server
>> Starting up http-server, serving ./
>> Available on:
>>   http://127.0.0.1:8080
>> Hit CTRL-C to stop the server

ブラウザでhttp://localhost:8080/demos/mnist/manifest.jsonにアクセスして、HTTP経由でmanifest.jsonにアクセスできることを確認してください。

これで、deeplearn.jsコードを書く準備が整いました。

注意

TypeScriptで記述する場合は、コードをJavaScriptにコンパイルして、静的HTTPサーバー経由で提供するようにしてください。

重み(weight)を読むには、CheckpointLoaderを作成し、manifestファイルを指し示す必要があります。次に、変数名をNDArraysにマップするディクショナリーを返すloader.getAllVariables()を呼び出します。これで、モデルを書く準備が整いました。以下は、CheckpointLoaderの使用方法を示す抜粋になります。

import {CheckpointLoader, Graph} from 'deeplearnjs';
// manifest.json is in the same dir as index.html.
const varLoader = new CheckpointLoader('.');
varLoader.getAllVariables().then(vars => {
  // Write your model here.
  const g = new Graph();
  const input = g.placeholder('input', [784]);
  const hidden1W = g.constant(vars['hidden1/weights']);
  const hidden1B = g.constant(vars['hidden1/biases']);
  const hidden1 = g.relu(g.add(g.matmul(input, hidden1W), hidden1B));
  ...
  ...
  const math = new NDArrayMathGPU();
  const sess = new Session(g, math);
  math.scope(() => {
    const result = sess.eval(...);
    console.log(result.getValues());
  });
});

完全なモデルコードの詳細については、demos/mnist/mnist.tsを参照してください。このデモでは、3つの異なるAPIを使用してMNISTモデルを正確に実装しています。

  • buildModelGraphAPI()は、TensorFlow APIを模倣したGraph APIを使用して、フィードとフェッチを遅延実行(lazy execution)します。ユーザーは、入力データ以外のGPU関連のメモリリークを心配する必要はありません。
  • buildModelLayerAPI()は、Graph APIをKeraレイヤAPIを模倣するGraph.layersと組み合わせて使用​​します。
  • buildModelMathAPI()は、Math APIを使用します。これはdeeplearn.jsの最も低いレベルのAPIであり、ユーザに最も多くの機能を与えます。数学コマンドはnumpyのようにすぐに実行されます。mathコマンドはmath.scope()に含まれ、中間のmathコマンドで作成されたNDArraysが自動的にクリーンアップされます。

このmnistデモを実行するために、変更されたときにタイプコピーコードを見て再コンパイルするwatch-demoスクリプトがあります。さらに、スクリプトは、静的なhtml/jsファイルを提供する8080上の単純なHTTPサーバーを実行します。watch-demoを実行する前に、8080ポートを解放するために、チュートリアルで前述したHTTPサーバーを終了させて​​ください。次に、$BASEから web app デモのエントリ・ポイントのdemos/mnist/mnist.tsを指すように、watch-demoを実行します。

./scripts/watch-demo demos/mnist/mnist.ts
>> Starting up http-server, serving ./
>> Available on:
>>   http://127.0.0.1:8080
>>   http://192.168.1.5:8080
>> Hit CTRL-C to stop the server
>> 1410084 bytes written to demos/mnist/bundle.js (0.91 seconds) at 5:17:45 PM

http://localhost:8080/demos/mnist/にアクセスすると、demos/mnist/sample_data.jsonに保存されているテストイメージを使用して測定された〜90%のテスト精度を示す簡単なページが表示されます。

この記事が気に入ったら
いいね ! しよう

Twitter で

関連記事

  1. AI・機械学習・ディープラーニング

    Appleの機械学習フレームワーク「Core ML」「Vision」

    こんにちは、荒井(@yutakarai)です。機械学習、流行ってい…

  2. AI・機械学習・ディープラーニング

    TensorFlow for C をインストールしよう

    こんにちは、荒井(@yutakarai)です。本記事は、Tenso…

  3. AI・機械学習・ディープラーニング

    システム障害を予測するAIプラットフォーム Blue Matador

    こんにちは、荒井です。海外のスタートアップの事例を通して、新しいビ…

IT技術と経営をつなぐコンサルティング

最近の記事

  1. ITビジネス戦略・ビジネスモデル研究

    マーケティング自動化サービス – MaaxMarket
  2. RPA・業務自動化・業務改善

    コンバージョン率最適化とは?
  3. AI・機械学習・ディープラーニング

    AIボットと会話するだけでウェブサイトをデザインできる – Fire…
  4. AI・機械学習・ディープラーニング

    深層学習ライブラリKerasの勉強に役立つコミュニティサイトまとめ(英語サイト)…
  5. AI・機械学習・ディープラーニング

    顧客サポート社員のクローンを作るAIチャットボットサービス Hyphen
PAGE TOP