*

Prediction API入門(後編)

公開日: : 投稿者: GAE, GCP

今回はPrediction API on GAE/J

みなさん、こんにちは。Prediction API入門もついに後編となりました。

前編ではPrediction API概要、中編ではAPIs Explorerを使った実行方法を紹介しましたが、今回はClient Library for Javaを使って、GAE/JでPredictoin APIを利用するデモプログラムを作成しました。本記事では、デモプログラムのご紹介とPrediction APIを使っているロジック部分のコードの解説を行います。

ここでは、開発環境として、Eclipse 4.3(Kepler) + Google Pluginを前提とします。

前編、中編の記事はこちら
Prediction API入門(前編)
Prediction API入門(中編)

デモプログラムについて

以下のURLにアクセスすると、デモプログラムを閲覧することができます。

https://z-hida-predictionapi.appspot.com/#/prediction-result

このプログラムでは、GAEやGCPに関するTweetの内容を、Prediction APIを用いて、Good(良い評価をしている)、Neutral(中立)、Bad(悪い評価をしている)の3つに分類した結果を表示しています。ただし、学習データがいい加減なので、結果の精度はイマイチですが、あしからずご了承ください。

ちなみに、予測結果はDatastore上に保存してあり、上記URLにアクセスしたときは、Datastoreから値を取り出して表示しています。

GCPプロジェクトの設定

今回、紹介するコードが動くためには、いくつか準備が必要となりますので、その手順を説明します。

Prediction APIを使う準備

中編の記事を参考に、学習データの用意やPrediction APIの有効化を行ってください。

サービスアカウントの作成

デモプログラムではサービスアカウントを作成して、そのアカウントでAPIを叩く方法を採用しています。ここでは、サービスアカウントの作成方法を説明します。

Google Developers Consoleでプロジェクトを選択し、左のメニューの「認証情報」をクリックしてください。次に、「新しいクライアントIDを作成」をボタンをクリックしてください。

predictionapi3-credentials1

すると、ダイアログが表示されるので、「サービスアカウント」を選択し、「クライアントIDを作成」ボタンをクリックしてください。

predictionapi3-credentials2

これにより、秘密鍵(拡張子が.p12となっているファイル)がダウンロードされ、ダイアログに秘密鍵のパスワードが表示されます。「OK」ボタンをクリックしてください。

predictionapi3-credentials3

ダイアログが閉じて、作成されたサービスアカウントの情報が表示されます。

predictionapi3-credentials4

作成された秘密鍵、秘密鍵のパスワード、サービスアカウントのメールアドレスは後で必要となります。

Eclipseプロジェクトの作成とライブラリの追加

まず、GAEプロジェクトを作成し、WEB-INFフォルダ直下に、サービスアカウントの秘密鍵をコピーしてください。

次にClient Libraryをビルドパスに追加します。NavigatorまたはPackege Explorerで、プロジェクトを右クリックし、Google > Add Google APIs…を選択してください。ダイアログが表示されるので、検索窓にPredictionと入力して絞りこみを行い、Prediction API v1.6を選択して、「OK」ボタンをクリックしてください。

predictionapi3-eclipse1

PredictionLogicクラスの作成

さて、ここからは、今回作成したサンプルの中で、Prediction APIを叩く肝となる部分のコードの紹介と解説を行っていきます。ここでは、PredictionLogicというクラスを作成し、その中に、Prediction APIを呼び出すメソッドを実装していきます。

import文、クラス宣言、定数の宣言

PROJECT_ID, ACCOUNT_ID, PRIVATE_KEY_PATH, P12_PASSWORDの値は適宜、読み替えてください。

[java]
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import com.google.api.client.extensions.appengine.http.UrlFetchTransport;
import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.batch.BatchRequest;
import com.google.api.client.googleapis.batch.json.JsonBatchCallback;
import com.google.api.client.googleapis.json.GoogleJsonError;
import com.google.api.client.http.HttpHeaders;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.services.prediction.Prediction;
import com.google.api.services.prediction.PredictionScopes;
import com.google.api.services.prediction.model.Input;
import com.google.api.services.prediction.model.Input.InputInput;
import com.google.api.services.prediction.model.Insert;
import com.google.api.services.prediction.model.Output;

