Skip to content

Commit e01cc42

Browse files
fix: Elastic Text-Embedding demo. (GoogleCloudPlatform#9280)
1 parent c54658b commit e01cc42

File tree

3 files changed

+42
-43
lines changed

3 files changed

+42
-43
lines changed

aiplatform/src/main/java/aiplatform/PredictTextEmbeddingsSample.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package aiplatform;
1818

1919
// [START aiplatform_sdk_embedding]
20+
import static java.util.stream.Collectors.toList;
21+
2022
import com.google.cloud.aiplatform.v1beta1.EndpointName;
2123
import com.google.cloud.aiplatform.v1beta1.PredictRequest;
2224
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
@@ -25,6 +27,7 @@
2527
import com.google.protobuf.Struct;
2628
import com.google.protobuf.Value;
2729
import java.io.IOException;
30+
import java.util.ArrayList;
2831
import java.util.List;
2932
import java.util.regex.Matcher;
3033
import java.util.regex.Pattern;
@@ -46,12 +49,8 @@ public static void main(String[] args) throws IOException {
4649
}
4750

4851
// Gets text embeddings from a pretrained, foundational model.
49-
public static void predictTextEmbeddings(
50-
String endpoint,
51-
String project,
52-
String model,
53-
List<String> texts,
54-
String task)
52+
public static List<List<Float>> predictTextEmbeddings(
53+
String endpoint, String project, String model, List<String> texts, String task)
5554
throws IOException {
5655
PredictionServiceSettings settings =
5756
PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
@@ -74,10 +73,17 @@ public static void predictTextEmbeddings(
7473
.build()));
7574
}
7675
PredictResponse response = client.predict(request.build());
77-
System.out.println("Got predict response:\n");
76+
List<List<Float>> floats = new ArrayList<>();
7877
for (Value prediction : response.getPredictionsList()) {
79-
System.out.format("Got prediction: %s\n", prediction);
78+
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
79+
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
80+
floats.add(
81+
values.getListValue().getValuesList().stream()
82+
.map(Value::getNumberValue)
83+
.map(Double::floatValue)
84+
.collect(toList()));
8085
}
86+
return floats;
8187
}
8288
}
8389

aiplatform/src/main/java/aiplatform/PredictTextEmbeddingsSamplePreview.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package aiplatform;
1818

1919
// [START generativeaionvertexai_sdk_embedding]
20+
import static java.util.stream.Collectors.toList;
21+
2022
import com.google.cloud.aiplatform.v1beta1.EndpointName;
2123
import com.google.cloud.aiplatform.v1beta1.PredictRequest;
2224
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
@@ -25,6 +27,7 @@
2527
import com.google.protobuf.Struct;
2628
import com.google.protobuf.Value;
2729
import java.io.IOException;
30+
import java.util.ArrayList;
2831
import java.util.List;
2932
import java.util.OptionalInt;
3033
import java.util.regex.Matcher;
@@ -49,7 +52,7 @@ public static void main(String[] args) throws IOException {
4952
}
5053

5154
// Gets text embeddings from a pretrained, foundational model.
52-
public static void predictTextEmbeddings(
55+
public static List<List<Float>> predictTextEmbeddings(
5356
String endpoint,
5457
String project,
5558
String model,
@@ -86,10 +89,17 @@ public static void predictTextEmbeddings(
8689
.build()));
8790
}
8891
PredictResponse response = client.predict(request.build());
89-
System.out.println("Got predict response:\n");
92+
List<List<Float>> floats = new ArrayList<>();
9093
for (Value prediction : response.getPredictionsList()) {
91-
System.out.format("Got prediction: %s\n", prediction);
94+
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
95+
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
96+
floats.add(
97+
values.getListValue().getValuesList().stream()
98+
.map(Value::getNumberValue)
99+
.map(Double::floatValue)
100+
.collect(toList()));
92101
}
102+
return floats;
93103
}
94104
}
95105

aiplatform/src/test/java/aiplatform/PredictTextEmbeddingsSampleTest.java

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,9 @@
2020
import static junit.framework.TestCase.assertNotNull;
2121

2222
import com.google.cloud.testing.junit4.MultipleAttemptsRule;
23-
import java.io.ByteArrayOutputStream;
2423
import java.io.IOException;
25-
import java.io.PrintStream;
2624
import java.util.List;
2725
import java.util.OptionalInt;
28-
import org.junit.After;
29-
import org.junit.Before;
3026
import org.junit.BeforeClass;
3127
import org.junit.Rule;
3228
import org.junit.Test;
@@ -35,9 +31,6 @@ public class PredictTextEmbeddingsSampleTest {
3531
@Rule public final MultipleAttemptsRule multipleAttemptsRule = new MultipleAttemptsRule(3);
3632
private static final String APIS_ENDPOINT = "us-central1-aiplatform.googleapis.com:443";
3733
private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
38-
private ByteArrayOutputStream bout;
39-
private PrintStream out;
40-
private PrintStream originalPrintStream;
4134

4235
private static void requireEnvVar(String varName) {
4336
String errorMessage =
@@ -51,40 +44,30 @@ public static void checkRequirements() {
5144
requireEnvVar("UCAIP_PROJECT_ID");
5245
}
5346

54-
@Before
55-
public void setUp() {
56-
bout = new ByteArrayOutputStream();
57-
out = new PrintStream(bout);
58-
originalPrintStream = System.out;
59-
System.setOut(out);
60-
}
61-
62-
@After
63-
public void tearDown() {
64-
System.out.flush();
65-
System.setOut(originalPrintStream);
66-
}
67-
6847
@Test
6948
public void testPredictTextEmbeddings() throws IOException {
7049
List<String> texts =
7150
List.of("banana bread?", "banana muffin?", "banana?", "recipe?", "muffin recipe?");
72-
PredictTextEmbeddingsSample.predictTextEmbeddings(
73-
APIS_ENDPOINT, PROJECT, "textembedding-gecko@003", texts, "RETRIEVAL_DOCUMENT");
74-
assertThat(bout.toString()).contains("Got predict response");
51+
List<List<Float>> embeddings =
52+
PredictTextEmbeddingsSample.predictTextEmbeddings(
53+
APIS_ENDPOINT, PROJECT, "textembedding-gecko@003", texts, "RETRIEVAL_DOCUMENT");
54+
assertThat(embeddings.size()).isEqualTo(texts.size());
55+
assertThat(embeddings.get(0).size()).isEqualTo(768);
7556
}
7657

7758
@Test
7859
public void testPredictTextEmbeddingsPreview() throws IOException {
7960
List<String> texts =
8061
List.of("banana bread?", "banana muffin?", "banana?", "recipe?", "muffin recipe?");
81-
PredictTextEmbeddingsSamplePreview.predictTextEmbeddings(
82-
APIS_ENDPOINT,
83-
PROJECT,
84-
"text-embedding-preview-0409",
85-
texts,
86-
"QUESTION_ANSWERING",
87-
OptionalInt.of(256));
88-
assertThat(bout.toString()).contains("Got predict response");
62+
List<List<Float>> embeddings =
63+
PredictTextEmbeddingsSamplePreview.predictTextEmbeddings(
64+
APIS_ENDPOINT,
65+
PROJECT,
66+
"text-embedding-preview-0409",
67+
texts,
68+
"QUESTION_ANSWERING",
69+
OptionalInt.of(5));
70+
assertThat(embeddings.size()).isEqualTo(texts.size());
71+
assertThat(embeddings.get(0).size()).isEqualTo(5);
8972
}
9073
}

0 commit comments

Comments
 (0)