Skip to content

Commit fc01e7e

Browse files
sararobcopybara-github
authored andcommitted
feat: GenAI SDK client - Add experimental prompt_management module with create_version and get methods
PiperOrigin-RevId: 795536822
1 parent a9db14f commit fc01e7e

File tree

8 files changed

+4076
-166
lines changed

8 files changed

+4076
-166
lines changed

tests/unit/vertexai/genai/replays/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def client(use_vertex, replays_prefix, http_options, request):
184184
os.path.dirname(__file__),
185185
"credentials.json",
186186
)
187-
os.environ["GOOGLE_CLOUD_PROJECT"] = "project-id"
188-
os.environ["GOOGLE_CLOUD_LOCATION"] = "location"
187+
os.environ["GOOGLE_CLOUD_PROJECT"] = "vertex-sdk-dev"
188+
os.environ["GOOGLE_CLOUD_LOCATION"] = "us-central1"
189189
os.environ["VAPO_CONFIG_PATH"] = "gs://dummy-test/dummy-config.json"
190190
os.environ["VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"] = "1234567890"
191191
os.environ["GCS_BUCKET"] = "test-bucket"
Lines changed: 339 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
from google.genai import types as genai_types
20+
21+
22+
TEST_PROMPT_DATASET_ID = "8005484238453342208"
23+
TEST_VARIABLES = [
24+
{"name": genai_types.Part(text="Alice")},
25+
{"name": genai_types.Part(text="Bob")},
26+
]
27+
TEST_PROMPT = types.Prompt(
28+
prompt_data=types.PromptData(
29+
contents=[
30+
genai_types.Content(
31+
role="user",
32+
parts=[genai_types.Part(text="Hello, {name}! How are you?")],
33+
)
34+
],
35+
safety_settings=[
36+
genai_types.SafetySetting(
37+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
38+
threshold="BLOCK_MEDIUM_AND_ABOVE",
39+
method="SEVERITY",
40+
),
41+
],
42+
generation_config=genai_types.GenerationConfig(temperature=0.1),
43+
system_instruction=genai_types.Content(
44+
parts=[genai_types.Part(text="Please answer in a short sentence.")]
45+
),
46+
),
47+
variables=TEST_VARIABLES,
48+
model_name="gemini-2.0-flash-001",
49+
)
50+
TEST_CONFIG = types.CreatePromptConfig(
51+
prompt_display_name="my_prompt",
52+
version_display_name="my_version",
53+
)
54+
55+
56+
def test_create_dataset(client):
57+
create_dataset_operation = client.prompt_management._create_dataset_resource(
58+
config=types.CreateDatasetConfig(should_return_http_response=True),
59+
name="projects/vertex-sdk-dev/locations/us-central1",
60+
display_name="test display name",
61+
metadata_schema_uri="gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml",
62+
metadata={
63+
"promptType": "freeform",
64+
"promptApiSchema": {
65+
"multimodalPrompt": {
66+
"promptMessage": {
67+
"contents": [
68+
{
69+
"role": "user",
70+
"parts": [{"text": "Hello, {name}! How are you?"}],
71+
}
72+
],
73+
"safety_settings": [
74+
{
75+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
76+
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
77+
"method": "SEVERITY",
78+
}
79+
],
80+
"generation_config": {"temperature": 0.1},
81+
"model": "projects/vertex-sdk-dev/locations/us-central1/publishers/google/models/gemini-2.0-flash-001",
82+
"system_instruction": {
83+
"role": "user",
84+
"parts": [{"text": "Please answer in a short sentence."}],
85+
},
86+
}
87+
},
88+
"apiSchemaVersion": "1.0.0",
89+
"executions": [
90+
{
91+
"arguments": {
92+
"name": {"partList": {"parts": [{"text": "Alice"}]}}
93+
}
94+
},
95+
{"arguments": {"name": {"partList": {"parts": [{"text": "Bob"}]}}}},
96+
],
97+
},
98+
},
99+
model_reference="gemini-2.0-flash-001",
100+
)
101+
assert isinstance(create_dataset_operation, types.CreateDatasetOperationMetadata)
102+
assert create_dataset_operation.sdk_http_response.body is not None
103+
104+
105+
def test_create_dataset_version(client):
106+
dataset_version_resource = (
107+
client.prompt_management._create_dataset_version_resource(
108+
dataset_name=TEST_PROMPT_DATASET_ID,
109+
display_name="my new version yay",
110+
)
111+
)
112+
assert isinstance(
113+
dataset_version_resource, types.CreateDatasetVersionOperationMetadata
114+
)
115+
116+
117+
def test_create_version_e2e(client):
118+
prompt_resource = client.prompt_management.create_version(
119+
prompt=TEST_PROMPT,
120+
config=TEST_CONFIG,
121+
)
122+
assert isinstance(prompt_resource, types.Prompt)
123+
assert isinstance(prompt_resource.dataset, types.Dataset)
124+
125+
# Test local prompt resource is the same after calling get()
126+
retrieved_prompt = client.prompt_management.get(prompt_id=prompt_resource.prompt_id)
127+
assert (
128+
retrieved_prompt.prompt_data.system_instruction
129+
== prompt_resource.prompt_data.system_instruction
130+
)
131+
assert retrieved_prompt.variables[0]["name"].text == TEST_VARIABLES[0]["name"].text
132+
assert (
133+
retrieved_prompt.prompt_data.generation_config.temperature
134+
== prompt_resource.prompt_data.generation_config.temperature
135+
)
136+
assert (
137+
retrieved_prompt.prompt_data.safety_settings
138+
== prompt_resource.prompt_data.safety_settings
139+
)
140+
assert retrieved_prompt.model_name == prompt_resource.model_name
141+
assert (
142+
retrieved_prompt.prompt_data.generation_config.response_schema
143+
== prompt_resource.prompt_data.generation_config.response_schema
144+
)
145+
146+
# Test calling create_version again uses dataset from local Prompt resource.
147+
prompt_resource_2 = client.prompt_management.create_version(
148+
prompt=TEST_PROMPT,
149+
config=types.CreatePromptConfig(
150+
version_display_name="my_version",
151+
),
152+
)
153+
assert prompt_resource_2.dataset.name == prompt_resource.dataset.name
154+
155+
156+
def test_create_version_in_existing_dataset(client):
157+
prompt_resource = client.prompt_management.create_version(
158+
prompt=TEST_PROMPT,
159+
config=types.CreatePromptConfig(
160+
prompt_id=TEST_PROMPT_DATASET_ID,
161+
prompt_display_name=TEST_CONFIG.prompt_display_name,
162+
version_display_name="my_version_existing_dataset",
163+
),
164+
)
165+
assert isinstance(prompt_resource, types.Prompt)
166+
assert isinstance(prompt_resource.dataset, types.Dataset)
167+
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
168+
assert prompt_resource.dataset.name.endswith(TEST_PROMPT_DATASET_ID)
169+
170+
171+
def test_create_version_with_version_name(client):
172+
version_name = "a_new_version_yay"
173+
prompt_resource = client.prompt_management.create_version(
174+
prompt=TEST_PROMPT,
175+
config=types.CreatePromptConfig(
176+
version_display_name=version_name,
177+
),
178+
)
179+
assert isinstance(prompt_resource, types.Prompt)
180+
assert isinstance(prompt_resource.dataset, types.Dataset)
181+
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
182+
assert prompt_resource.dataset_version.display_name == version_name
183+
184+
185+
def test_create_version_with_inline_data(client):
186+
version_name = "prompt with inline data"
187+
188+
pink_square_blob = genai_types.Part(
189+
inline_data=(
190+
genai_types.Blob(
191+
display_name="pink_square.ppm",
192+
data=b"P3\n1 1\n255\n255 192 203\n",
193+
mime_type="image/x-portable-pixmap",
194+
)
195+
)
196+
)
197+
198+
prompt_resource = client.prompt_management.create_version(
199+
prompt=types.Prompt(
200+
prompt_data=types.PromptData(
201+
contents=[
202+
genai_types.Content(
203+
parts=[
204+
pink_square_blob,
205+
genai_types.Part(text="What is this image?"),
206+
]
207+
)
208+
],
209+
system_instruction=genai_types.Content(
210+
parts=[
211+
genai_types.Part(
212+
text="Answer in the style of Taylor Swift lyrics."
213+
)
214+
]
215+
),
216+
generation_config=genai_types.GenerationConfig(temperature=0.1),
217+
safety_settings=[
218+
genai_types.SafetySetting(
219+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
220+
threshold="BLOCK_MEDIUM_AND_ABOVE",
221+
method="SEVERITY",
222+
)
223+
],
224+
),
225+
model_name="gemini-2.0-flash-001",
226+
),
227+
config=types.CreatePromptConfig(
228+
prompt_display_name=TEST_CONFIG.prompt_display_name,
229+
version_display_name=version_name,
230+
),
231+
)
232+
assert isinstance(prompt_resource, types.Prompt)
233+
assert isinstance(prompt_resource.dataset, types.Dataset)
234+
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
235+
assert prompt_resource.dataset_version.display_name == version_name
236+
237+
# Confirm inline data is preserved when we retrieve the prompt.
238+
retrieved_prompt = client.prompt_management.get(
239+
prompt_id=prompt_resource.prompt_id,
240+
)
241+
assert (
242+
retrieved_prompt.prompt_data.contents[0].parts[0].inline_data.data
243+
== pink_square_blob.inline_data.data
244+
)
245+
assert (
246+
retrieved_prompt.prompt_data.contents[0].parts[0].inline_data.display_name
247+
== pink_square_blob.inline_data.display_name
248+
)
249+
250+
251+
def test_create_version_with_file_data(client):
252+
version_name = "prompt with file data"
253+
254+
audio_file_part = genai_types.Part(
255+
file_data=genai_types.FileData(
256+
file_uri="https://generativelanguage.googleapis.com/v1beta/files/57w3vpfomj71",
257+
mime_type="video/mp4",
258+
),
259+
)
260+
261+
prompt_resource = client.prompt_management.create_version(
262+
prompt=types.Prompt(
263+
prompt_data=types.PromptData(
264+
contents=[
265+
genai_types.Content(
266+
parts=[
267+
audio_file_part,
268+
genai_types.Part(text="What is this recording about?"),
269+
]
270+
)
271+
],
272+
system_instruction=genai_types.Content(
273+
parts=[
274+
genai_types.Part(
275+
text="Answer in the style of Taylor Swift lyrics."
276+
)
277+
]
278+
),
279+
generation_config=genai_types.GenerationConfig(temperature=0.1),
280+
safety_settings=[
281+
genai_types.SafetySetting(
282+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
283+
threshold="BLOCK_MEDIUM_AND_ABOVE",
284+
method="SEVERITY",
285+
)
286+
],
287+
),
288+
model_name="gemini-2.0-flash-001",
289+
),
290+
config=types.CreatePromptConfig(
291+
version_display_name=version_name,
292+
prompt_display_name="my prompt with file data",
293+
),
294+
)
295+
assert isinstance(prompt_resource, types.Prompt)
296+
assert isinstance(prompt_resource.dataset, types.Dataset)
297+
assert isinstance(prompt_resource.dataset_version, types.DatasetVersion)
298+
assert prompt_resource.dataset_version.display_name == version_name
299+
300+
# Confirm file data is preserved when we retrieve the prompt.
301+
retrieved_prompt = client.prompt_management.get(
302+
prompt_id=prompt_resource.prompt_id,
303+
)
304+
assert (
305+
retrieved_prompt.prompt_data.contents[0].parts[0].file_data.file_uri
306+
== audio_file_part.file_data.file_uri
307+
)
308+
assert (
309+
retrieved_prompt.prompt_data.contents[0].parts[0].file_data.display_name
310+
== audio_file_part.file_data.display_name
311+
)
312+
313+
314+
def test_prompt_id_overrides_local_prompt(client):
315+
prompt_resource = client.prompt_management.create_version(
316+
prompt=types.Prompt(**TEST_PROMPT.model_dump()),
317+
config=types.CreatePromptConfig(
318+
prompt_id="2966871049100066816",
319+
prompt_display_name=TEST_CONFIG.prompt_display_name,
320+
),
321+
)
322+
323+
# Passing in prompt_id should override prompt_resource.prompt_id
324+
new_prompt_resource = client.prompt_management.create_version(
325+
prompt=types.Prompt(**TEST_PROMPT.model_dump()),
326+
config=types.CreatePromptConfig(
327+
prompt_id=TEST_PROMPT_DATASET_ID,
328+
prompt_display_name=TEST_CONFIG.prompt_display_name,
329+
),
330+
)
331+
assert new_prompt_resource.prompt_id == TEST_PROMPT_DATASET_ID
332+
assert new_prompt_resource.prompt_id != prompt_resource.prompt_id
333+
334+
335+
pytestmark = pytest_helper.setup(
336+
file=__file__,
337+
globals_for_file=globals(),
338+
test_method="prompt_management.create_version",
339+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
20+
21+
def test_get_dataset_operation(client):
22+
dataset_operation = client.prompt_management._get_dataset_operation(
23+
config=types.GetDatasetOperationConfig(should_return_http_response=True),
24+
dataset_id="6550997480673116160",
25+
operation_id="5108504762664353792",
26+
)
27+
assert dataset_operation.sdk_http_response.body is not None
28+
29+
30+
pytestmark = pytest_helper.setup(
31+
file=__file__,
32+
globals_for_file=globals(),
33+
test_method="prompt_management._get_dataset_operation",
34+
)

0 commit comments

Comments
 (0)