public class PredictionLogic {
private static final String PROJECT_ID = "GCPプロジェクトのID";
private static final String ACCOUNT_ID ="サービスアカウントのメールアドレス";
private static final String PRIVATE_KEY_PATH = "WEB-INF/p12ファイル";
private static final String P12_PASSWORD = "p12ファイルのパスワード";

private static final String APPLICATION_NAME = "GAE_PREDCTION_EXAMPLE";
private static final String KEY_STYLE = "PKCS12";
private static final String KEY_ALIAS = "privateKey";

private static final HttpTransport HTTP_TRANSPORT = new UrlFetchTransport();
private static final JacksonFactory JSON_FACTORY = JacksonFactory.getDefaultInstance();
//(中略)
}
[/java]

サービスオブジェクトを生成するメソッド

まず、このクラスに、サービスアカウントを使って、Prediction APIを呼び出すためのサービスクラス( com.google.api.services.prediction.Prediction)のインスタンスを生成するメソッドを実装します。

[java]
@SuppressWarnings("resource")
private PrivateKey loadPrivateKey() throws KeyStoreException, NoSuchAlgorithmException,
CertificateException, IOException, UnrecoverableKeyException {
File file = new File(PRIVATE_KEY_PATH);
InputStream is = new FileInputStream(file);
KeyStore keystore = KeyStore.getInstance(KEY_STYLE);
keystore.load(is, P12_PASSWORD.toCharArray());
return (PrivateKey) keystore.getKey(KEY_ALIAS, P12_PASSWORD.toCharArray());
}

private Prediction createPredictionService() throws UnrecoverableKeyException, KeyStoreException,
NoSuchAlgorithmException, CertificateException, IOException {
GoogleCredential credential =
new GoogleCredential.Builder()
.setTransport(HTTP_TRANSPORT)
.setJsonFactory(JSON_FACTORY)
.setServiceAccountId(ACCOUNT_ID)
.setServiceAccountScopes(
Arrays.asList(PredictionScopes.PREDICTION, PredictionScopes.DEVSTORAGE_FULL_CONTROL))
.setServiceAccountPrivateKey(laodPrivateKey())
.build();

return new com.google.api.services.prediction.Prediction.Builder(
HTTP_TRANSPORT,
JSON_FACTORY,
credential).setApplicationName(APPLICATION_NAME).build();
}
[/java]

loadPrivateKeyメソッドは秘密鍵ファイルを読み込んで、PrivateKey型で返すメソッドで、createPredictionServiceメソッドの中で呼ばれます。

createPredictionServiceメソッドでは、まず、GoogleCredentialオブジェクトを生成しています。このオブジェクトは、文字通り、認証に必要な情報(サービスアカウントのID、秘密鍵、使用するAPIのスコープなど)を保持します。このオブジェクトをPrediction.Builderクラスに渡して、buildメソッドを呼ぶことで、Predicitonオブジェクトを生成します。

Insert

Prediction APIのinsertメソッドを実行するメソッドを追加します。

[java]
public void insert(String modelId, String csvPath) throws IOException,
UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException, CertificateException {
Prediction prediction = createPredictionService();

Insert insert = new Insert();
insert.setId(modelId); //好きなmodelIdを指定
insert.setStorageDataLocation(csvPath); // csvPathは”バケット名/ファイルパス”という形式
prediction.trainedmodels().insert(PROJECT_ID, insert).execute();
}
[/java]

この例では、GCS上のCSVファイルを学習データとして、モデルを生成しています。Insertオブジェクトを生成し、idと学習データのパスを指定をセットした後、学習処理を実行しています。

predict

Prediction APIのpredictメソッドを実行し、1件の予測を行うメソッドを追加します。

[java]
public Output predict(String modelId, String text)
throws UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException,
CertificateException, IOException {
Prediction prediction = createPredictionService();

InputInput inputInput = new InputInput();
inputInput.setCsvInstance(Collections.<Object> singletonList(text)); //予測対象のデータをセット
Input input = new Input();
input.setInput(inputInput);

return prediction.trainedmodels().predict(PROJECT_ID, modelId, input).execute();
}
[/java]

InputInputオブジェクトに、予測したいデータをセットし、さらにそれをInputオブジェクトにセットした後、予測処理を実行しています。

今回の予測対象データの構造は、テキスト形式1カラムのみですので、予測したいテキストだけを含むListをsetCsvInstanceでセットしています。学習データが複数カラムの場合は、予測対象データのListの順序が学習データと同じになるように気をつけてください。

