Skip to content

Commit bfcf355

Browse files
feat: (Gen AI): Add sample for Gemini parallel function calling example (GoogleCloudPlatform#12637)
* Add Python sample for Gemini parallel function calling * Update model version * rename test file to follow standard naming practice * move function_calling tests into a single file * give up to the pylint import-order unconfigured check * refresh CI checks * update region name to follow up the standard schema * update region tag name after a discussion
1 parent 5aab168 commit bfcf355

File tree

3 files changed

+154
-61
lines changed

3 files changed

+154
-61
lines changed

generative_ai/function_calling/examples_function_calling_test.py

Lines changed: 0 additions & 51 deletions
This file was deleted.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
from vertexai.generative_models import ChatSession
17+
18+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
19+
20+
21+
def parallel_function_calling_example() -> ChatSession:
22+
# [START generativeaionvertexai_function_calling_generate_parallel_calls]
23+
import vertexai
24+
from vertexai.generative_models import (
25+
FunctionDeclaration,
26+
GenerativeModel,
27+
Part,
28+
Tool,
29+
)
30+
31+
# TODO(developer): Update & uncomment below line
32+
# PROJECT_ID = "your-project-id"
33+
34+
# Initialize Vertex AI
35+
vertexai.init(project=PROJECT_ID, location="us-central1")
36+
37+
# Specify a function declaration and parameters for an API request
38+
function_name = "get_current_weather"
39+
get_current_weather_func = FunctionDeclaration(
40+
name=function_name,
41+
description="Get the current weather in a given location",
42+
parameters={
43+
"type": "object",
44+
"properties": {
45+
"location": {
46+
"type": "string",
47+
"description": "The location to whih to get the weather. Can be a city name, a city name and state, or a zip code. Examples: 'San Francisco', 'San Francisco, CA', '95616', etc."
48+
},
49+
},
50+
},
51+
)
52+
53+
# In this example, we'll use synthetic data to simulate a response payload from an external API
54+
def mock_weather_api_service(location: str) -> str:
55+
temperature = 25 if location == "San Francisco" else 35
56+
return f"""{{ "location": "{location}", "temperature": {temperature}, "unit": "C" }}"""
57+
58+
# Define a tool that includes the above function
59+
tools = Tool(
60+
function_declarations=[get_current_weather_func],
61+
)
62+
63+
# Initialize Gemini model
64+
model = GenerativeModel(
65+
model_name="gemini-1.5-pro-002",
66+
tools=[tools],
67+
)
68+
69+
# Start a chat session
70+
chat = model.start_chat()
71+
response = chat.send_message("Get weather details in New Delhi and San Francisco?")
72+
73+
function_calls = response.candidates[0].function_calls
74+
print("Suggested finction calls:\n", function_calls)
75+
76+
if function_calls:
77+
api_responses = []
78+
for func in function_calls:
79+
if func.name == function_name:
80+
api_responses.append({
81+
"content": mock_weather_api_service(location=func.args["location"])
82+
})
83+
84+
# Return the API response to Gemini
85+
response = chat.send_message(
86+
[
87+
Part.from_function_response(
88+
name="get_current_weather",
89+
response=api_responses[0],
90+
),
91+
Part.from_function_response(
92+
name="get_current_weather",
93+
response=api_responses[1],
94+
),
95+
],
96+
)
97+
98+
print(response.text)
99+
# Example response:
100+
# The current weather in New Delhi is 35°C. The current weather in San Francisco is 25°C.
101+
102+
# [END generativeaionvertexai_function_calling_generate_parallel_calls]
103+
return response
104+
105+
106+
if __name__ == "__main__":
107+
parallel_function_calling_example()

generative_ai/function_calling/chat_function_calling_test.py renamed to generative_ai/function_calling/test_function_calling.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,46 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import backoff
14+
import advanced_example
1515

16+
import backoff
1617

18+
import basic_example
1719
import chat_example
18-
1920
import chat_function_calling_basic
2021
import chat_function_calling_config
2122

2223
from google.api_core.exceptions import ResourceExhausted
2324

24-
summaries_expected = [
25-
"Pixel 8 Pro",
26-
"stock",
27-
"store",
28-
"2000 N Shoreline Blvd",
29-
"Mountain View",
30-
]
25+
import parallel_function_calling_example
26+
27+
28+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
29+
def test_function_calling() -> None:
30+
response = basic_example.generate_function_call()
31+
32+
expected_summary = [
33+
"Boston",
34+
]
35+
expected_responses = [
36+
"candidates",
37+
"content",
38+
"role",
39+
"model",
40+
"parts",
41+
"Boston",
42+
]
43+
assert all(x in str(response.text) for x in expected_summary)
44+
assert all(x in str(response) for x in expected_responses)
45+
46+
47+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
48+
def test_function_calling_advanced_function_selection() -> None:
49+
response = advanced_example.generate_function_call_advanced()
50+
assert (
51+
"Pixel 8 Pro 128GB"
52+
in response.candidates[0].function_calls[0].args["product_name"]
53+
)
3154

3255

3356
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
@@ -48,4 +71,18 @@ def test_function_calling_chat() -> None:
4871

4972
assert chat
5073
assert chat.history
51-
assert any(x in str(chat.history) for x in summaries_expected)
74+
75+
expected_summaries = [
76+
"Pixel 8 Pro",
77+
"stock",
78+
"store",
79+
"2000 N Shoreline Blvd",
80+
"Mountain View",
81+
]
82+
assert any(x in str(chat.history) for x in expected_summaries)
83+
84+
85+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
86+
def test_parallel_function_calling() -> None:
87+
response = parallel_function_calling_example.parallel_function_calling_example()
88+
assert response is not None

0 commit comments

Comments
 (0)