Skip to content

Commit 583de9f

Browse files
authored
feat(generative_ai): add model evaluation sample (GoogleCloudPlatform#11122)
* feat(generative_ai): add model evaluation sample * set credentials required by pipeline components
1 parent 93429e0 commit 583de9f

File tree

2 files changed

+93
-0
lines changed

2 files changed

+93
-0
lines changed

generative_ai/evaluate_model.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_evaluate_model]
16+
17+
from google.auth import default
18+
import vertexai
19+
from vertexai.preview.language_models import (
20+
EvaluationTextClassificationSpec,
21+
TextGenerationModel,
22+
)
23+
24+
# Set credentials for the pipeline components used in the evaluation task
25+
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
26+
27+
28+
def evaluate_model(
29+
project_id: str,
30+
location: str,
31+
) -> object:
32+
"""Evaluate the performance of a generative AI model."""
33+
34+
vertexai.init(project=project_id, location=location, credentials=credentials)
35+
36+
# Create a reference to a generative AI model
37+
model = TextGenerationModel.from_pretrained("text-bison@001")
38+
39+
# Define the evaluation specification for a text classification task
40+
task_spec = EvaluationTextClassificationSpec(
41+
ground_truth_data=[
42+
"gs://cloud-samples-data/ai-platform/generative_ai/llm_classification_bp_input_prompts_with_ground_truth.jsonl"
43+
],
44+
class_names=["nature", "news", "sports", "health", "startups"],
45+
target_column_name="ground_truth",
46+
)
47+
48+
# Evaluate the model
49+
eval_metrics = model.evaluate(task_spec=task_spec)
50+
print(eval_metrics)
51+
52+
return eval_metrics
53+
54+
55+
# [END aiplatform_evaluate_model]
56+
57+
58+
if __name__ == "__main__":
59+
evaluate_model()

generative_ai/evaluate_model_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
from google.api_core.exceptions import ResourceExhausted
19+
20+
import evaluate_model
21+
22+
23+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
24+
_LOCATION = "us-central1"
25+
26+
27+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
28+
def test_evaluate_model() -> None:
29+
eval_metrics = evaluate_model.evaluate_model(
30+
_PROJECT_ID,
31+
_LOCATION,
32+
)
33+
34+
assert hasattr(eval_metrics, "auRoc")

0 commit comments

Comments
 (0)