Skip to content

Commit 9855e23

Browse files
Add code sample for generative-ai/function-calling-chat that uses the chat modality (GoogleCloudPlatform#11257)
* Initial commit of Gemini function calling chat code sample * Update test name * Show example of getting function name, arguments, and calling API * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * Avoid flake8 errors for (intentionally) unused function call arguments * Simplify sample code and arrays for context * Fix type annotation for function output * Rerun CI tests * Update type annotation of function and test keywords --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
1 parent cfdbb24 commit 9855e23

File tree

3 files changed

+187
-1
lines changed

3 files changed

+187
-1
lines changed

generative_ai/function_calling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424

2525

26-
def generate_function_call(prompt: str, project_id: str, location: str) -> str:
26+
def generate_function_call(prompt: str, project_id: str, location: str) -> tuple:
2727
# Initialize Vertex AI
2828
vertexai.init(project=project_id, location=location)
2929

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START aiplatform_gemini_function_calling_chat]
16+
import vertexai
17+
from vertexai.generative_models import (
18+
FunctionDeclaration,
19+
GenerativeModel,
20+
Part,
21+
Tool,
22+
)
23+
24+
25+
def generate_function_call_chat(project_id: str, location: str) -> tuple:
26+
prompts = []
27+
summaries = []
28+
29+
# Initialize Vertex AI
30+
vertexai.init(project=project_id, location=location)
31+
32+
# Specify a function declaration and parameters for an API request
33+
get_product_info_func = FunctionDeclaration(
34+
name="get_product_sku",
35+
description="Get the SKU for a product",
36+
# Function parameters are specified in OpenAPI JSON schema format
37+
parameters={
38+
"type": "object",
39+
"properties": {
40+
"product_name": {"type": "string", "description": "Product name"}
41+
},
42+
},
43+
)
44+
45+
# Specify another function declaration and parameters for an API request
46+
get_store_location_func = FunctionDeclaration(
47+
name="get_store_location",
48+
description="Get the location of the closest store",
49+
# Function parameters are specified in OpenAPI JSON schema format
50+
parameters={
51+
"type": "object",
52+
"properties": {"location": {"type": "string", "description": "Location"}},
53+
},
54+
)
55+
56+
# Define a tool that includes the above functions
57+
retail_tool = Tool(
58+
function_declarations=[
59+
get_product_info_func,
60+
get_store_location_func,
61+
],
62+
)
63+
64+
# Initialize Gemini model
65+
model = GenerativeModel(
66+
"gemini-1.0-pro", generation_config={"temperature": 0}, tools=[retail_tool]
67+
)
68+
69+
# Start a chat session
70+
chat = model.start_chat()
71+
72+
# Send a prompt for the first conversation turn that should invoke the get_product_sku function
73+
prompt = "Do you have the Pixel 8 Pro in stock?"
74+
response = chat.send_message(prompt)
75+
prompts.append(prompt)
76+
77+
# Check the function name that the model responded with, and make an API call to an external system
78+
if response.candidates[0].content.parts[0].function_call.name == "get_product_sku":
79+
# Extract the arguments to use in your API call
80+
product_name = (
81+
response.candidates[0].content.parts[0].function_call.args["product_name"]
82+
)
83+
product_name
84+
85+
# Here you can use your preferred method to make an API request to retrieve the product SKU, as in:
86+
# api_response = requests.post(product_api_url, data={"product_name": product_name})
87+
88+
# In this example, we'll use synthetic data to simulate a response payload from an external API
89+
api_response = {"sku": "GA04834-US", "in_stock": "yes"}
90+
91+
# Return the API response to Gemini so it can generate a model response or request another function call
92+
response = chat.send_message(
93+
Part.from_function_response(
94+
name="get_product_sku",
95+
response={
96+
"content": api_response,
97+
},
98+
),
99+
)
100+
101+
# Extract the text from the summary response
102+
summary = response.candidates[0].content.parts[0].text
103+
summaries.append(summary)
104+
105+
# Send a prompt for the second conversation turn that should invoke the get_store_location function
106+
prompt = "Is there a store in Mountain View, CA that I can visit to try it out?"
107+
response = chat.send_message(prompt)
108+
prompts.append(prompt)
109+
110+
# Check the function name that the model responded with, and make an API call to an external system
111+
if (
112+
response.candidates[0].content.parts[0].function_call.name
113+
== "get_store_location"
114+
):
115+
# Extract the arguments to use in your API call
116+
location = (
117+
response.candidates[0].content.parts[0].function_call.args["location"]
118+
)
119+
location
120+
121+
# Here you can use your preferred method to make an API request to retrieve store location closest to the user, as in:
122+
# api_response = requests.post(store_api_url, data={"location": location})
123+
124+
# In this example, we'll use synthetic data to simulate a response payload from an external API
125+
api_response = {"store": "2000 N Shoreline Blvd, Mountain View, CA 94043, US"}
126+
127+
# Return the API response to Gemini so it can generate a model response or request another function call
128+
response = chat.send_message(
129+
Part.from_function_response(
130+
name="get_store_location",
131+
response={
132+
"content": api_response,
133+
},
134+
),
135+
)
136+
137+
# Extract the text from the summary response
138+
summary = response.candidates[0].content.parts[0].text
139+
summaries.append(summary)
140+
141+
return prompts, summaries
142+
143+
144+
# [END aiplatform_gemini_function_calling_chat]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import backoff
18+
from google.api_core.exceptions import ResourceExhausted
19+
20+
import function_calling_chat
21+
22+
23+
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
24+
_LOCATION = "us-central1"
25+
26+
27+
summaries_expected = [
28+
"Pixel 8 Pro",
29+
"stock",
30+
"store",
31+
"2000 N Shoreline Blvd",
32+
"Mountain View",
33+
]
34+
35+
36+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
37+
def test_function_calling_chat() -> None:
38+
prompts, summaries = function_calling_chat.generate_function_call_chat(
39+
project_id=_PROJECT_ID,
40+
location=_LOCATION,
41+
)
42+
assert all(x in str(summaries) for x in summaries_expected)

0 commit comments

Comments
 (0)