Skip to content

Commit 61c9d9e

Browse files
feat: Text-embedding model tuning demo. (GoogleCloudPlatform#9272)
* feat: Text-embedding model tuning demo. * feat: Text-embedding model tuning demo. * feat: Text-embedding model tuning demo. * feat: Text-embedding model tuning demo. * feat: Text-embedding model tuning demo. * feat: Text-embedding model tuning demo. * feat: Text-Embedding Model Tuning demo. * feat: Text-Embedding Model Tuning demo. * feat: Text-Embedding Model Tuning demo. * feat: Text-Embedding Model Tuning demo. * feat: Text-Embedding Model Tuning demo. * feat: Text-Embedding Model Tuning demo. * feat: Text-Embedding Model Tuning demo.
1 parent 0d0f15d commit 61c9d9e

File tree

3 files changed

+284
-0
lines changed

3 files changed

+284
-0
lines changed

aiplatform/pom.xml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,18 @@
7676
<groupId>com.google.cloud</groupId>
7777
<artifactId>google-cloud-bigquery</artifactId>
7878
</dependency>
79+
<!-- Used for retry logic in EmbeddingModelTuningSampleTest.java -->
80+
<dependency>
81+
<groupId>io.github.resilience4j</groupId>
82+
<artifactId>resilience4j-core</artifactId>
83+
<version>1.7.1</version>
84+
<scope>test</scope>
85+
</dependency>
86+
<dependency>
87+
<groupId>io.github.resilience4j</groupId>
88+
<artifactId>resilience4j-retry</artifactId>
89+
<version>1.7.1</version>
90+
<scope>test</scope>
91+
</dependency>
7992
</dependencies>
8093
</project>
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
// [START aiplatform_sdk_embedding_model_tuning]
20+
import com.google.cloud.aiplatform.v1.CreatePipelineJobRequest;
21+
import com.google.cloud.aiplatform.v1.LocationName;
22+
import com.google.cloud.aiplatform.v1.PipelineJob;
23+
import com.google.cloud.aiplatform.v1.PipelineJob.RuntimeConfig;
24+
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
25+
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
26+
import com.google.protobuf.Value;
27+
import java.io.IOException;
28+
import java.util.Map;
29+
import java.util.regex.Matcher;
30+
import java.util.regex.Pattern;
31+
32+
// [END aiplatform_sdk_embedding_model_tuning]
33+
34+
public class EmbeddingModelTuningSample {
35+
public static void main(String[] args) throws IOException {
36+
// [START aiplatform_sdk_embedding_model_tuning]
37+
// TODO(developer): Replace these variables before running this sample.
38+
String apiEndpoint = "us-central1-aiplatform.googleapis.com:443";
39+
String project = "PROJECT";
40+
String baseModelVersionId = "BASE_MODEL_VERSION_ID";
41+
String taskType = "DEFAULT";
42+
String pipelineJobDisplayName = "PIPELINE_JOB_DISPLAY_NAME";
43+
String outputDir = "OUTPUT_DIR";
44+
String queriesPath = "QUERIES";
45+
String corpusPath = "CORPUS";
46+
String trainLabelPath = "TRAIN_LABEL";
47+
String testLabelPath = "TEST_LABEL";
48+
int batchSize = 128;
49+
int iterations = 1000;
50+
51+
createEmbeddingModelTuningPipelineJob(
52+
apiEndpoint,
53+
project,
54+
baseModelVersionId,
55+
taskType,
56+
pipelineJobDisplayName,
57+
outputDir,
58+
queriesPath,
59+
corpusPath,
60+
trainLabelPath,
61+
testLabelPath,
62+
batchSize,
63+
iterations);
64+
// [END aiplatform_sdk_embedding_model_tuning]
65+
}
66+
67+
// [START aiplatform_sdk_embedding_model_tuning]
68+
public static PipelineJob createEmbeddingModelTuningPipelineJob(
69+
String apiEndpoint,
70+
String project,
71+
String baseModelVersionId,
72+
String taskType,
73+
String pipelineJobDisplayName,
74+
String outputDir,
75+
String queriesPath,
76+
String corpusPath,
77+
String trainLabelPath,
78+
String testLabelPath,
79+
int batchSize,
80+
int iterations)
81+
throws IOException {
82+
Matcher matcher = Pattern.compile("^(?<Location>\\w+-\\w+)").matcher(apiEndpoint);
83+
String location = matcher.matches() ? matcher.group("Location") : "us-central1";
84+
String templateUri =
85+
"https://us-kfp.pkg.dev/ml-pipeline/llm-text-embedding/tune-text-embedding-model/v1.1.2";
86+
PipelineServiceSettings settings =
87+
PipelineServiceSettings.newBuilder().setEndpoint(apiEndpoint).build();
88+
try (PipelineServiceClient client = PipelineServiceClient.create(settings)) {
89+
Map<String, Value> parameterValues =
90+
Map.of(
91+
"project", valueOf(project),
92+
"base_model_version_id", valueOf(baseModelVersionId),
93+
"task_type", valueOf(taskType),
94+
"location", valueOf(location),
95+
"queries_path", valueOf(queriesPath),
96+
"corpus_path", valueOf(corpusPath),
97+
"train_label_path", valueOf(trainLabelPath),
98+
"test_label_path", valueOf(testLabelPath),
99+
"batch_size", valueOf(batchSize),
100+
"iterations", valueOf(iterations));
101+
PipelineJob pipelineJob =
102+
PipelineJob.newBuilder()
103+
.setTemplateUri(templateUri)
104+
.setDisplayName(pipelineJobDisplayName)
105+
.setRuntimeConfig(
106+
RuntimeConfig.newBuilder()
107+
.setGcsOutputDirectory(outputDir)
108+
.putAllParameterValues(parameterValues)
109+
.build())
110+
.build();
111+
CreatePipelineJobRequest request =
112+
CreatePipelineJobRequest.newBuilder()
113+
.setParent(LocationName.of(project, location).toString())
114+
.setPipelineJob(pipelineJob)
115+
.build();
116+
return client.createPipelineJob(request);
117+
}
118+
}
119+
120+
private static Value valueOf(String s) {
121+
return Value.newBuilder().setStringValue(s).build();
122+
}
123+
124+
private static Value valueOf(int n) {
125+
return Value.newBuilder().setNumberValue(n).build();
126+
}
127+
// [END aiplatform_sdk_embedding_model_tuning]
128+
}
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package aiplatform;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static java.util.stream.Collectors.toList;
21+
import static junit.framework.TestCase.assertNotNull;
22+
23+
import com.google.api.gax.longrunning.OperationFuture;
24+
import com.google.cloud.aiplatform.v1.CancelPipelineJobRequest;
25+
import com.google.cloud.aiplatform.v1.DeleteOperationMetadata;
26+
import com.google.cloud.aiplatform.v1.PipelineJob;
27+
import com.google.cloud.aiplatform.v1.PipelineServiceClient;
28+
import com.google.cloud.aiplatform.v1.PipelineServiceSettings;
29+
import com.google.cloud.aiplatform.v1.PipelineState;
30+
import com.google.cloud.testing.junit4.MultipleAttemptsRule;
31+
import com.google.protobuf.Empty;
32+
import io.github.resilience4j.retry.Retry;
33+
import io.github.resilience4j.retry.RetryConfig;
34+
import io.github.resilience4j.retry.RetryRegistry;
35+
import io.vavr.CheckedRunnable;
36+
import java.io.IOException;
37+
import java.time.Duration;
38+
import java.util.LinkedList;
39+
import java.util.List;
40+
import java.util.Queue;
41+
import java.util.concurrent.TimeUnit;
42+
import java.util.concurrent.TimeoutException;
43+
import org.junit.AfterClass;
44+
import org.junit.BeforeClass;
45+
import org.junit.Rule;
46+
import org.junit.Test;
47+
import org.junit.runner.RunWith;
48+
import org.junit.runners.JUnit4;
49+
50+
@RunWith(JUnit4.class)
51+
public class EmbeddingModelTuningSampleTest {
52+
@Rule public final MultipleAttemptsRule multipleAttemptsRule = new MultipleAttemptsRule(3);
53+
54+
private static final String API_ENDPOINT = "us-central1-aiplatform.googleapis.com:443";
55+
private static final String PROJECT = System.getenv("UCAIP_PROJECT_ID");
56+
private static final String BASE_MODEL_VERSION_ID = "textembedding-gecko@003";
57+
private static final String TASK_TYPE = "DEFAULT";
58+
private static final String JOB_DISPLAY_NAME = "embedding-customization-pipeline-sample";
59+
private static final String QUERIES =
60+
"gs://embedding-customization-pipeline/dataset/queries.jsonl";
61+
private static final String CORPUS = "gs://embedding-customization-pipeline/dataset/corpus.jsonl";
62+
private static final String TRAIN_LABEL =
63+
"gs://embedding-customization-pipeline/dataset/train.tsv";
64+
private static final String TEST_LABEL = "gs://embedding-customization-pipeline/dataset/test.tsv";
65+
private static final String OUTPUT_DIR =
66+
"gs://ucaip-samples-us-central1/training_pipeline_output";
67+
private static final int BATCH_SIZE = 50;
68+
private static final int ITERATIONS = 300;
69+
70+
private static Queue<String> JobNames = new LinkedList<String>();
71+
private static final RetryConfig RETRY_CONFIG =
72+
RetryConfig.custom()
73+
.maxAttempts(30)
74+
.waitDuration(Duration.ofSeconds(6))
75+
.retryExceptions(TimeoutException.class)
76+
.failAfterMaxAttempts(false)
77+
.build();
78+
private static final RetryRegistry RETRY_REGISTRY = RetryRegistry.of(RETRY_CONFIG);
79+
80+
private static void requireEnvVar(String varName) {
81+
String errorMessage = String.format("Test requires environment variable '%s'.", varName);
82+
assertNotNull(errorMessage, System.getenv(varName));
83+
}
84+
85+
@BeforeClass
86+
public static void checkRequirements() {
87+
requireEnvVar("GOOGLE_APPLICATION_CREDENTIALS");
88+
requireEnvVar("UCAIP_PROJECT_ID");
89+
}
90+
91+
@AfterClass
92+
public static void tearDown() throws Throwable {
93+
PipelineServiceSettings settings =
94+
PipelineServiceSettings.newBuilder().setEndpoint(API_ENDPOINT).build();
95+
try (PipelineServiceClient client = PipelineServiceClient.create(settings)) {
96+
List<CancelPipelineJobRequest> requests =
97+
JobNames.stream()
98+
.map(n -> CancelPipelineJobRequest.newBuilder().setName(n).build())
99+
.collect(toList());
100+
CheckedRunnable runnable =
101+
Retry.decorateCheckedRunnable(
102+
RETRY_REGISTRY.retry("delete-pipeline-jobs", RETRY_CONFIG),
103+
() -> {
104+
List<OperationFuture<Empty, DeleteOperationMetadata>> deletions =
105+
requests.stream()
106+
.map(
107+
req -> {
108+
client.cancelPipelineJobCallable().futureCall(req);
109+
return client.deletePipelineJobAsync(req.getName());
110+
})
111+
.collect(toList());
112+
for (OperationFuture<Empty, DeleteOperationMetadata> d : deletions) {
113+
d.get(0, TimeUnit.SECONDS);
114+
}
115+
});
116+
try {
117+
runnable.run();
118+
} catch (TimeoutException e) {
119+
// Do nothing.
120+
}
121+
}
122+
}
123+
124+
@Test
125+
public void createPipelineJobEmbeddingModelTuningSample() throws IOException {
126+
PipelineJob job =
127+
EmbeddingModelTuningSample.createEmbeddingModelTuningPipelineJob(
128+
API_ENDPOINT,
129+
PROJECT,
130+
BASE_MODEL_VERSION_ID,
131+
TASK_TYPE,
132+
JOB_DISPLAY_NAME,
133+
OUTPUT_DIR,
134+
QUERIES,
135+
CORPUS,
136+
TRAIN_LABEL,
137+
TEST_LABEL,
138+
BATCH_SIZE,
139+
ITERATIONS);
140+
assertThat(job.getState()).isNotEqualTo(PipelineState.PIPELINE_STATE_FAILED);
141+
JobNames.add(job.getName());
142+
}
143+
}

0 commit comments

Comments
 (0)