Prediction API入門(後編)

  • このエントリーをはてなブックマークに追加

ご注意

この記事は 2015年3月31日 に書かれたものです。内容が古い可能性がありますのでご注意ください。

今回はPrediction API on GAE/J

みなさん、こんにちは。Prediction API入門もついに後編となりました。

前編ではPrediction API概要、中編ではAPIs Explorerを使った実行方法を紹介しましたが、今回はClient Library for Javaを使って、GAE/JでPredictoin APIを利用するデモプログラムを作成しました。本記事では、デモプログラムのご紹介とPrediction APIを使っているロジック部分のコードの解説を行います。

ここでは、開発環境として、Eclipse 4.3(Kepler) + Google Pluginを前提とします。

前編、中編の記事はこちら
Prediction API入門(前編)
Prediction API入門(中編)

デモプログラムについて

以下のURLにアクセスすると、デモプログラムを閲覧することができます。

https://z-hida-predictionapi.appspot.com/#/prediction-result

このプログラムでは、GAEやGCPに関するTweetの内容を、Prediction APIを用いて、Good(良い評価をしている)、Neutral(中立)、Bad(悪い評価をしている)の3つに分類した結果を表示しています。ただし、学習データがいい加減なので、結果の精度はイマイチですが、あしからずご了承ください。

ちなみに、予測結果はDatastore上に保存してあり、上記URLにアクセスしたときは、Datastoreから値を取り出して表示しています。

GCPプロジェクトの設定

今回、紹介するコードが動くためには、いくつか準備が必要となりますので、その手順を説明します。

Prediction APIを使う準備

中編の記事を参考に、学習データの用意やPrediction APIの有効化を行ってください。

サービスアカウントの作成

デモプログラムではサービスアカウントを作成して、そのアカウントでAPIを叩く方法を採用しています。ここでは、サービスアカウントの作成方法を説明します。

Google Developers Consoleでプロジェクトを選択し、左のメニューの「認証情報」をクリックしてください。次に、「新しいクライアントIDを作成」をボタンをクリックしてください。

predictionapi3-credentials1

すると、ダイアログが表示されるので、「サービスアカウント」を選択し、「クライアントIDを作成」ボタンをクリックしてください。

predictionapi3-credentials2

これにより、秘密鍵(拡張子が.p12となっているファイル)がダウンロードされ、ダイアログに秘密鍵のパスワードが表示されます。「OK」ボタンをクリックしてください。

predictionapi3-credentials3

ダイアログが閉じて、作成されたサービスアカウントの情報が表示されます。

predictionapi3-credentials4

作成された秘密鍵、秘密鍵のパスワード、サービスアカウントのメールアドレスは後で必要となります。

Eclipseプロジェクトの作成とライブラリの追加

まず、GAEプロジェクトを作成し、WEB-INFフォルダ直下に、サービスアカウントの秘密鍵をコピーしてください。

次にClient Libraryをビルドパスに追加します。NavigatorまたはPackege Explorerで、プロジェクトを右クリックし、Google > Add Google APIs…を選択してください。ダイアログが表示されるので、検索窓にPredictionと入力して絞りこみを行い、Prediction API v1.6を選択して、「OK」ボタンをクリックしてください。

predictionapi3-eclipse1

PredictionLogicクラスの作成

さて、ここからは、今回作成したサンプルの中で、Prediction APIを叩く肝となる部分のコードの紹介と解説を行っていきます。ここでは、PredictionLogicというクラスを作成し、その中に、Prediction APIを呼び出すメソッドを実装していきます。

import文、クラス宣言、定数の宣言

PROJECT_ID, ACCOUNT_ID, PRIVATE_KEY_PATH, P12_PASSWORDの値は適宜、読み替えてください。

