Skip to content

Commit cee3e73

Browse files
authored
feat: add code samples for distillation (GoogleCloudPlatform#11156)
* feat: add code samples for distillation * fix test * fix test
1 parent 555b9d6 commit cee3e73

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

generative_ai/distillation.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_sdk_distillation]
16+
from __future__ import annotations
17+
18+
19+
from typing import Optional
20+
21+
22+
from google.auth import default
23+
import vertexai
24+
from vertexai.preview.language_models import TextGenerationModel, TuningEvaluationSpec
25+
26+
27+
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
28+
29+
30+
def distill_model(
31+
project_id: str,
32+
location: str,
33+
dataset: str,
34+
teacher_model: str,
35+
train_steps: int = 300,
36+
evaluation_dataset: Optional[str] = None,
37+
) -> None:
38+
"""Distill a new model.
39+
40+
Args:
41+
project_id: GCP Project ID, used to initialize vertexai
42+
location: GCP Region, used to initialize vertexai
43+
dataset: GCS URI of jsonl file.
44+
teacher_model: Name of the teacher model.
45+
train_steps: Number of training steps to use when tuning the model.
46+
evaluation_dataset: GCS URI of jsonl file of evaluation data.
47+
"""
48+
vertexai.init(project=project_id, location=location, credentials=credentials)
49+
50+
eval_spec = TuningEvaluationSpec(evaluation_data=evaluation_dataset)
51+
52+
student_model = TextGenerationModel.from_pretrained("text-bison@002")
53+
student_model.distill_from(
54+
teacher_model=teacher_model,
55+
dataset=dataset,
56+
# Optional:
57+
train_steps=train_steps,
58+
tuning_job_location="europe-west4",
59+
tuned_model_location=location,
60+
tuning_evaluation_spec=eval_spec,
61+
)
62+
63+
print(student_model._job.status)
64+
return student_model
65+
66+
67+
if __name__ == "__main__":
68+
distill_model()
69+
# [END aiplatform_sdk_distillation]

generative_ai/distillation_test.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import uuid
17+
18+
from google.cloud import aiplatform
19+
from google.cloud import storage
20+
from google.cloud.aiplatform.compat.types import pipeline_state
21+
import pytest
22+
from vertexai.preview.language_models import TextGenerationModel
23+
24+
import distillation
25+
26+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
27+
_LOCATION = "us-central1"
28+
_BUCKET = os.environ["CLOUD_STORAGE_BUCKET"]
29+
30+
31+
def get_model_display_name(tuned_model: TextGenerationModel) -> str:
32+
language_model_tuning_job = tuned_model._job
33+
pipeline_job = language_model_tuning_job._job
34+
return dict(pipeline_job._gca_resource.runtime_config.parameter_values)[
35+
"model_display_name"
36+
]
37+
38+
39+
def upload_to_gcs(bucket: str, name: str, data: str) -> None:
40+
client = storage.Client()
41+
bucket = client.get_bucket(bucket)
42+
blob = bucket.blob(name)
43+
blob.upload_from_string(data)
44+
45+
46+
def download_from_gcs(bucket: str, name: str) -> str:
47+
client = storage.Client()
48+
bucket = client.get_bucket(bucket)
49+
blob = bucket.blob(name)
50+
data = blob.download_as_bytes()
51+
return "\n".join(data.decode().splitlines()[:10])
52+
53+
54+
def delete_from_gcs(bucket: str, name: str) -> None:
55+
client = storage.Client()
56+
bucket = client.get_bucket(bucket)
57+
blob = bucket.blob(name)
58+
blob.delete()
59+
60+
61+
@pytest.fixture(scope="function")
62+
def training_data_filename() -> str:
63+
temp_filename = f"{uuid.uuid4()}.jsonl"
64+
data = download_from_gcs(
65+
"cloud-samples-data", "ai-platform/generative_ai/headline_classification.jsonl"
66+
)
67+
upload_to_gcs(_BUCKET, temp_filename, data)
68+
try:
69+
yield f"gs://{_BUCKET}/{temp_filename}"
70+
finally:
71+
delete_from_gcs(_BUCKET, temp_filename)
72+
73+
74+
def teardown_model(
75+
tuned_model: TextGenerationModel, training_data_filename: str
76+
) -> None:
77+
for tuned_model_name in tuned_model.list_tuned_model_names():
78+
model_registry = aiplatform.models.ModelRegistry(model=tuned_model_name)
79+
if (
80+
training_data_filename
81+
in model_registry.get_version_info("1").model_display_name
82+
):
83+
display_name = model_registry.get_version_info("1").model_display_name
84+
for endpoint in aiplatform.Endpoint.list():
85+
for _ in endpoint.list_models():
86+
if endpoint.display_name == display_name:
87+
endpoint.undeploy_all()
88+
endpoint.delete()
89+
aiplatform.Model(model_registry.model_resource_name).delete()
90+
91+
92+
def test_distill_model(training_data_filename: str) -> None:
93+
"""Takes approx. 60 minutes."""
94+
student_model = distillation.distill_model(
95+
dataset=training_data_filename,
96+
teacher_model="text-unicorn@001",
97+
project_id=_PROJECT_ID,
98+
location=_LOCATION,
99+
train_steps=1,
100+
evaluation_dataset=training_data_filename,
101+
)
102+
try:
103+
assert (
104+
student_model._job.status
105+
== pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
106+
)
107+
finally:
108+
teardown_model(student_model, training_data_filename)

0 commit comments

Comments
 (0)