Skip to content

Commit 971a830

Browse files
authored
Add a generic BatchPrediction sample (GoogleCloudPlatform#1774)
1 parent 2802af1 commit 971a830

File tree

5 files changed

+83
-3
lines changed

5 files changed

+83
-3
lines changed

automl/cloud-client/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ small section of code to print out the `metadata` field.
6363
* [Deploy Model](src/main/java/com/example/automl/DeployModel.java) - Not supported by Translation
6464
* [Uneploy Model](src/main/java/com/example/automl/UndeployModel.java) - Not supported by Translation
6565

66+
### Batch Prediction
67+
* [Batch Predict](src/main/java/com/example/automl/BatchPredict.java) - Supported by: Natural Language Entity Extraction, Vision Classification, and Vision Object Detection.
6668

6769
### Operation Management
6870
* [List Operation Statuses](src/main/java/com/example/automl/ListOperationStatus.java)
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright 2019 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.example.automl;
18+
19+
// [START automl_batch_predict]
20+
import com.google.api.gax.longrunning.OperationFuture;
21+
import com.google.cloud.automl.v1.BatchPredictInputConfig;
22+
import com.google.cloud.automl.v1.BatchPredictOutputConfig;
23+
import com.google.cloud.automl.v1.BatchPredictRequest;
24+
import com.google.cloud.automl.v1.BatchPredictResult;
25+
import com.google.cloud.automl.v1.GcsDestination;
26+
import com.google.cloud.automl.v1.GcsSource;
27+
import com.google.cloud.automl.v1.ModelName;
28+
import com.google.cloud.automl.v1.OperationMetadata;
29+
import com.google.cloud.automl.v1.PredictionServiceClient;
30+
31+
import java.io.IOException;
32+
import java.util.concurrent.ExecutionException;
33+
34+
class BatchPredict {
35+
36+
static void batchPredict() throws IOException, ExecutionException, InterruptedException {
37+
// TODO(developer): Replace these variables before running the sample.
38+
String projectId = "YOUR_PROJECT_ID";
39+
String modelId = "YOUR_MODEL_ID";
40+
String inputUri = "gs://YOUR_BUCKET_ID/path_to_your_input_csv_or_jsonl";
41+
String outputUri = "gs://YOUR_BUCKET_ID/path_to_save_results/";
42+
batchPredict(projectId, modelId, inputUri, outputUri);
43+
}
44+
45+
static void batchPredict(String projectId, String modelId, String inputUri, String outputUri)
46+
throws IOException, ExecutionException, InterruptedException {
47+
// Initialize client that will be used to send requests. This client only needs to be created
48+
// once, and can be reused for multiple requests. After completing all of your requests, call
49+
// the "close" method on the client to safely clean up any remaining background resources.
50+
try (PredictionServiceClient client = PredictionServiceClient.create()) {
51+
// Get the full path of the model.
52+
ModelName name = ModelName.of(projectId, "us-central1", modelId);
53+
GcsSource gcsSource = GcsSource.newBuilder().addInputUris(inputUri).build();
54+
BatchPredictInputConfig inputConfig =
55+
BatchPredictInputConfig.newBuilder().setGcsSource(gcsSource).build();
56+
GcsDestination gcsDestination =
57+
GcsDestination.newBuilder().setOutputUriPrefix(outputUri).build();
58+
BatchPredictOutputConfig outputConfig =
59+
BatchPredictOutputConfig.newBuilder().setGcsDestination(gcsDestination).build();
60+
BatchPredictRequest request =
61+
BatchPredictRequest.newBuilder()
62+
.setName(name.toString())
63+
.setInputConfig(inputConfig)
64+
.setOutputConfig(outputConfig)
65+
// [0.0-1.0] Only produce results higher than this value
66+
.putParams("score_threshold", "0.8")
67+
.build();
68+
69+
OperationFuture<BatchPredictResult, OperationMetadata> future =
70+
client.batchPredictAsync(request);
71+
72+
System.out.println("Waiting for operation to complete...");
73+
BatchPredictResult response = future.get();
74+
System.out.println("Batch Prediction results saved to specified Cloud Storage bucket.");
75+
}
76+
}
77+
}
78+
// [END automl_batch_predict]

automl/cloud-client/src/test/java/com/example/automl/LanguageEntityExtractionPredictIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt
8686
String inputUri = String.format("gs://%s/entity_extraction/input.jsonl", BUCKET_ID);
8787
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
8888
// Act
89-
LanguageBatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri);
89+
BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri);
9090

9191
// Assert
9292
String got = bout.toString();

automl/cloud-client/src/test/java/com/example/automl/VisionClassificationPredictIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt
8686
String inputUri = String.format("gs://%s/batch_predict_test.csv", BUCKET_ID);
8787
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
8888
// Act
89-
VisionBatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri);
89+
BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri);
9090

9191
// Assert
9292
String got = bout.toString();

automl/cloud-client/src/test/java/com/example/automl/VisionObjectDetectionPredictIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public void testBatchPredict() throws IOException, ExecutionException, Interrupt
8888
String.format("gs://%s/vision_object_detection_batch_predict_test.csv", BUCKET_ID);
8989
String outputUri = String.format("gs://%s/TEST_BATCH_PREDICT/", BUCKET_ID);
9090
// Act
91-
VisionBatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri);
91+
BatchPredict.batchPredict(PROJECT_ID, modelId, inputUri, outputUri);
9292

9393
// Assert
9494
String got = bout.toString();

0 commit comments

Comments
 (0)