予測結果はOutputクラスのオブジェクトとして格納されます。このクラスには、APIs Explorerでpredictを実行して確認できるJSON(中編を参照)に対応した形で、結果を保持しています。例えば、今回のようなラベル予測の場合、getOutputLabel()メソッドやgetOutputMulti()で、予測結果を取得できます。数値を予測する回帰予測の場合は、getOutputValue()で予測値を取得できます。

上記の例は、1件の対象データに対して、予測するものでした。次は、複数件の予測を1回のバッチ処理で行うメソッドを追加します。PredictionCallbackという内部クラスを定義し、そのクラスのメンバ変数resultに、1件1件に対する予測結果のリストを保持させていることに注意してください。

[java]
private static class PredictionCallback extends JsonBatchCallback<Output> {
private List<Output> results = new ArrayList<>();

@Override
public void onSuccess(Output o, HttpHeaders responseHeaders) throws IOException {
results.add(o);
}
@Override
public void onFailure(GoogleJsonError e, HttpHeaders responseHeaders) throws IOException {
results.add(null);
}

public List<Output> getResults() {
return results;
}
}

private List<Output> predictInSingleBatch(String modelId, List<String> texts)
throws UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException,
CertificateException, IOException {
Prediction prediction = createPredictionService();
BatchRequest batch = prediction.batch();
PredictionCallback callback = new PredictionCallback();
for (String text : texts) {
InputInput inputInput = new InputInput();
inputInput.setCsvInstance(Collections.<Object> singletonList(text));
Input input = new Input();
input.setInput(inputInput);
prediction.trainedmodels().predict(PROJECT_ID, modelId, input).queue(batch, callback);
}

batch.execute(); //この処理を実行後、callbackに結果が格納される
return callback.getResults();
}
[/java]

上記の例では、batch.execute()を実行し終わると、変数callbackに予測結果が格納されます。1件1件の予測結果について、PrectionCallbackクラスのonSuccess()メソッドまたはonFailure()メソッドの処理が実行され、resultsに予測結果がaddされるというわけです。

なお、この例では、onFailure()メソッドでは、とりあえずresultsにnullを追加していますが、実際に使うときは、エラーメッセージを格納する変数を定義して、そこにエラーメッセージをセットするなどの処理を実装すると良いと思います。

まとめ

以上の解説を見ていただいてClient Libraryを用いることで、簡単にPrediction APIを利用できることがわかったかと思います。今回、ご紹介したコードではinsertとpredictしか使っていませんが、他のメソッドも同じくらい簡単です。

面白いアイデアがある方は、早速、Prediction APIを使ったWEBアプリケーションを作ってはいかが?

前編、中編の記事はこちら
Prediction API入門(前編)
Prediction API入門(中編)

この記事を書いた人

hida
雑食エンジニア/にわかSF者

関連記事

Prediction API入門(中編)

Prediction API入門実行編 みなさん、こんにちは。 前編では、Google Pr

記事を読む

機械学習最前線!Cloud Machine Learning を始めてみた!

最近、機械学習・深層学習というワードをよく耳にするようになりましたね。機械学習は知っているけど、

記事を読む

GCP×Zencoderで始める動画トランスコーディング

はじめに 以前apps-gcpでは、GCP(GoogleCloudPlatform)と連携可能なA

記事を読む

1つのエンティティにプロパティをいくつまで作れるか

1つのエンティティにプロパティをいくつまで作れるか 1つのエンティティにプロパティをいくつまで

記事を読む

2015/09/16 GCE vs AWS ベンチマーク

2015/09/16 GCE vs AWS ベンチマーク 本シリーズでは定期的に

記事を読む

2014/12/18 GCE vs AWS ベンチマーク

2014/12/18 GCE vs AWS ベンチマーク 本シリーズでは定期的にGCEとEC2のベ

記事を読む

Search API詳細解説 Part4「Search API 詳細 検索性能編」

Search API詳細解説シリーズ タイトル Part1Search API 概要説明

記事を読む

GCEはどこまでディスクを拡張できるのか?手順は?

GCEではディスク容量を自由に設定することができますが、どこまで容量が増やすことができるのか

記事を読む

Search API詳細解説 Part5「Search API 詳細 反映速度編」

Search API詳細解説シリーズ タイトル Part1Search API 概要説明

記事を読む

Google App Engine コンソールの小技

みなさんご存知、GAEのコンソールにはさまざまな機能があります。 あなたがデプロイしたアプリケ

記事を読む

PAGE TOP ↑