[java]
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import com.google.api.client.extensions.appengine.http.UrlFetchTransport;
import com.google.api.client.googleapis.auth.oauth2.GoogleCredential;
import com.google.api.client.googleapis.batch.BatchRequest;
import com.google.api.client.googleapis.batch.json.JsonBatchCallback;
import com.google.api.client.googleapis.json.GoogleJsonError;
import com.google.api.client.http.HttpHeaders;
import com.google.api.client.http.HttpTransport;
import com.google.api.client.json.jackson2.JacksonFactory;
import com.google.api.services.prediction.Prediction;
import com.google.api.services.prediction.PredictionScopes;
import com.google.api.services.prediction.model.Input;
import com.google.api.services.prediction.model.Input.InputInput;
import com.google.api.services.prediction.model.Insert;
import com.google.api.services.prediction.model.Output;

public class PredictionLogic {
private static final String PROJECT_ID = "GCPプロジェクトのID";
private static final String ACCOUNT_ID ="サービスアカウントのメールアドレス";
private static final String PRIVATE_KEY_PATH = "WEB-INF/p12ファイル";
private static final String P12_PASSWORD = "p12ファイルのパスワード";

private static final String APPLICATION_NAME = "GAE_PREDCTION_EXAMPLE";
private static final String KEY_STYLE = "PKCS12";
private static final String KEY_ALIAS = "privateKey";

private static final HttpTransport HTTP_TRANSPORT = new UrlFetchTransport();
private static final JacksonFactory JSON_FACTORY = JacksonFactory.getDefaultInstance();
//(中略)
}
[/java]

サービスオブジェクトを生成するメソッド

まず、このクラスに、サービスアカウントを使って、Prediction APIを呼び出すためのサービスクラス( com.google.api.services.prediction.Prediction)のインスタンスを生成するメソッドを実装します。

[java]
@SuppressWarnings("resource")
private PrivateKey loadPrivateKey() throws KeyStoreException, NoSuchAlgorithmException,
CertificateException, IOException, UnrecoverableKeyException {
File file = new File(PRIVATE_KEY_PATH);
InputStream is = new FileInputStream(file);
KeyStore keystore = KeyStore.getInstance(KEY_STYLE);
keystore.load(is, P12_PASSWORD.toCharArray());
return (PrivateKey) keystore.getKey(KEY_ALIAS, P12_PASSWORD.toCharArray());
}

private Prediction createPredictionService() throws UnrecoverableKeyException, KeyStoreException,
NoSuchAlgorithmException, CertificateException, IOException {
GoogleCredential credential =
new GoogleCredential.Builder()
.setTransport(HTTP_TRANSPORT)
.setJsonFactory(JSON_FACTORY)
.setServiceAccountId(ACCOUNT_ID)
.setServiceAccountScopes(
Arrays.asList(PredictionScopes.PREDICTION, PredictionScopes.DEVSTORAGE_FULL_CONTROL))
.setServiceAccountPrivateKey(laodPrivateKey())
.build();

return new com.google.api.services.prediction.Prediction.Builder(
HTTP_TRANSPORT,
JSON_FACTORY,
credential).setApplicationName(APPLICATION_NAME).build();
}
[/java]

loadPrivateKeyメソッドは秘密鍵ファイルを読み込んで、PrivateKey型で返すメソッドで、createPredictionServiceメソッドの中で呼ばれます。

createPredictionServiceメソッドでは、まず、GoogleCredentialオブジェクトを生成しています。このオブジェクトは、文字通り、認証に必要な情報(サービスアカウントのID、秘密鍵、使用するAPIのスコープなど)を保持します。このオブジェクトをPrediction.Builderクラスに渡して、buildメソッドを呼ぶことで、Predicitonオブジェクトを生成します。

Insert

Prediction APIのinsertメソッドを実行するメソッドを追加します。

[java]
public void insert(String modelId, String csvPath) throws IOException,
UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException, CertificateException {
Prediction prediction = createPredictionService();

Insert insert = new Insert();
insert.setId(modelId); //好きなmodelIdを指定
insert.setStorageDataLocation(csvPath); // csvPathは”バケット名/ファイルパス”という形式
prediction.trainedmodels().insert(PROJECT_ID, insert).execute();
}
[/java]

この例では、GCS上のCSVファイルを学習データとして、モデルを生成しています。Insertオブジェクトを生成し、idと学習データのパスを指定をセットした後、学習処理を実行しています。

