20
20
import static junit .framework .TestCase .assertNotNull ;
21
21
22
22
import com .google .cloud .testing .junit4 .MultipleAttemptsRule ;
23
- import java .io .ByteArrayOutputStream ;
24
23
import java .io .IOException ;
25
- import java .io .PrintStream ;
26
24
import java .util .List ;
27
25
import java .util .OptionalInt ;
28
- import org .junit .After ;
29
- import org .junit .Before ;
30
26
import org .junit .BeforeClass ;
31
27
import org .junit .Rule ;
32
28
import org .junit .Test ;
@@ -35,9 +31,6 @@ public class PredictTextEmbeddingsSampleTest {
35
31
@ Rule public final MultipleAttemptsRule multipleAttemptsRule = new MultipleAttemptsRule (3 );
36
32
private static final String APIS_ENDPOINT = "us-central1-aiplatform.googleapis.com:443" ;
37
33
private static final String PROJECT = System .getenv ("UCAIP_PROJECT_ID" );
38
- private ByteArrayOutputStream bout ;
39
- private PrintStream out ;
40
- private PrintStream originalPrintStream ;
41
34
42
35
private static void requireEnvVar (String varName ) {
43
36
String errorMessage =
@@ -51,40 +44,30 @@ public static void checkRequirements() {
51
44
requireEnvVar ("UCAIP_PROJECT_ID" );
52
45
}
53
46
54
- @ Before
55
- public void setUp () {
56
- bout = new ByteArrayOutputStream ();
57
- out = new PrintStream (bout );
58
- originalPrintStream = System .out ;
59
- System .setOut (out );
60
- }
61
-
62
- @ After
63
- public void tearDown () {
64
- System .out .flush ();
65
- System .setOut (originalPrintStream );
66
- }
67
-
68
47
@ Test
69
48
public void testPredictTextEmbeddings () throws IOException {
70
49
List <String > texts =
71
50
List .of ("banana bread?" , "banana muffin?" , "banana?" , "recipe?" , "muffin recipe?" );
72
- PredictTextEmbeddingsSample .predictTextEmbeddings (
73
- APIS_ENDPOINT , PROJECT , "textembedding-gecko@003" , texts , "RETRIEVAL_DOCUMENT" );
74
- assertThat (bout .toString ()).contains ("Got predict response" );
51
+ List <List <Float >> embeddings =
52
+ PredictTextEmbeddingsSample .predictTextEmbeddings (
53
+ APIS_ENDPOINT , PROJECT , "textembedding-gecko@003" , texts , "RETRIEVAL_DOCUMENT" );
54
+ assertThat (embeddings .size ()).isEqualTo (texts .size ());
55
+ assertThat (embeddings .get (0 ).size ()).isEqualTo (768 );
75
56
}
76
57
77
58
@ Test
78
59
public void testPredictTextEmbeddingsPreview () throws IOException {
79
60
List <String > texts =
80
61
List .of ("banana bread?" , "banana muffin?" , "banana?" , "recipe?" , "muffin recipe?" );
81
- PredictTextEmbeddingsSamplePreview .predictTextEmbeddings (
82
- APIS_ENDPOINT ,
83
- PROJECT ,
84
- "text-embedding-preview-0409" ,
85
- texts ,
86
- "QUESTION_ANSWERING" ,
87
- OptionalInt .of (256 ));
88
- assertThat (bout .toString ()).contains ("Got predict response" );
62
+ List <List <Float >> embeddings =
63
+ PredictTextEmbeddingsSamplePreview .predictTextEmbeddings (
64
+ APIS_ENDPOINT ,
65
+ PROJECT ,
66
+ "text-embedding-preview-0409" ,
67
+ texts ,
68
+ "QUESTION_ANSWERING" ,
69
+ OptionalInt .of (5 ));
70
+ assertThat (embeddings .size ()).isEqualTo (texts .size ());
71
+ assertThat (embeddings .get (0 ).size ()).isEqualTo (5 );
89
72
}
90
73
}
0 commit comments