Skip to content

Commit bb3b34f

Browse files
fix(generative-ai): Add tests for function_calling code snippet (GoogleCloudPlatform#11019)
* Add tests for function_calling.py * Add test file * Update test * Move main construct out of region tag * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * Update google-cloud-aiplatform to 1.38.0 * Update sample region tag --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent 74bcfbd commit bb3b34f

File tree

2 files changed

+76
-31
lines changed

2 files changed

+76
-31
lines changed

generative_ai/function_calling.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,48 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# [START aiplatform_function_calling]
15+
# [START aiplatform_gemini_function_calling]
1616
from vertexai.preview.generative_models import (
1717
FunctionDeclaration,
1818
GenerativeModel,
1919
Tool,
2020
)
2121

22-
# Load the Vertex AI Gemini API to use function calling
23-
model = GenerativeModel("gemini-pro")
24-
25-
# Specify a function declaration and parameters for an API request
26-
get_current_weather_func = FunctionDeclaration(
27-
name="get_current_weather",
28-
description="Get the current weather in a given location",
29-
# Function parameters are specified in OpenAPI JSON schema format
30-
parameters={
31-
"type": "object",
32-
"properties": {"location": {"type": "string", "description": "Location"}},
33-
},
34-
)
3522

36-
# Define a tool that includes the above get_current_weather_func
37-
weather_tool = Tool(
38-
function_declarations=[get_current_weather_func],
39-
)
23+
def generate_function_call(prompt: str) -> str:
24+
# Load the Vertex AI Gemini API to use function calling
25+
model = GenerativeModel("gemini-pro")
4026

41-
# Prompt to ask the model about weather, which will invoke the Tool
42-
prompt = "What is the weather like in Boston?"
27+
# Specify a function declaration and parameters for an API request
28+
get_current_weather_func = FunctionDeclaration(
29+
name="get_current_weather",
30+
description="Get the current weather in a given location",
31+
# Function parameters are specified in OpenAPI JSON schema format
32+
parameters={
33+
"type": "object",
34+
"properties": {"location": {"type": "string", "description": "Location"}},
35+
},
36+
)
37+
38+
# Define a tool that includes the above get_current_weather_func
39+
weather_tool = Tool(
40+
function_declarations=[get_current_weather_func],
41+
)
42+
43+
# Prompt to ask the model about weather, which will invoke the Tool
44+
prompt = prompt
45+
46+
# Instruct the model to generate content using the Tool that you just created:
47+
response = model.generate_content(
48+
prompt,
49+
generation_config={"temperature": 0},
50+
tools=[weather_tool],
51+
)
52+
53+
return str(response)
4354

44-
# Instruct the model to generate content using the Tool that you just created:
45-
response = model.generate_content(
46-
prompt,
47-
generation_config={"temperature": 0},
48-
tools=[weather_tool],
49-
)
5055

51-
# Print the entire response
52-
print(response)
56+
# [END aiplatform_gemini_function_calling]
5357

54-
# Print the part of the response that contains info about the function call
55-
print(response.candidates[0].content.parts[0].function_call)
56-
# [END aiplatform_function_calling]
58+
if __name__ == "__main__":
59+
generate_function_call("What is the weather like in Boston?")
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
from google.api_core.exceptions import ResourceExhausted
19+
20+
import function_calling
21+
22+
23+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
24+
_LOCATION = "us-central1"
25+
26+
27+
function_expected_responses = [
28+
"function_call",
29+
"get_current_weather",
30+
"args",
31+
"fields",
32+
"location",
33+
"Boston",
34+
]
35+
36+
37+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
38+
def test_interview() -> None:
39+
content = function_calling.generate_function_call(
40+
prompt="What is the weather like in Boston?"
41+
)
42+
assert all(x in content for x in function_expected_responses)

0 commit comments

Comments
 (0)