predict

Prediction APIのpredictメソッドを実行し、1件の予測を行うメソッドを追加します。

[java]
public Output predict(String modelId, String text)
throws UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException,
CertificateException, IOException {
Prediction prediction = createPredictionService();

InputInput inputInput = new InputInput();
inputInput.setCsvInstance(Collections.<Object> singletonList(text)); //予測対象のデータをセット
Input input = new Input();
input.setInput(inputInput);

return prediction.trainedmodels().predict(PROJECT_ID, modelId, input).execute();
}
[/java]

InputInputオブジェクトに、予測したいデータをセットし、さらにそれをInputオブジェクトにセットした後、予測処理を実行しています。

今回の予測対象データの構造は、テキスト形式1カラムのみですので、予測したいテキストだけを含むListをsetCsvInstanceでセットしています。学習データが複数カラムの場合は、予測対象データのListの順序が学習データと同じになるように気をつけてください。

予測結果はOutputクラスのオブジェクトとして格納されます。このクラスには、APIs Explorerでpredictを実行して確認できるJSON(中編を参照)に対応した形で、結果を保持しています。例えば、今回のようなラベル予測の場合、getOutputLabel()メソッドやgetOutputMulti()で、予測結果を取得できます。数値を予測する回帰予測の場合は、getOutputValue()で予測値を取得できます。

上記の例は、1件の対象データに対して、予測するものでした。次は、複数件の予測を1回のバッチ処理で行うメソッドを追加します。PredictionCallbackという内部クラスを定義し、そのクラスのメンバ変数resultに、1件1件に対する予測結果のリストを保持させていることに注意してください。

[java]
private static class PredictionCallback extends JsonBatchCallback<Output> {
private List<Output> results = new ArrayList<>();

@Override
public void onSuccess(Output o, HttpHeaders responseHeaders) throws IOException {
results.add(o);
}
@Override
public void onFailure(GoogleJsonError e, HttpHeaders responseHeaders) throws IOException {
results.add(null);
}

public List<Output> getResults() {
return results;
}
}

private List<Output> predictInSingleBatch(String modelId, List<String> texts)
throws UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException,
CertificateException, IOException {
Prediction prediction = createPredictionService();
BatchRequest batch = prediction.batch();
PredictionCallback callback = new PredictionCallback();
for (String text : texts) {
InputInput inputInput = new InputInput();
inputInput.setCsvInstance(Collections.<Object> singletonList(text));
Input input = new Input();
input.setInput(inputInput);
prediction.trainedmodels().predict(PROJECT_ID, modelId, input).queue(batch, callback);
}

batch.execute(); //この処理を実行後、callbackに結果が格納される
return callback.getResults();
}
[/java]

上記の例では、batch.execute()を実行し終わると、変数callbackに予測結果が格納されます。1件1件の予測結果について、PrectionCallbackクラスのonSuccess()メソッドまたはonFailure()メソッドの処理が実行され、resultsに予測結果がaddされるというわけです。

なお、この例では、onFailure()メソッドでは、とりあえずresultsにnullを追加していますが、実際に使うときは、エラーメッセージを格納する変数を定義して、そこにエラーメッセージをセットするなどの処理を実装すると良いと思います。

まとめ

以上の解説を見ていただいてClient Libraryを用いることで、簡単にPrediction APIを利用できることがわかったかと思います。今回、ご紹介したコードではinsertとpredictしか使っていませんが、他のメソッドも同じくらい簡単です。

面白いアイデアがある方は、早速、Prediction APIを使ったWEBアプリケーションを作ってはいかが?

前編、中編の記事はこちら
Prediction API入門(前編)
Prediction API入門(中編)

  • このエントリーをはてなブックマークに追加

Google のクラウドサービスについてもっと詳しく知りたい、直接話が聞いてみたいという方のために、クラウドエースでは無料相談会を実施しております。お申し込みは下記ボタンより承っておりますので、この機会にぜひ弊社をご利用いただければと思います。

無料相談会のお申込みはこちら