From c9a955e85dc463cd8e4aec9bf5722202650845f7 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 06:29:57 -0700 Subject: [PATCH 01/16] Add genai.protos Change-Id: I21cfada033c6ffbed7a20e117e61582fde925f61 --- docs/build_docs.py | 82 +----- google/generativeai/answer.py | 66 ++--- google/generativeai/discuss.py | 103 ++++---- google/generativeai/embedding.py | 28 +-- google/generativeai/files.py | 8 +- google/generativeai/generative_models.py | 72 +++--- google/generativeai/models.py | 50 ++-- google/generativeai/operations.py | 12 +- google/generativeai/permission.py | 6 +- google/generativeai/responder.py | 74 +++--- google/generativeai/retriever.py | 42 ++-- google/generativeai/text.py | 48 ++-- google/generativeai/types/answer_types.py | 4 +- google/generativeai/types/citation_types.py | 6 +- google/generativeai/types/content_types.py | 138 +++++----- google/generativeai/types/discuss_types.py | 24 +- google/generativeai/types/file_types.py | 20 +- google/generativeai/types/generation_types.py | 60 ++--- google/generativeai/types/model_types.py | 42 ++-- .../generativeai/types/palm_safety_types.py | 134 +++++----- google/generativeai/types/permission_types.py | 60 ++--- google/generativeai/types/retriever_types.py | 236 +++++++++--------- google/generativeai/types/safety_types.py | 106 ++++---- tests/test_answer.py | 112 ++++----- tests/test_client.py | 4 +- tests/test_content.py | 146 +++++------ tests/test_discuss.py | 72 +++--- tests/test_discuss_async.py | 22 +- tests/test_embedding.py | 20 +- tests/test_embedding_async.py | 20 +- tests/test_generation.py | 140 +++++------ tests/test_generative_models.py | 118 ++++----- tests/test_generative_models_async.py | 32 +-- tests/test_helpers.py | 12 +- tests/test_models.py | 118 ++++----- tests/test_operations.py | 18 +- tests/test_permission.py | 70 +++--- tests/test_permission_async.py | 70 +++--- tests/test_responder.py | 58 ++--- tests/test_retriever.py | 140 +++++------ tests/test_retriever_async.py | 134 +++++----- tests/test_text.py | 94 +++---- 42 files changed, 1375 insertions(+), 1446 deletions(-) diff --git a/docs/build_docs.py b/docs/build_docs.py index eaa6a1ba4..280738700 100644 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -44,77 +44,13 @@ # For showing the conditional imports and types in `content_types.py` # grpc must be imported first. typing.TYPE_CHECKING = True -from google import generativeai as palm +from google import generativeai as genai from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api import yaml -glm.__doc__ = """\ -This package, `google.ai.generativelanguage`, is a low-level auto-generated client library for the PaLM API. - -```posix-terminal -pip install google.ai.generativelanguage -``` - -It is built using the same tooling as Google Cloud client libraries, and will be quite familiar if you've used -those before. - -While we encourage Python users to access the PaLM API using the `google.generativeai` package (aka `palm`), -this lower level package is also available. - -Each method in the PaLM API is connected to one of the client classes. Pass your API-key to the class' `client_options` -when initializing a client: - -``` -from google.ai import generativelanguage as glm - -client = glm.DiscussServiceClient( - client_options={'api_key':'YOUR_API_KEY'}) -``` - -To call the api, pass an appropriate request-proto-object. For the `DiscussServiceClient.generate_message` pass -a `generativelanguage.GenerateMessageRequest` instance: - -``` -request = glm.GenerateMessageRequest( - model='models/chat-bison-001', - prompt=glm.MessagePrompt( - messages=[glm.Message(content='Hello!')])) - -client.generate_message(request) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` - -For simplicity: - -* The API methods also accept key-word arguments. -* Anywhere you might pass a proto-object, the library will also accept simple python structures. - -So the following is equivalent to the previous example: - -``` -client.generate_message( - model='models/chat-bison-001', - prompt={'messages':[{'content':'Hello!'}]}) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` -""" - HERE = pathlib.Path(__file__).parent PROJECT_SHORT_NAME = "genai" @@ -143,21 +79,11 @@ class MyFilter: def __init__(self, base_dirs): self.filter_base_dirs = public_api.FilterBaseDirs(base_dirs) - def drop_staticmethods(self, parent, children): - parent = dict(parent.__dict__) - for name, value in children: - if not isinstance(parent.get(name, None), staticmethod): - yield name, value - def __call__(self, path, parent, children): if any("generativelanguage" in part for part in path) or "generativeai" in path: children = self.filter_base_dirs(path, parent, children) children = public_api.explicit_package_contents_filter(path, parent, children) - if any("generativelanguage" in part for part in path): - if "ServiceClient" in path[-1] or "ServiceAsyncClient" in path[-1]: - children = list(self.drop_staticmethods(parent, children)) - return children @@ -188,11 +114,11 @@ def gen_api_docs(): """ ) - doc_generator = MyDocGenerator( + doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[("google", google)], + py_modules=[("google.generativeai", genai)], base_dir=( - pathlib.Path(palm.__file__).parent, + pathlib.Path(genai.__file__).parent, pathlib.Path(glm.__file__).parent.parent, ), code_url_prefix=( diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 1b419be57..eefe6e68d 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -20,7 +20,7 @@ from typing import Any, Iterable, Union, Mapping, Optional from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import ( get_default_generative_client, @@ -35,7 +35,7 @@ DEFAULT_ANSWER_MODEL = "models/aqa" -AnswerStyle = glm.GenerateAnswerRequest.AnswerStyle +AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle AnswerStyleOptions = Union[int, str, AnswerStyle] @@ -66,28 +66,28 @@ def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: GroundingPassageOptions = ( - Union[glm.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], + Union[protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], ) GroundingPassagesOptions = Union[ - glm.GroundingPassages, + protos.GroundingPassages, Iterable[GroundingPassageOptions], Mapping[str, content_types.ContentType], ] -def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingPassages: +def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages: """ - Converts the `source` into a `glm.GroundingPassage`. A `GroundingPassages` contains a list of - `glm.GroundingPassage` objects, which each contain a `glm.Contant` and a string `id`. + Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of + `protos.GroundingPassage` objects, which each contain a `protos.Contant` and a string `id`. Args: - source: `Content` or a `GroundingPassagesOptions` that will be converted to glm.GroundingPassages. + source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages. Return: - `glm.GroundingPassages` to be passed into `glm.GenerateAnswer`. + `protos.GroundingPassages` to be passed into `protos.GenerateAnswer`. """ - if isinstance(source, glm.GroundingPassages): + if isinstance(source, protos.GroundingPassages): return source if not isinstance(source, Iterable): @@ -100,7 +100,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP source = source.items() for n, data in enumerate(source): - if isinstance(data, glm.GroundingPassage): + if isinstance(data, protos.GroundingPassage): passages.append(data) elif isinstance(data, tuple): id, content = data # tuple must have exactly 2 items. @@ -108,11 +108,11 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP else: passages.append({"id": str(n), "content": content_types.to_content(data)}) - return glm.GroundingPassages(passages=passages) + return protos.GroundingPassages(passages=passages) SourceNameType = Union[ - str, retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document + str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document ] @@ -127,7 +127,7 @@ class SemanticRetrieverConfigDict(TypedDict): SemanticRetrieverConfigOptions = Union[ SourceNameType, SemanticRetrieverConfigDict, - glm.SemanticRetrieverConfig, + protos.SemanticRetrieverConfig, ] @@ -135,7 +135,7 @@ def _maybe_get_source_name(source) -> str | None: if isinstance(source, str): return source elif isinstance( - source, (retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document) + source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document) ): return source.name else: @@ -145,8 +145,8 @@ def _maybe_get_source_name(source) -> str | None: def _make_semantic_retriever_config( source: SemanticRetrieverConfigOptions, query: content_types.ContentsType, -) -> glm.SemanticRetrieverConfig: - if isinstance(source, glm.SemanticRetrieverConfig): +) -> protos.SemanticRetrieverConfig: + if isinstance(source, protos.SemanticRetrieverConfig): return source name = _maybe_get_source_name(source) @@ -156,7 +156,7 @@ def _make_semantic_retriever_config( source["source"] = _maybe_get_source_name(source["source"]) else: raise TypeError( - "Could create a `glm.SemanticRetrieverConfig` from:\n" + "Could create a `protos.SemanticRetrieverConfig` from:\n" f" type: {type(source)}\n" f" value: {source}" ) @@ -166,7 +166,7 @@ def _make_semantic_retriever_config( elif isinstance(source["query"], str): source["query"] = content_types.to_content(source["query"]) - return glm.SemanticRetrieverConfig(source) + return protos.SemanticRetrieverConfig(source) def _make_generate_answer_request( @@ -178,9 +178,9 @@ def _make_generate_answer_request( answer_style: AnswerStyle | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, -) -> glm.GenerateAnswerRequest: +) -> protos.GenerateAnswerRequest: """ - constructs a glm.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. + constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. Args: model: Name of the model used to generate the grounded response. @@ -188,16 +188,16 @@ def _make_generate_answer_request( single question to answer. For multi-turn queries, this is a repeated field that contains conversation history and the last `Content` in the list containing the question. inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style for grounded answers. safety_settings: Safety settings for generated output. temperature: The temperature for randomness in the output. Returns: - Call for glm.GenerateAnswerRequest(). + Call for protos.GenerateAnswerRequest(). """ model = model_types.make_model_name(model) @@ -222,7 +222,7 @@ def _make_generate_answer_request( if answer_style: answer_style = to_answer_style(answer_style) - return glm.GenerateAnswerRequest( + return protos.GenerateAnswerRequest( model=model, contents=contents, inline_passages=inline_passages, @@ -242,7 +242,7 @@ def generate_answer( answer_style: AnswerStyle | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, - client: glm.GenerativeServiceClient | None = None, + client: protos.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -272,14 +272,14 @@ def generate_answer( contents: The question to be answered by the model, grounded in the provided source. inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. temperature: Controls the randomness of the output. - client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `protos.TextServiceClient` instead. request_options: Options for the request. Returns: @@ -315,7 +315,7 @@ async def generate_answer_async( answer_style: AnswerStyle | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, - client: glm.GenerativeServiceClient | None = None, + client: protos.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -326,14 +326,14 @@ async def generate_answer_async( contents: The question to be answered by the model, grounded in the provided source. inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. temperature: Controls the randomness of the output. - client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `protos.TextServiceClient` instead. Returns: A `types.Answer` containing the model's text answer response. diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 35611ae69..4cea08421 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -18,37 +18,38 @@ import sys import textwrap -from typing import Any, Iterable, List, Optional, Union +from typing import Iterable, List import google.ai.generativelanguage as glm from google.generativeai.client import get_default_discuss_client from google.generativeai.client import get_default_discuss_async_client from google.generativeai import string_utils +from google.generativeai import protos from google.generativeai.types import discuss_types from google.generativeai.types import helper_types from google.generativeai.types import model_types from google.generativeai.types import palm_safety_types -def _make_message(content: discuss_types.MessageOptions) -> glm.Message: - """Creates a `glm.Message` object from the provided content.""" - if isinstance(content, glm.Message): +def _make_message(content: discuss_types.MessageOptions) -> protos.Message: + """Creates a `protos.Message` object from the provided content.""" + if isinstance(content, protos.Message): return content if isinstance(content, str): - return glm.Message(content=content) + return protos.Message(content=content) else: - return glm.Message(content) + return protos.Message(content) def _make_messages( messages: discuss_types.MessagesOptions, -) -> List[glm.Message]: +) -> List[protos.Message]: """ - Creates a list of `glm.Message` objects from the provided messages. + Creates a list of `protos.Message` objects from the provided messages. This function takes a variety of message content inputs, such as strings, dictionaries, - or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that + or `protos.Message` objects, and creates a list of `protos.Message` objects. It ensures that the authors of the messages alternate appropriately. If authors are not provided, default authors are assigned based on their position in the list. @@ -56,9 +57,9 @@ def _make_messages( messages: The messages to convert. Returns: - A list of `glm.Message` objects with alternating authors. + A list of `protos.Message` objects with alternating authors. """ - if isinstance(messages, (str, dict, glm.Message)): + if isinstance(messages, (str, dict, protos.Message)): messages = [_make_message(messages)] else: messages = [_make_message(message) for message in messages] @@ -89,39 +90,39 @@ def _make_messages( return messages -def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: - """Creates a `glm.Example` object from the provided item.""" - if isinstance(item, glm.Example): +def _make_example(item: discuss_types.ExampleOptions) -> protos.Example: + """Creates a `protos.Example` object from the provided item.""" + if isinstance(item, protos.Example): return item if isinstance(item, dict): item = item.copy() item["input"] = _make_message(item["input"]) item["output"] = _make_message(item["output"]) - return glm.Example(item) + return protos.Example(item) if isinstance(item, Iterable): input, output = list(item) - return glm.Example(input=_make_message(input), output=_make_message(output)) + return protos.Example(input=_make_message(input), output=_make_message(output)) # try anyway - return glm.Example(item) + return protos.Example(item) def _make_examples_from_flat( examples: List[discuss_types.MessageOptions], -) -> List[glm.Example]: +) -> List[protos.Example]: """ - Creates a list of `glm.Example` objects from a list of message options. + Creates a list of `protos.Example` objects from a list of message options. This function takes a list of `discuss_types.MessageOptions` and pairs them into - `glm.Example` objects. The input examples must be in pairs to create valid examples. + `protos.Example` objects. The input examples must be in pairs to create valid examples. Args: examples: The list of `discuss_types.MessageOptions`. Returns: - A list of `glm.Example objects` created by pairing up the provided messages. + A list of `protos.Example objects` created by pairing up the provided messages. Raises: ValueError: If the provided list of examples is not of even length. @@ -141,7 +142,7 @@ def _make_examples_from_flat( pair.append(msg) if n % 2 == 0: continue - primer = glm.Example( + primer = protos.Example( input=pair[0], output=pair[1], ) @@ -152,21 +153,21 @@ def _make_examples_from_flat( def _make_examples( examples: discuss_types.ExamplesOptions, -) -> List[glm.Example]: +) -> List[protos.Example]: """ - Creates a list of `glm.Example` objects from the provided examples. + Creates a list of `protos.Example` objects from the provided examples. This function takes various types of example content inputs and creates a list - of `glm.Example` objects. It handles the conversion of different input types and ensures + of `protos.Example` objects. It handles the conversion of different input types and ensures the appropriate structure for creating valid examples. Args: examples: The examples to convert. Returns: - A list of `glm.Example` objects created from the provided examples. + A list of `protos.Example` objects created from the provided examples. """ - if isinstance(examples, glm.Example): + if isinstance(examples, protos.Example): return [examples] if isinstance(examples, dict): @@ -204,11 +205,11 @@ def _make_message_prompt_dict( context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, -) -> glm.MessagePrompt: +) -> protos.MessagePrompt: """ - Creates a `glm.MessagePrompt` object from the provided prompt components. + Creates a `protos.MessagePrompt` object from the provided prompt components. - This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`, + This function constructs a `protos.MessagePrompt` object using the provided `context`, `examples`, or `messages`. It ensures the proper structure and handling of the input components. Either pass a `prompt` or it's component `context`, `examples`, `messages`. @@ -220,7 +221,7 @@ def _make_message_prompt_dict( messages: The messages for the prompt. Returns: - A `glm.MessagePrompt` object created from the provided prompt components. + A `protos.MessagePrompt` object created from the provided prompt components. """ if prompt is None: prompt = dict( @@ -235,7 +236,7 @@ def _make_message_prompt_dict( "You can't set `prompt`, and its fields `(context, examples, messages)`" " at the same time" ) - if isinstance(prompt, glm.MessagePrompt): + if isinstance(prompt, protos.MessagePrompt): return prompt elif isinstance(prompt, dict): # Always check dict before Iterable. pass @@ -265,12 +266,12 @@ def _make_message_prompt( context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, -) -> glm.MessagePrompt: - """Creates a `glm.MessagePrompt` object from the provided prompt components.""" +) -> protos.MessagePrompt: + """Creates a `protos.MessagePrompt` object from the provided prompt components.""" prompt = _make_message_prompt_dict( prompt=prompt, context=context, examples=examples, messages=messages ) - return glm.MessagePrompt(prompt) + return protos.MessagePrompt(prompt) def _make_generate_message_request( @@ -284,15 +285,15 @@ def _make_generate_message_request( top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, -) -> glm.GenerateMessageRequest: - """Creates a `glm.GenerateMessageRequest` object for generating messages.""" +) -> protos.GenerateMessageRequest: + """Creates a `protos.GenerateMessageRequest` object for generating messages.""" model = model_types.make_model_name(model) prompt = _make_message_prompt( prompt=prompt, context=context, examples=examples, messages=messages ) - return glm.GenerateMessageRequest( + return protos.GenerateMessageRequest( model=model, prompt=prompt, temperature=temperature, @@ -316,7 +317,7 @@ def chat( top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, - client: glm.DiscussServiceClient | None = None, + client: protos.DiscussServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: """Calls the API and returns a `types.ChatResponse` containing the response. @@ -383,7 +384,7 @@ def chat( prompt: You may pass a `types.MessagePromptOptions` **instead** of a setting `context`/`examples`/`messages`, but not both. client: If you're not relying on the default client, you pass a - `glm.DiscussServiceClient` instead. + `protos.DiscussServiceClient` instead. request_options: Options for the request. Returns: @@ -416,7 +417,7 @@ async def chat_async( top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, - client: glm.DiscussServiceAsyncClient | None = None, + client: protos.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: request = _make_generate_message_request( @@ -446,7 +447,7 @@ async def chat_async( @string_utils.set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): - _client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False) + _client: protos.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False) def __init__(self, **kwargs): for key, value in kwargs.items(): @@ -495,7 +496,7 @@ def reply( async def reply_async( self, message: discuss_types.MessageOptions ) -> discuss_types.ChatResponse: - if isinstance(self._client, glm.DiscussServiceClient): + if isinstance(self._client, protos.DiscussServiceClient): raise TypeError( f"reply_async can't be called on a non-async client, use reply instead." ) @@ -509,9 +510,9 @@ async def reply_async( def _build_chat_response( - request: glm.GenerateMessageRequest, - response: glm.GenerateMessageResponse, - client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient, + request: protos.GenerateMessageRequest, + response: protos.GenerateMessageResponse, + client: protos.DiscussServiceClient | protos.DiscussServiceAsyncClient, ) -> ChatResponse: request = type(request).to_dict(request) prompt = request.pop("prompt") @@ -536,8 +537,8 @@ def _build_chat_response( def _generate_response( - request: glm.GenerateMessageRequest, - client: glm.DiscussServiceClient | None = None, + request: protos.GenerateMessageRequest, + client: protos.DiscussServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: @@ -552,8 +553,8 @@ def _generate_response( async def _generate_response_async( - request: glm.GenerateMessageRequest, - client: glm.DiscussServiceAsyncClient | None = None, + request: protos.GenerateMessageRequest, + client: protos.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: @@ -574,7 +575,7 @@ def count_message_tokens( examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, - client: glm.DiscussServiceAsyncClient | None = None, + client: protos.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.TokenCount: model = model_types.make_model_name(model) diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 14fff1737..b62e6b450 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -17,7 +17,7 @@ import itertools from typing import Any, Iterable, overload, TypeVar, Union, Mapping -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_generative_client from google.generativeai.client import get_default_generative_async_client @@ -30,7 +30,7 @@ DEFAULT_EMB_MODEL = "models/embedding-001" EMBEDDING_MAX_BATCH_SIZE = 100 -EmbeddingTaskType = glm.TaskType +EmbeddingTaskType = protos.TaskType EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType] @@ -101,7 +101,7 @@ def embed_content( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: glm.GenerativeServiceClient | None = None, + client: protos.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -113,7 +113,7 @@ def embed_content( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: glm.GenerativeServiceClient | None = None, + client: protos.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -124,7 +124,7 @@ def embed_content( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: glm.GenerativeServiceClient = None, + client: protos.GenerativeServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create embeddings for content passed in. @@ -179,7 +179,7 @@ def embed_content( if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): result = {"embedding": []} requests = ( - glm.EmbedContentRequest( + protos.EmbedContentRequest( model=model, content=content_types.to_content(c), task_type=task_type, @@ -189,7 +189,7 @@ def embed_content( for c in content ) for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): - embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) embedding_response = client.batch_embed_contents( embedding_request, **request_options, @@ -198,7 +198,7 @@ def embed_content( result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) return result else: - embedding_request = glm.EmbedContentRequest( + embedding_request = protos.EmbedContentRequest( model=model, content=content_types.to_content(content), task_type=task_type, @@ -221,7 +221,7 @@ async def embed_content_async( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: glm.GenerativeServiceAsyncClient | None = None, + client: protos.GenerativeServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -233,7 +233,7 @@ async def embed_content_async( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: glm.GenerativeServiceAsyncClient | None = None, + client: protos.GenerativeServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -244,7 +244,7 @@ async def embed_content_async( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: glm.GenerativeServiceAsyncClient = None, + client: protos.GenerativeServiceAsyncClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """The async version of `genai.embed_content`.""" @@ -270,7 +270,7 @@ async def embed_content_async( if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): result = {"embedding": []} requests = ( - glm.EmbedContentRequest( + protos.EmbedContentRequest( model=model, content=content_types.to_content(c), task_type=task_type, @@ -280,7 +280,7 @@ async def embed_content_async( for c in content ) for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): - embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) embedding_response = await client.batch_embed_contents( embedding_request, **request_options, @@ -289,7 +289,7 @@ async def embed_content_async( result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) return result else: - embedding_request = glm.EmbedContentRequest( + embedding_request = protos.EmbedContentRequest( model=model, content=content_types.to_content(content), task_type=task_type, diff --git a/google/generativeai/files.py b/google/generativeai/files.py index 13535c47f..03dce3466 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -19,7 +19,7 @@ import mimetypes from typing import Iterable import logging -import google.ai.generativelanguage as glm +from google.generativeai import protos from itertools import islice from google.generativeai.types import file_types @@ -75,7 +75,7 @@ def upload_file( def list_files(page_size=100) -> Iterable[file_types.File]: client = get_default_file_client() - response = client.list_files(glm.ListFilesRequest(page_size=page_size)) + response = client.list_files(protos.ListFilesRequest(page_size=page_size)) for proto in response: yield file_types.File(proto) @@ -86,8 +86,8 @@ def get_file(name) -> file_types.File: def delete_file(name): - if isinstance(name, (file_types.File, glm.File)): + if isinstance(name, (file_types.File, protos.File)): name = name.name - request = glm.DeleteFileRequest(name=name) + request = protos.DeleteFileRequest(name=name) client = get_default_file_client() client.delete_file(request=request) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 86b87ee90..cea5b444e 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -13,7 +13,7 @@ import google.api_core.exceptions -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import client from google.generativeai.types import content_types from google.generativeai.types import generation_types @@ -127,8 +127,8 @@ def _prepare_request( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None, tool_config: content_types.ToolConfigType | None, - ) -> glm.GenerateContentRequest: - """Creates a `glm.GenerateContentRequest` from raw inputs.""" + ) -> protos.GenerateContentRequest: + """Creates a `protos.GenerateContentRequest` from raw inputs.""" if not contents: raise TypeError("contents must not be empty") @@ -152,7 +152,7 @@ def _prepare_request( merged_ss.update(safety_settings) merged_ss = safety_types.normalize_safety_settings(merged_ss) - return glm.GenerateContentRequest( + return protos.GenerateContentRequest( model=self._model_name, contents=contents, generation_config=merged_gc, @@ -214,25 +214,25 @@ def generate_content( ### Input type flexibility - While the underlying API strictly expects a `list[glm.Content]` objects, this method + While the underlying API strictly expects a `list[protos.Content]` objects, this method will convert the user input into the correct type. The hierarchy of types that can be converted is below. Any of these objects can be passed as an equivalent `dict`. - * `Iterable[glm.Content]` - * `glm.Content` - * `Iterable[glm.Part]` - * `glm.Part` - * `str`, `Image`, or `glm.Blob` + * `Iterable[protos.Content]` + * `protos.Content` + * `Iterable[protos.Part]` + * `protos.Part` + * `str`, `Image`, or `protos.Blob` - In an `Iterable[glm.Content]` each `content` is a separate message. - But note that an `Iterable[glm.Part]` is taken as the parts of a single message. + In an `Iterable[protos.Content]` each `content` is a separate message. + But note that an `Iterable[protos.Part]` is taken as the parts of a single message. Arguments: contents: The contents serving as the model's prompt. generation_config: Overrides for the model's generation config. safety_settings: Overrides for the model's safety settings. stream: If True, yield response chunks as they are generated. - tools: `glm.Tools` more info coming soon. + tools: `protos.Tools` more info coming soon. request_options: Options for the request. """ request = self._prepare_request( @@ -327,14 +327,14 @@ def count_tokens( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, - ) -> glm.CountTokensResponse: + ) -> protos.CountTokensResponse: if request_options is None: request_options = {} if self._client is None: self._client = client.get_default_generative_client() - request = glm.CountTokensRequest( + request = protos.CountTokensRequest( model=self.model_name, generate_content_request=self._prepare_request( contents=contents, @@ -354,14 +354,14 @@ async def count_tokens_async( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, - ) -> glm.CountTokensResponse: + ) -> protos.CountTokensResponse: if request_options is None: request_options = {} if self._async_client is None: self._async_client = client.get_default_generative_async_client() - request = glm.CountTokensRequest( + request = protos.CountTokensRequest( model=self.model_name, generate_content_request=self._prepare_request( contents=contents, @@ -387,7 +387,7 @@ def start_chat( >>> response = chat.send_message("Hello?") Arguments: - history: An iterable of `glm.Content` objects, or equivalents to initialize the session. + history: An iterable of `protos.Content` objects, or equivalents to initialize the session. """ if self._generation_config.get("candidate_count", 1) > 1: raise ValueError("Can't chat with `candidate_count > 1`") @@ -427,8 +427,8 @@ def __init__( enable_automatic_function_calling: bool = False, ): self.model: GenerativeModel = model - self._history: list[glm.Content] = content_types.to_contents(history) - self._last_sent: glm.Content | None = None + self._history: list[protos.Content] = content_types.to_contents(history) + self._last_sent: protos.Content | None = None self._last_received: generation_types.BaseGenerateContentResponse | None = None self.enable_automatic_function_calling = enable_automatic_function_calling @@ -525,13 +525,13 @@ def _check_response(self, *, response, stream): if not stream: if response.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, ): raise generation_types.StopCandidateException(response.candidates[0]) - def _get_function_calls(self, response) -> list[glm.FunctionCall]: + def _get_function_calls(self, response) -> list[protos.FunctionCall]: candidates = response.candidates if len(candidates) != 1: raise ValueError( @@ -543,14 +543,14 @@ def _get_function_calls(self, response) -> list[glm.FunctionCall]: def _handle_afc( self, *, response, history, generation_config, safety_settings, stream, tools_lib - ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): if not all(callable(tools_lib[fc]) for fc in function_calls): break history.append(response.candidates[0].content) - function_response_parts: list[glm.Part] = [] + function_response_parts: list[protos.Part] = [] for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( @@ -559,7 +559,7 @@ def _handle_afc( ) function_response_parts.append(fr) - send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) history.append(send) response = self.model.generate_content( @@ -634,14 +634,14 @@ async def send_message_async( async def _handle_afc_async( self, *, response, history, generation_config, safety_settings, stream, tools_lib - ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): if not all(callable(tools_lib[fc]) for fc in function_calls): break history.append(response.candidates[0].content) - function_response_parts: list[glm.Part] = [] + function_response_parts: list[protos.Part] = [] for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( @@ -650,7 +650,7 @@ async def _handle_afc_async( ) function_response_parts.append(fr) - send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) history.append(send) response = await self.model.generate_content_async( @@ -673,7 +673,7 @@ def __copy__(self): history=list(self.history), ) - def rewind(self) -> tuple[glm.Content, glm.Content]: + def rewind(self) -> tuple[protos.Content, protos.Content]: """Removes the last request/response pair from the chat history.""" if self._last_received is None: result = self._history.pop(-2), self._history.pop() @@ -690,16 +690,16 @@ def last(self) -> generation_types.BaseGenerateContentResponse | None: return self._last_received @property - def history(self) -> list[glm.Content]: + def history(self) -> list[protos.Content]: """The chat history.""" last = self._last_received if last is None: return self._history if last.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, ): error = generation_types.StopCandidateException(last.candidates[0]) last._error = error @@ -737,7 +737,7 @@ def __repr__(self) -> str: _model = str(self.model).replace("\n", "\n" + " " * 4) def content_repr(x): - return f"glm.Content({_dict_repr.repr(type(x).to_dict(x))})" + return f"protos.Content({_dict_repr.repr(type(x).to_dict(x))})" try: history = list(self.history) diff --git a/google/generativeai/models.py b/google/generativeai/models.py index f25be57c6..3d914dfa7 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -17,7 +17,7 @@ import typing from typing import Any, Literal -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import operations from google.generativeai.client import get_default_model_client from google.generativeai.types import model_types @@ -137,7 +137,7 @@ def get_tuned_model( def get_base_model_name( - model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None + model: model_types.AnyModelNameOptions, client: protos.ModelServiceClient | None = None ): if isinstance(model, str): if model.startswith("tunedModels/"): @@ -149,9 +149,9 @@ def get_base_model_name( base_model = model.base_model elif isinstance(model, model_types.Model): base_model = model.name - elif isinstance(model, glm.Model): + elif isinstance(model, protos.Model): base_model = model.name - elif isinstance(model, glm.TunedModel): + elif isinstance(model, protos.TunedModel): base_model = getattr(model, "base_model", None) if not base_model: base_model = model.tuned_model_source.base_model @@ -164,7 +164,7 @@ def get_base_model_name( def list_models( *, page_size: int | None = 50, - client: glm.ModelServiceClient | None = None, + client: protos.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.ModelsIterable: """Lists available models. @@ -177,7 +177,7 @@ def list_models( Args: page_size: How many `types.Models` to fetch per page (api call). - client: You may pass a `glm.ModelServiceClient` instead of using the default client. + client: You may pass a `protos.ModelServiceClient` instead of using the default client. request_options: Options for the request. Yields: @@ -198,7 +198,7 @@ def list_models( def list_tuned_models( *, page_size: int | None = 50, - client: glm.ModelServiceClient | None = None, + client: protos.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModelsIterable: """Lists available models. @@ -211,7 +211,7 @@ def list_tuned_models( Args: page_size: How many `types.Models` to fetch per page (api call). - client: You may pass a `glm.ModelServiceClient` instead of using the default client. + client: You may pass a `protos.ModelServiceClient` instead of using the default client. request_options: Options for the request. Yields: @@ -246,7 +246,7 @@ def create_tuned_model( learning_rate: float | None = None, input_key: str = "text_input", output_key: str = "output", - client: glm.ModelServiceClient | None = None, + client: protos.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> operations.CreateTunedModelOperation: """Launches a tuning job to create a TunedModel. @@ -273,9 +273,9 @@ def create_tuned_model( Args: source_model: The name of the model to tune. training_data: The dataset to tune the model on. This must be either: - * A `glm.Dataset`, or + * A `protos.Dataset`, or * An `Iterable` of: - *`glm.TuningExample`, + *`protos.TuningExample`, * `{'text_input': text_input, 'output': output}` dicts * `(text_input, output)` tuples. * A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which @@ -328,17 +328,17 @@ def create_tuned_model( training_data, input_key=input_key, output_key=output_key ) - hyperparameters = glm.Hyperparameters( + hyperparameters = protos.Hyperparameters( epoch_count=epoch_count, batch_size=batch_size, learning_rate=learning_rate, ) - tuning_task = glm.TuningTask( + tuning_task = protos.TuningTask( training_data=training_data, hyperparameters=hyperparameters, ) - tuned_model = glm.TunedModel( + tuned_model = protos.TunedModel( **source_model, display_name=display_name, description=description, @@ -357,10 +357,10 @@ def create_tuned_model( @typing.overload def update_tuned_model( - tuned_model: glm.TunedModel, + tuned_model: protos.TunedModel, updates: None = None, *, - client: glm.ModelServiceClient | None = None, + client: protos.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -371,17 +371,17 @@ def update_tuned_model( tuned_model: str, updates: dict[str, Any], *, - client: glm.ModelServiceClient | None = None, + client: protos.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass def update_tuned_model( - tuned_model: str | glm.TunedModel, + tuned_model: str | protos.TunedModel, updates: dict[str, Any] | None = None, *, - client: glm.ModelServiceClient | None = None, + client: protos.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: """Push updates to the tuned model. Only certain attributes are updatable.""" @@ -408,10 +408,10 @@ def update_tuned_model( field_mask.paths.append(path) for path, value in updates.items(): _apply_update(tuned_model, path, value) - elif isinstance(tuned_model, glm.TunedModel): + elif isinstance(tuned_model, protos.TunedModel): if updates is not None: raise ValueError( - "When calling `update_tuned_model(tuned_model:glm.TunedModel, updates=None)`," + "When calling `update_tuned_model(tuned_model:protos.TunedModel, updates=None)`," "`updates` must not be set." ) @@ -420,12 +420,12 @@ def update_tuned_model( field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb) else: raise TypeError( - "For `update_tuned_model(tuned_model:dict|glm.TunedModel)`," - f"`tuned_model` must be a `dict` or a `glm.TunedModel`. Got a: `{type(tuned_model)}`" + "For `update_tuned_model(tuned_model:dict|protos.TunedModel)`," + f"`tuned_model` must be a `dict` or a `protos.TunedModel`. Got a: `{type(tuned_model)}`" ) result = client.update_tuned_model( - glm.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), + protos.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), **request_options, ) return model_types.decode_tuned_model(result) @@ -440,7 +440,7 @@ def _apply_update(thing, path, value): def delete_tuned_model( tuned_model: model_types.TunedModelNameOptions, - client: glm.ModelServiceClient | None = None, + client: protos.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> None: if request_options is None: diff --git a/google/generativeai/operations.py b/google/generativeai/operations.py index d492a9dee..67fb42139 100644 --- a/google/generativeai/operations.py +++ b/google/generativeai/operations.py @@ -17,7 +17,7 @@ import functools from typing import Iterator -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai.types import model_types @@ -71,8 +71,8 @@ def from_proto(cls, proto, client): cls=CreateTunedModelOperation, operation=proto, operations_client=client, - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) @classmethod @@ -107,14 +107,14 @@ def update(self): """Refresh the current statuses in metadata/result/error""" self._refresh_and_update() - def wait_bar(self, **kwargs) -> Iterator[glm.CreateTunedModelMetadata]: + def wait_bar(self, **kwargs) -> Iterator[protos.CreateTunedModelMetadata]: """A tqdm wait bar, yields `Operation` statuses until complete. Args: **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)` Yields: - Operation statuses as `glm.CreateTunedModelMetadata` objects. + Operation statuses as `protos.CreateTunedModelMetadata` objects. """ bar = tqdm.tqdm(total=self.metadata.total_steps, initial=0, **kwargs) @@ -127,7 +127,7 @@ def wait_bar(self, **kwargs) -> Iterator[glm.CreateTunedModelMetadata]: bar.update(self.metadata.completed_steps - bar.n) return self.result() - def set_result(self, result: glm.TunedModel): + def set_result(self, result: protos.TunedModel): result = model_types.decode_tuned_model(result) super().set_result(result) diff --git a/google/generativeai/permission.py b/google/generativeai/permission.py index b502f9a60..b4672a607 100644 --- a/google/generativeai/permission.py +++ b/google/generativeai/permission.py @@ -16,7 +16,7 @@ from typing import Callable -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import permission_types from google.generativeai.types import retriever_types @@ -123,7 +123,7 @@ def _construct_name( def get_permission( name: str | None = None, *, - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, resource_name: str | None = None, permission_id: str | int | None = None, resource_type: str | None = None, @@ -152,7 +152,7 @@ def get_permission( async def get_permission_async( name: str | None = None, *, - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, resource_name: str | None = None, permission_id: str | int | None = None, resource_type: str | None = None, diff --git a/google/generativeai/responder.py b/google/generativeai/responder.py index 238e7e13a..612923b03 100644 --- a/google/generativeai/responder.py +++ b/google/generativeai/responder.py @@ -22,9 +22,9 @@ import pydantic -from google.ai import generativelanguage as glm +from google.generativeai import protos -Type = glm.Type +Type = protos.Type TypeOptions = Union[int, str, Type] @@ -186,8 +186,8 @@ def _rename_schema_fields(schema: dict[str, Any]): class FunctionDeclaration: def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): - """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" - self._proto = glm.FunctionDeclaration( + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( name=name, description=description, parameters=_rename_schema_fields(parameters) ) @@ -200,7 +200,7 @@ def description(self) -> str: return self._proto.description @property - def parameters(self) -> glm.Schema: + def parameters(self) -> protos.Schema: return self._proto.parameters @classmethod @@ -209,7 +209,7 @@ def from_proto(cls, proto) -> FunctionDeclaration: self._proto = proto return self - def to_proto(self) -> glm.FunctionDeclaration: + def to_proto(self) -> protos.FunctionDeclaration: return self._proto @staticmethod @@ -255,16 +255,16 @@ def __init__( super().__init__(name=name, description=description, parameters=parameters) self.function = function - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: result = self.function(**fc.args) if not isinstance(result, dict): result = {"result": result} - return glm.FunctionResponse(name=fc.name, response=result) + return protos.FunctionResponse(name=fc.name, response=result) FunctionDeclarationType = Union[ FunctionDeclaration, - glm.FunctionDeclaration, + protos.FunctionDeclaration, dict[str, Any], Callable[..., Any], ] @@ -272,8 +272,8 @@ def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: def _make_function_declaration( fun: FunctionDeclarationType, -) -> FunctionDeclaration | glm.FunctionDeclaration: - if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): return fun elif isinstance(fun, dict): if "function" in fun: @@ -289,15 +289,15 @@ def _make_function_declaration( ) -def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: - if isinstance(fd, glm.FunctionDeclaration): +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): return fd return fd.to_proto() class Tool: - """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): # The main path doesn't use this but is seems useful. @@ -309,23 +309,23 @@ def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): raise ValueError("") self._index[fd.name] = fd - self._proto = glm.Tool( + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations] ) @property - def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: declaration = self[fc] if not callable(declaration): return None @@ -341,21 +341,21 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] def _make_tool(tool: ToolType) -> Tool: if isinstance(tool, Tool): return tool - elif isinstance(tool, glm.Tool): + elif isinstance(tool, protos.Tool): return Tool(function_declarations=tool.function_declarations) elif isinstance(tool, dict): if "function_declarations" in tool: return Tool(**tool) else: fd = tool - return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -386,20 +386,20 @@ def __init__(self, tools: Iterable[ToolType]): self._index[declaration.name] = declaration def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: declaration = self[fc] if not callable(declaration): return None response = declaration(fc) - return glm.Part(function_response=response) + return protos.Part(function_response=response) def to_proto(self): return [tool.to_proto() for tool in self._tools] @@ -432,7 +432,7 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No return FunctionLibrary(tools=lib) -FunctionCallingMode = glm.FunctionCallingConfig.Mode +FunctionCallingMode = protos.FunctionCallingConfig.Mode # fmt: off _FUNCTION_CALLING_MODE = { @@ -468,12 +468,12 @@ class FunctionCallingConfigDict(TypedDict): FunctionCallingConfigType = Union[ - FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig ] -def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, glm.FunctionCallingConfig): +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): return obj elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} @@ -483,29 +483,29 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", + f"Could not convert input to `protos.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", obj, ) - return glm.FunctionCallingConfig(obj) + return protos.FunctionCallingConfig(obj) class ToolConfigDict: function_calling_config: FunctionCallingConfigType -ToolConfigType = Union[ToolConfigDict, glm.ToolConfig] +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] -def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: - if isinstance(obj, glm.ToolConfig): +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): return obj elif isinstance(obj, dict): fcc = obj.pop("function_calling_config") fcc = to_function_calling_config(fcc) obj["function_calling_config"] = fcc - return glm.ToolConfig(**obj) + return protos.ToolConfig(**obj) else: raise TypeError( - f"Could not convert input to `glm.ToolConfig`: \n'" f" type: {type(obj)}\n", obj + f"Could not convert input to `protos.ToolConfig`: \n'" f" type: {type(obj)}\n", obj ) diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index 190a222a6..0b9e83a05 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -19,7 +19,7 @@ import dataclasses from typing import Any, AsyncIterable, Iterable, Optional -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client @@ -31,7 +31,7 @@ def create_corpus( name: str | None = None, display_name: str | None = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """ @@ -60,13 +60,13 @@ def create_corpus( client = get_default_retriever_client() if name is None: - corpus = glm.Corpus(display_name=display_name) + corpus = protos.Corpus(display_name=display_name) elif retriever_types.valid_name(name): - corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name) + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) else: raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateCorpusRequest(corpus=corpus) + request = protos.CreateCorpusRequest(corpus=corpus) response = client.create_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -78,7 +78,7 @@ def create_corpus( async def create_corpus_async( name: str | None = None, display_name: str | None = None, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """This is the async version of `retriever.create_corpus`.""" @@ -89,13 +89,13 @@ async def create_corpus_async( client = get_default_retriever_async_client() if name is None: - corpus = glm.Corpus(display_name=display_name) + corpus = protos.Corpus(display_name=display_name) elif retriever_types.valid_name(name): - corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name) + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) else: raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateCorpusRequest(corpus=corpus) + request = protos.CreateCorpusRequest(corpus=corpus) response = await client.create_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -106,7 +106,7 @@ async def create_corpus_async( def get_corpus( name: str, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """ @@ -128,7 +128,7 @@ def get_corpus( if "/" not in name: name = "corpora/" + name - request = glm.GetCorpusRequest(name=name) + request = protos.GetCorpusRequest(name=name) response = client.get_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -139,7 +139,7 @@ def get_corpus( async def get_corpus_async( name: str, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """This is the async version of `retriever.get_corpus`.""" @@ -152,7 +152,7 @@ async def get_corpus_async( if "/" not in name: name = "corpora/" + name - request = glm.GetCorpusRequest(name=name) + request = protos.GetCorpusRequest(name=name) response = await client.get_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -164,7 +164,7 @@ async def get_corpus_async( def delete_corpus( name: str, force: bool = False, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """ @@ -184,14 +184,14 @@ def delete_corpus( if "/" not in name: name = "corpora/" + name - request = glm.DeleteCorpusRequest(name=name, force=force) + request = protos.DeleteCorpusRequest(name=name, force=force) client.delete_corpus(request, **request_options) async def delete_corpus_async( name: str, force: bool = False, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """This is the async version of `retriever.delete_corpus`.""" @@ -204,14 +204,14 @@ async def delete_corpus_async( if "/" not in name: name = "corpora/" + name - request = glm.DeleteCorpusRequest(name=name, force=force) + request = protos.DeleteCorpusRequest(name=name, force=force) await client.delete_corpus(request, **request_options) def list_corpora( *, page_size: Optional[int] = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[retriever_types.Corpus]: """ @@ -231,7 +231,7 @@ def list_corpora( if client is None: client = get_default_retriever_client() - request = glm.ListCorporaRequest(page_size=page_size) + request = protos.ListCorporaRequest(page_size=page_size) for corpus in client.list_corpora(request, **request_options): corpus = type(corpus).to_dict(corpus) idecode_time(corpus, "create_time") @@ -242,7 +242,7 @@ def list_corpora( async def list_corpora_async( *, page_size: Optional[int] = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[retriever_types.Corpus]: """This is the async version of `retriever.list_corpora`.""" @@ -252,7 +252,7 @@ async def list_corpora_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListCorporaRequest(page_size=page_size) + request = protos.ListCorporaRequest(page_size=page_size) async for corpus in await client.list_corpora(request, **request_options): corpus = type(corpus).to_dict(corpus) idecode_time(corpus, "create_time") diff --git a/google/generativeai/text.py b/google/generativeai/text.py index bb5ec4bdd..3c7cf612b 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -19,7 +19,7 @@ import itertools from typing import Any, Iterable, overload, TypeVar -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_text_client from google.generativeai import string_utils @@ -52,23 +52,23 @@ def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: yield batch -def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: +def _make_text_prompt(prompt: str | dict[str, str]) -> protos.TextPrompt: """ - Creates a `glm.TextPrompt` object based on the provided prompt input. + Creates a `protos.TextPrompt` object based on the provided prompt input. Args: prompt: The prompt input, either a string or a dictionary. Returns: - glm.TextPrompt: A TextPrompt object containing the prompt text. + protos.TextPrompt: A TextPrompt object containing the prompt text. Raises: TypeError: If the provided prompt is neither a string nor a dictionary. """ if isinstance(prompt, str): - return glm.TextPrompt(text=prompt) + return protos.TextPrompt(text=prompt) elif isinstance(prompt, dict): - return glm.TextPrompt(prompt) + return protos.TextPrompt(prompt) else: TypeError("Expected string or dictionary for text prompt.") @@ -84,11 +84,11 @@ def _make_generate_text_request( top_k: int | None = None, safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, -) -> glm.GenerateTextRequest: +) -> protos.GenerateTextRequest: """ - Creates a `glm.GenerateTextRequest` object based on the provided parameters. + Creates a `protos.GenerateTextRequest` object based on the provided parameters. - This function generates a `glm.GenerateTextRequest` object with the specified + This function generates a `protos.GenerateTextRequest` object with the specified parameters. It prepares the input parameters and creates a request that can be used for generating text using the chosen model. @@ -105,7 +105,7 @@ def _make_generate_text_request( or iterable of strings. Defaults to None. Returns: - `glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. + `protos.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) @@ -115,7 +115,7 @@ def _make_generate_text_request( if stop_sequences: stop_sequences = list(stop_sequences) - return glm.GenerateTextRequest( + return protos.GenerateTextRequest( model=model, prompt=prompt, temperature=temperature, @@ -139,7 +139,7 @@ def generate_text( top_k: float | None = None, safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, - client: glm.TextServiceClient | None = None, + client: protos.TextServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.Completion: """Calls the API and returns a `types.Completion` containing the response. @@ -180,7 +180,7 @@ def generate_text( stop_sequences: A set of up to 5 character sequences that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. - client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `protos.TextServiceClient` instead. request_options: Options for the request. Returns: @@ -214,12 +214,12 @@ def __init__(self, **kwargs): def _generate_response( - request: glm.GenerateTextRequest, - client: glm.TextServiceClient = None, + request: protos.GenerateTextRequest, + client: protos.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ - Generates a response using the provided `glm.GenerateTextRequest` and client. + Generates a response using the provided `protos.GenerateTextRequest` and client. Args: request: The text generation request. @@ -251,7 +251,7 @@ def _generate_response( def count_text_tokens( model: model_types.AnyModelNameOptions, prompt: str, - client: glm.TextServiceClient | None = None, + client: protos.TextServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.TokenCount: base_model = models.get_base_model_name(model) @@ -263,7 +263,7 @@ def count_text_tokens( client = get_default_text_client() result = client.count_text_tokens( - glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), + protos.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), **request_options, ) @@ -274,7 +274,7 @@ def count_text_tokens( def generate_embeddings( model: model_types.BaseModelNameOptions, text: str, - client: glm.TextServiceClient = None, + client: protos.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -283,7 +283,7 @@ def generate_embeddings( def generate_embeddings( model: model_types.BaseModelNameOptions, text: Sequence[str], - client: glm.TextServiceClient = None, + client: protos.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -291,7 +291,7 @@ def generate_embeddings( def generate_embeddings( model: model_types.BaseModelNameOptions, text: str | Sequence[str], - client: glm.TextServiceClient = None, + client: protos.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create an embedding for the text passed in. @@ -302,7 +302,7 @@ def generate_embeddings( text: Free-form input text given to the model. Given a string, the model will generate an embedding based on the input text. - client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `protos.TextServiceClient` instead. request_options: Options for the request. @@ -318,7 +318,7 @@ def generate_embeddings( client = get_default_text_client() if isinstance(text, str): - embedding_request = glm.EmbedTextRequest(model=model, text=text) + embedding_request = protos.EmbedTextRequest(model=model, text=text) embedding_response = client.embed_text( embedding_request, **request_options, @@ -329,7 +329,7 @@ def generate_embeddings( result = {"embedding": []} for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): # TODO(markdaoust): This could use an option for returning an iterator or wait-bar. - embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch) + embedding_request = protos.BatchEmbedTextRequest(model=model, texts=batch) embedding_response = client.batch_embed_text( embedding_request, **request_options, diff --git a/google/generativeai/types/answer_types.py b/google/generativeai/types/answer_types.py index 18bd11d62..143a578a4 100644 --- a/google/generativeai/types/answer_types.py +++ b/google/generativeai/types/answer_types.py @@ -16,11 +16,11 @@ from typing import Union -import google.ai.generativelanguage as glm +from google.generativeai import protos __all__ = ["Answer"] -FinishReason = glm.Candidate.FinishReason +FinishReason = protos.Candidate.FinishReason FinishReasonOptions = Union[int, str, FinishReason] diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py index ae857c35b..9f169703f 100644 --- a/google/generativeai/types/citation_types.py +++ b/google/generativeai/types/citation_types.py @@ -17,7 +17,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -33,10 +33,10 @@ class CitationSourceDict(TypedDict): uri: str | None license: str | None - __doc__ = string_utils.strip_oneof(glm.CitationSource.__doc__) + __doc__ = string_utils.strip_oneof(protos.CitationSource.__doc__) class CitationMetadataDict(TypedDict): citation_sources: List[CitationSourceDict | None] - __doc__ = string_utils.strip_oneof(glm.CitationMetadata.__doc__) + __doc__ = string_utils.strip_oneof(protos.CitationMetadata.__doc__) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index ce72dddbc..afa33e34e 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -26,7 +26,7 @@ import pydantic from google.generativeai.types import file_types -from google.ai import generativelanguage as glm +from google.generativeai import protos if typing.TYPE_CHECKING: import PIL.Image @@ -80,10 +80,10 @@ def pil_to_blob(img): mime_type = "image/jpeg" bytesio.seek(0) data = bytesio.read() - return glm.Blob(mime_type=mime_type, data=data) + return protos.Blob(mime_type=mime_type, data=data) -def image_to_blob(image) -> glm.Blob: +def image_to_blob(image) -> protos.Blob: if PIL is not None: if isinstance(image, PIL.Image.Image): return pil_to_blob(image) @@ -101,7 +101,7 @@ def image_to_blob(image) -> glm.Blob: if mime_type is None: mime_type = "image/unknown" - return glm.Blob(mime_type=mime_type, data=image.data) + return protos.Blob(mime_type=mime_type, data=image.data) raise TypeError( "Could not convert image. expected an `Image` type" @@ -116,23 +116,23 @@ class BlobDict(TypedDict): data: bytes -def _convert_dict(d: Mapping) -> glm.Content | glm.Part | glm.Blob: +def _convert_dict(d: Mapping) -> protos.Content | protos.Part | protos.Blob: if is_content_dict(d): content = dict(d) if isinstance(parts := content["parts"], str): content["parts"] = [parts] content["parts"] = [to_part(part) for part in content["parts"]] - return glm.Content(content) + return protos.Content(content) elif is_part_dict(d): part = dict(d) if "inline_data" in part: part["inline_data"] = to_blob(part["inline_data"]) if "file_data" in part: part["file_data"] = file_types.to_file_data(part["file_data"]) - return glm.Part(part) + return protos.Part(part) elif is_blob_dict(d): blob = d - return glm.Blob(blob) + return protos.Blob(blob) else: raise KeyError( "Could not recognize the intended type of the `dict`. " @@ -149,17 +149,17 @@ def is_blob_dict(d): if typing.TYPE_CHECKING: BlobType = Union[ - glm.Blob, BlobDict, PIL.Image.Image, IPython.display.Image + protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image ] # Any for the images else: - BlobType = Union[glm.Blob, BlobDict, Any] + BlobType = Union[protos.Blob, BlobDict, Any] -def to_blob(blob: BlobType) -> glm.Blob: +def to_blob(blob: BlobType) -> protos.Blob: if isinstance(blob, Mapping): blob = _convert_dict(blob) - if isinstance(blob, glm.Blob): + if isinstance(blob, protos.Blob): return blob elif isinstance(blob, IMAGE_TYPES): return image_to_blob(blob) @@ -183,12 +183,12 @@ class PartDict(TypedDict): # When you need a `Part` accept a part object, part-dict, blob or string PartType = Union[ - glm.Part, + protos.Part, PartDict, BlobType, str, - glm.FunctionCall, - glm.FunctionResponse, + protos.FunctionCall, + protos.FunctionResponse, file_types.FileDataType, ] @@ -207,22 +207,22 @@ def to_part(part: PartType): if isinstance(part, Mapping): part = _convert_dict(part) - if isinstance(part, glm.Part): + if isinstance(part, protos.Part): return part elif isinstance(part, str): - return glm.Part(text=part) - elif isinstance(part, glm.FileData): - return glm.Part(file_data=part) - elif isinstance(part, (glm.File, file_types.File)): - return glm.Part(file_data=file_types.to_file_data(part)) - elif isinstance(part, glm.FunctionCall): - return glm.Part(function_call=part) - elif isinstance(part, glm.FunctionResponse): - return glm.Part(function_response=part) + return protos.Part(text=part) + elif isinstance(part, protos.FileData): + return protos.Part(file_data=part) + elif isinstance(part, (protos.File, file_types.File)): + return protos.Part(file_data=file_types.to_file_data(part)) + elif isinstance(part, protos.FunctionCall): + return protos.Part(function_call=part) + elif isinstance(part, protos.FunctionResponse): + return protos.Part(function_response=part) else: # Maybe it can be turned into a blob? - return glm.Part(inline_data=to_blob(part)) + return protos.Part(inline_data=to_blob(part)) class ContentDict(TypedDict): @@ -236,10 +236,10 @@ def is_content_dict(d): # When you need a message accept a `Content` object or dict, a list of parts, # or a single part -ContentType = Union[glm.Content, ContentDict, Iterable[PartType], PartType] +ContentType = Union[protos.Content, ContentDict, Iterable[PartType], PartType] # For generate_content, we're not guessing roles for [[parts],[parts],[parts]] yet. -StrictContentType = Union[glm.Content, ContentDict] +StrictContentType = Union[protos.Content, ContentDict] def to_content(content: ContentType): @@ -249,24 +249,24 @@ def to_content(content: ContentType): if isinstance(content, Mapping): content = _convert_dict(content) - if isinstance(content, glm.Content): + if isinstance(content, protos.Content): return content elif isinstance(content, Iterable) and not isinstance(content, str): - return glm.Content(parts=[to_part(part) for part in content]) + return protos.Content(parts=[to_part(part) for part in content]) else: # Maybe this is a Part? - return glm.Content(parts=[to_part(content)]) + return protos.Content(parts=[to_part(content)]) def strict_to_content(content: StrictContentType): if isinstance(content, Mapping): content = _convert_dict(content) - if isinstance(content, glm.Content): + if isinstance(content, protos.Content): return content else: raise TypeError( - "Expected a `glm.Content` or a `dict(parts=...)`.\n" + "Expected a `protos.Content` or a `dict(parts=...)`.\n" f"Got type: {type(content)}\n" f"Value: {content}\n" ) @@ -275,7 +275,7 @@ def strict_to_content(content: StrictContentType): ContentsType = Union[ContentType, Iterable[StrictContentType], None] -def to_contents(contents: ContentsType) -> list[glm.Content]: +def to_contents(contents: ContentsType) -> list[protos.Content]: if contents is None: return [] @@ -502,8 +502,8 @@ def _rename_schema_fields(schema): class FunctionDeclaration: def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): - """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" - self._proto = glm.FunctionDeclaration( + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( name=name, description=description, parameters=_rename_schema_fields(parameters) ) @@ -516,7 +516,7 @@ def description(self) -> str: return self._proto.description @property - def parameters(self) -> glm.Schema: + def parameters(self) -> protos.Schema: return self._proto.parameters @classmethod @@ -525,7 +525,7 @@ def from_proto(cls, proto) -> FunctionDeclaration: self._proto = proto return self - def to_proto(self) -> glm.FunctionDeclaration: + def to_proto(self) -> protos.FunctionDeclaration: return self._proto @staticmethod @@ -571,16 +571,16 @@ def __init__( super().__init__(name=name, description=description, parameters=parameters) self.function = function - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: result = self.function(**fc.args) if not isinstance(result, dict): result = {"result": result} - return glm.FunctionResponse(name=fc.name, response=result) + return protos.FunctionResponse(name=fc.name, response=result) FunctionDeclarationType = Union[ FunctionDeclaration, - glm.FunctionDeclaration, + protos.FunctionDeclaration, dict[str, Any], Callable[..., Any], ] @@ -588,8 +588,8 @@ def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: def _make_function_declaration( fun: FunctionDeclarationType, -) -> FunctionDeclaration | glm.FunctionDeclaration: - if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): return fun elif isinstance(fun, dict): if "function" in fun: @@ -605,15 +605,15 @@ def _make_function_declaration( ) -def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: - if isinstance(fd, glm.FunctionDeclaration): +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): return fd return fd.to_proto() class Tool: - """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): # The main path doesn't use this but is seems useful. @@ -625,23 +625,23 @@ def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): raise ValueError("") self._index[fd.name] = fd - self._proto = glm.Tool( + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations] ) @property - def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: declaration = self[fc] if not callable(declaration): return None @@ -657,21 +657,21 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] def _make_tool(tool: ToolType) -> Tool: if isinstance(tool, Tool): return tool - elif isinstance(tool, glm.Tool): + elif isinstance(tool, protos.Tool): return Tool(function_declarations=tool.function_declarations) elif isinstance(tool, dict): if "function_declarations" in tool: return Tool(**tool) else: fd = tool - return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -702,20 +702,20 @@ def __init__(self, tools: Iterable[ToolType]): self._index[declaration.name] = declaration def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: declaration = self[fc] if not callable(declaration): return None response = declaration(fc) - return glm.Part(function_response=response) + return protos.Part(function_response=response) def to_proto(self): return [tool.to_proto() for tool in self._tools] @@ -748,7 +748,7 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No return FunctionLibrary(tools=lib) -FunctionCallingMode = glm.FunctionCallingConfig.Mode +FunctionCallingMode = protos.FunctionCallingConfig.Mode # fmt: off _FUNCTION_CALLING_MODE = { @@ -784,12 +784,12 @@ class FunctionCallingConfigDict(TypedDict): FunctionCallingConfigType = Union[ - FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig ] -def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, glm.FunctionCallingConfig): +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): return obj elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} @@ -799,29 +799,29 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", + f"Could not convert input to `protos.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", obj, ) - return glm.FunctionCallingConfig(obj) + return protos.FunctionCallingConfig(obj) class ToolConfigDict: function_calling_config: FunctionCallingConfigType -ToolConfigType = Union[ToolConfigDict, glm.ToolConfig] +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] -def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: - if isinstance(obj, glm.ToolConfig): +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): return obj elif isinstance(obj, dict): fcc = obj.pop("function_calling_config") fcc = to_function_calling_config(fcc) obj["function_calling_config"] = fcc - return glm.ToolConfig(**obj) + return protos.ToolConfig(**obj) else: raise TypeError( - f"Could not convert input to `glm.ToolConfig`: \n'" f" type: {type(obj)}\n", obj + f"Could not convert input to `protos.ToolConfig`: \n'" f" type: {type(obj)}\n", obj ) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index fa777d1d1..a538da65c 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Union, Iterable, Optional, Tuple, List from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils from google.generativeai.types import palm_safety_types @@ -46,15 +46,15 @@ class TokenCount(TypedDict): class MessageDict(TypedDict): - """A dict representation of a `glm.Message`.""" + """A dict representation of a `protos.Message`.""" author: str content: str citation_metadata: Optional[citation_types.CitationMetadataDict] -MessageOptions = Union[str, MessageDict, glm.Message] -MESSAGE_OPTIONS = (str, dict, glm.Message) +MessageOptions = Union[str, MessageDict, protos.Message] +MESSAGE_OPTIONS = (str, dict, protos.Message) MessagesOptions = Union[ MessageOptions, @@ -64,7 +64,7 @@ class MessageDict(TypedDict): class ExampleDict(TypedDict): - """A dict representation of a `glm.Example`.""" + """A dict representation of a `protos.Example`.""" input: MessageOptions output: MessageOptions @@ -74,14 +74,14 @@ class ExampleDict(TypedDict): Tuple[MessageOptions, MessageOptions], Iterable[MessageOptions], ExampleDict, - glm.Example, + protos.Example, ] -EXAMPLE_OPTIONS = (glm.Example, dict, Iterable) +EXAMPLE_OPTIONS = (protos.Example, dict, Iterable) ExamplesOptions = Union[ExampleOptions, Iterable[ExampleOptions]] class MessagePromptDict(TypedDict, total=False): - """A dict representation of a `glm.MessagePrompt`.""" + """A dict representation of a `protos.MessagePrompt`.""" context: str examples: ExamplesOptions @@ -90,16 +90,16 @@ class MessagePromptDict(TypedDict, total=False): MessagePromptOptions = Union[ str, - glm.Message, - Iterable[Union[str, glm.Message]], + protos.Message, + Iterable[Union[str, protos.Message]], MessagePromptDict, - glm.MessagePrompt, + protos.MessagePrompt, ] MESSAGE_PROMPT_KEYS = {"context", "examples", "messages"} class ResponseDict(TypedDict): - """A dict representation of a `glm.GenerateMessageResponse`.""" + """A dict representation of a `protos.GenerateMessageResponse`.""" messages: List[MessageDict] candidates: List[MessageDict] diff --git a/google/generativeai/types/file_types.py b/google/generativeai/types/file_types.py index 46b0f37b9..45af92872 100644 --- a/google/generativeai/types/file_types.py +++ b/google/generativeai/types/file_types.py @@ -20,14 +20,14 @@ from google.generativeai.client import get_default_file_client -import google.ai.generativelanguage as glm +from google.generativeai import protos class File: - def __init__(self, proto: glm.File | File | dict): + def __init__(self, proto: protos.File | File | dict): if isinstance(proto, File): proto = proto.to_proto() - self._proto = glm.File(proto) + self._proto = protos.File(proto) def to_proto(self): return self._proto @@ -69,7 +69,7 @@ def uri(self) -> str: return self._proto.uri @property - def state(self) -> glm.File.State: + def state(self) -> protos.File.State: return self._proto.state def delete(self): @@ -82,26 +82,26 @@ class FileDataDict(TypedDict): file_uri: str -FileDataType = Union[FileDataDict, glm.FileData, glm.File, File] +FileDataType = Union[FileDataDict, protos.FileData, protos.File, File] def to_file_data(file_data: FileDataType): if isinstance(file_data, dict): if "file_uri" in file_data: - file_data = glm.FileData(file_data) + file_data = protos.FileData(file_data) else: - file_data = glm.File(file_data) + file_data = protos.File(file_data) if isinstance(file_data, File): file_data = file_data.to_proto() - if isinstance(file_data, glm.File): - file_data = glm.FileData( + if isinstance(file_data, protos.File): + file_data = protos.FileData( mime_type=file_data.mime_type, file_uri=file_data.uri, ) - if isinstance(file_data, glm.FileData): + if isinstance(file_data, protos.FileData): return file_data else: raise TypeError(f"Could not convert a {type(file_data)} to `FileData`") diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index f0c9de4c7..c3a3e1d0a 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -30,7 +30,7 @@ import google.protobuf.json_format import google.api_core.exceptions -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils from google.generativeai.types import content_types from google.generativeai.responder import _rename_schema_fields @@ -85,7 +85,7 @@ class GenerationConfigDict(TypedDict, total=False): max_output_tokens: int temperature: float response_mime_type: str - response_schema: glm.Schema | Mapping[str, Any] # fmt: off + response_schema: protos.Schema | Mapping[str, Any] # fmt: off @dataclasses.dataclass @@ -165,19 +165,19 @@ class GenerationConfig: top_p: float | None = None top_k: int | None = None response_mime_type: str | None = None - response_schema: glm.Schema | Mapping[str, Any] | None = None + response_schema: protos.Schema | Mapping[str, Any] | None = None -GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig] +GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig] def _normalize_schema(generation_config): - # Convert response_schema to glm.Schema for request + # Convert response_schema to protos.Schema for request response_schema = generation_config.get("response_schema", None) if response_schema is None: return - if isinstance(response_schema, glm.Schema): + if isinstance(response_schema, protos.Schema): return if isinstance(response_schema, type): @@ -191,13 +191,13 @@ def _normalize_schema(generation_config): response_schema = content_types._schema_for_class(response_schema) response_schema = _rename_schema_fields(response_schema) - generation_config["response_schema"] = glm.Schema(response_schema) + generation_config["response_schema"] = protos.Schema(response_schema) def to_generation_config_dict(generation_config: GenerationConfigType): if generation_config is None: return {} - elif isinstance(generation_config, glm.GenerationConfig): + elif isinstance(generation_config, protos.GenerationConfig): schema = generation_config.response_schema generation_config = type(generation_config).to_dict( generation_config @@ -221,14 +221,14 @@ def to_generation_config_dict(generation_config: GenerationConfigType): def _join_citation_metadatas( - citation_metadatas: Iterable[glm.CitationMetadata], + citation_metadatas: Iterable[protos.CitationMetadata], ): citation_metadatas = list(citation_metadatas) return citation_metadatas[-1] def _join_safety_ratings_lists( - safety_ratings_lists: Iterable[list[glm.SafetyRating]], + safety_ratings_lists: Iterable[list[protos.SafetyRating]], ): ratings = {} blocked = collections.defaultdict(list) @@ -243,13 +243,13 @@ def _join_safety_ratings_lists( safety_list = [] for (category, probability), blocked in zip(ratings.items(), blocked.values()): safety_list.append( - glm.SafetyRating(category=category, probability=probability, blocked=blocked) + protos.SafetyRating(category=category, probability=probability, blocked=blocked) ) return safety_list -def _join_contents(contents: Iterable[glm.Content]): +def _join_contents(contents: Iterable[protos.Content]): contents = tuple(contents) roles = [c.role for c in contents if c.role] if roles: @@ -271,22 +271,22 @@ def _join_contents(contents: Iterable[glm.Content]): merged_parts.append(part) continue - merged_part = glm.Part(merged_parts[-1]) + merged_part = protos.Part(merged_parts[-1]) merged_part.text += part.text merged_parts[-1] = merged_part - return glm.Content( + return protos.Content( role=role, parts=merged_parts, ) -def _join_candidates(candidates: Iterable[glm.Candidate]): +def _join_candidates(candidates: Iterable[protos.Candidate]): candidates = tuple(candidates) index = candidates[0].index # These should all be the same. - return glm.Candidate( + return protos.Candidate( index=index, content=_join_contents([c.content for c in candidates]), finish_reason=candidates[-1].finish_reason, @@ -296,7 +296,7 @@ def _join_candidates(candidates: Iterable[glm.Candidate]): ) -def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): +def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]): # Assuming that is a candidate ends, it is no longer returned in the list of # candidates and that's why candidates have an index candidates = collections.defaultdict(list) @@ -312,15 +312,15 @@ def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): def _join_prompt_feedbacks( - prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback], + prompt_feedbacks: Iterable[protos.GenerateContentResponse.PromptFeedback], ): # Always return the first prompt feedback. return next(iter(prompt_feedbacks)) -def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]): +def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]): chunks = tuple(chunks) - return glm.GenerateContentResponse( + return protos.GenerateContentResponse( candidates=_join_candidate_lists(c.candidates for c in chunks), prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), usage_metadata=chunks[-1].usage_metadata, @@ -338,11 +338,11 @@ def __init__( done: bool, iterator: ( None - | Iterable[glm.GenerateContentResponse] - | AsyncIterable[glm.GenerateContentResponse] + | Iterable[protos.GenerateContentResponse] + | AsyncIterable[protos.GenerateContentResponse] ), - result: glm.GenerateContentResponse, - chunks: Iterable[glm.GenerateContentResponse] | None = None, + result: protos.GenerateContentResponse, + chunks: Iterable[protos.GenerateContentResponse] | None = None, ): self._done = done self._iterator = iterator @@ -443,7 +443,7 @@ def __str__(self) -> str: as_dict = self.to_dict() json_str = json.dumps(as_dict, indent=2) - _result = f"glm.GenerateContentResponse({json_str})" + _result = f"protos.GenerateContentResponse({json_str})" _result = _result.replace("\n", "\n ") if self._error: @@ -481,7 +481,7 @@ def rewrite_stream_error(): GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method. These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`. - This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback` + This object is based on the low level `protos.GenerateContentResponse` class which just has `prompt_feedback` and `candidates` attributes. This class adds several quick accessors for common use cases. The same object type is returned for both `stream=True/False`. @@ -510,7 +510,7 @@ def rewrite_stream_error(): @string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC) class GenerateContentResponse(BaseGenerateContentResponse): @classmethod - def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): + def from_iterator(cls, iterator: Iterable[protos.GenerateContentResponse]): iterator = iter(iterator) with rewrite_stream_error(): response = next(iterator) @@ -522,7 +522,7 @@ def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): ) @classmethod - def from_response(cls, response: glm.GenerateContentResponse): + def from_response(cls, response: protos.GenerateContentResponse): return cls( done=True, iterator=None, @@ -577,7 +577,7 @@ def resolve(self): @string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC) class AsyncGenerateContentResponse(BaseGenerateContentResponse): @classmethod - async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]): + async def from_aiterator(cls, iterator: AsyncIterable[protos.GenerateContentResponse]): iterator = aiter(iterator) # type: ignore with rewrite_stream_error(): response = await anext(iterator) # type: ignore @@ -589,7 +589,7 @@ async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentRespons ) @classmethod - def from_response(cls, response: glm.GenerateContentResponse): + def from_response(cls, response: protos.GenerateContentResponse): return cls( done=True, iterator=None, diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 0f85acfe8..34213a723 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -28,7 +28,7 @@ import urllib.request from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import permission_types from google.generativeai import string_utils @@ -44,7 +44,7 @@ "TunedModelState", ] -TunedModelState = glm.TunedModel.State +TunedModelState = protos.TunedModel.State TunedModelStateOptions = Union[None, str, int, TunedModelState] @@ -91,7 +91,7 @@ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState: @string_utils.prettyprint @dataclasses.dataclass class Model: - """A dataclass representation of a `glm.Model`. + """A dataclass representation of a `protos.Model`. Attributes: name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming @@ -140,8 +140,8 @@ def idecode_time(parent: dict["str", Any], name: str): parent[name] = dt -def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedModel: - if isinstance(tuned_model, glm.TunedModel): +def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel: + if isinstance(tuned_model, protos.TunedModel): tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None)) @@ -180,7 +180,7 @@ def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedM @string_utils.prettyprint @dataclasses.dataclass class TunedModel: - """A dataclass representation of a `glm.TunedModel`.""" + """A dataclass representation of a `protos.TunedModel`.""" name: str | None = None source_model: str | None = None @@ -214,13 +214,13 @@ class TuningExampleDict(TypedDict): output: str -TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str], list[str]] +TuningExampleOptions = Union[TuningExampleDict, protos.TuningExample, tuple[str, str], list[str]] # TODO(markdaoust): gs:// URLS? File-type argument for files without extension? TuningDataOptions = Union[ pathlib.Path, str, - glm.Dataset, + protos.Dataset, Mapping[str, Iterable[str]], Iterable[TuningExampleOptions], ] @@ -228,8 +228,8 @@ class TuningExampleDict(TypedDict): def encode_tuning_data( data: TuningDataOptions, input_key="text_input", output_key="output" -) -> glm.Dataset: - if isinstance(data, glm.Dataset): +) -> protos.Dataset: + if isinstance(data, protos.Dataset): return data if isinstance(data, str): @@ -295,8 +295,8 @@ def _convert_dict(data, input_key, output_key): raise KeyError(f'output_key is "{output_key}", but data has keys: {sorted(data.keys())}') for i, o in zip(inputs, outputs): - new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)})) - return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) + new_data.append(protos.TuningExample({"text_input": str(i), "output": str(o)})) + return protos.Dataset(examples=protos.TuningExamples(examples=new_data)) def _convert_iterable(data, input_key, output_key): @@ -304,17 +304,17 @@ def _convert_iterable(data, input_key, output_key): for example in data: example = encode_tuning_example(example, input_key, output_key) new_data.append(example) - return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) + return protos.Dataset(examples=protos.TuningExamples(examples=new_data)) def encode_tuning_example(example: TuningExampleOptions, input_key, output_key): - if isinstance(example, glm.TuningExample): + if isinstance(example, protos.TuningExample): return example elif isinstance(example, (tuple, list)): a, b = example - example = glm.TuningExample(text_input=a, output=b) + example = protos.TuningExample(text_input=a, output=b) else: # dict - example = glm.TuningExample(text_input=example[input_key], output=example[output_key]) + example = protos.TuningExample(text_input=example[input_key], output=example[output_key]) return example @@ -335,14 +335,14 @@ class Hyperparameters: learning_rate: float = 0.0 -BaseModelNameOptions = Union[str, Model, glm.Model] -TunedModelNameOptions = Union[str, TunedModel, glm.TunedModel] -AnyModelNameOptions = Union[str, Model, glm.Model, TunedModel, glm.TunedModel] +BaseModelNameOptions = Union[str, Model, protos.Model] +TunedModelNameOptions = Union[str, TunedModel, protos.TunedModel] +AnyModelNameOptions = Union[str, Model, protos.Model, TunedModel, protos.TunedModel] ModelNameOptions = AnyModelNameOptions def make_model_name(name: AnyModelNameOptions): - if isinstance(name, (Model, glm.Model, TunedModel, glm.TunedModel)): + if isinstance(name, (Model, protos.Model, TunedModel, protos.TunedModel)): name = name.name # pytype: disable=attribute-error elif isinstance(name, str): name = name @@ -362,7 +362,7 @@ def make_model_name(name: AnyModelNameOptions): @string_utils.prettyprint @dataclasses.dataclass class TokenCount: - """A dataclass representation of a `glm.TokenCountResponse`. + """A dataclass representation of a `protos.TokenCountResponse`. Attributes: token_count: The number of tokens returned by the model's tokenizer for the `input_text`. diff --git a/google/generativeai/types/palm_safety_types.py b/google/generativeai/types/palm_safety_types.py index 9fb88cd67..0ab85e1b2 100644 --- a/google/generativeai/types/palm_safety_types.py +++ b/google/generativeai/types/palm_safety_types.py @@ -23,7 +23,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -39,9 +39,9 @@ ] # These are basic python enums, it's okay to expose them -HarmProbability = glm.SafetyRating.HarmProbability -HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold -BlockedReason = glm.ContentFilter.BlockedReason +HarmProbability = protos.SafetyRating.HarmProbability +HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold +BlockedReason = protos.ContentFilter.BlockedReason class HarmCategory: @@ -49,70 +49,70 @@ class HarmCategory: Harm Categories supported by the palm-family models """ - HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value - HARM_CATEGORY_DEROGATORY = glm.HarmCategory.HARM_CATEGORY_DEROGATORY.value - HARM_CATEGORY_TOXICITY = glm.HarmCategory.HARM_CATEGORY_TOXICITY.value - HARM_CATEGORY_VIOLENCE = glm.HarmCategory.HARM_CATEGORY_VIOLENCE.value - HARM_CATEGORY_SEXUAL = glm.HarmCategory.HARM_CATEGORY_SEXUAL.value - HARM_CATEGORY_MEDICAL = glm.HarmCategory.HARM_CATEGORY_MEDICAL.value - HARM_CATEGORY_DANGEROUS = glm.HarmCategory.HARM_CATEGORY_DANGEROUS.value + HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_DEROGATORY = protos.HarmCategory.HARM_CATEGORY_DEROGATORY.value + HARM_CATEGORY_TOXICITY = protos.HarmCategory.HARM_CATEGORY_TOXICITY.value + HARM_CATEGORY_VIOLENCE = protos.HarmCategory.HARM_CATEGORY_VIOLENCE.value + HARM_CATEGORY_SEXUAL = protos.HarmCategory.HARM_CATEGORY_SEXUAL.value + HARM_CATEGORY_MEDICAL = protos.HarmCategory.HARM_CATEGORY_MEDICAL.value + HARM_CATEGORY_DANGEROUS = protos.HarmCategory.HARM_CATEGORY_DANGEROUS.value HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { - glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - - glm.HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - 1: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - "harm_category_derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - "derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - - glm.HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - 2: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "harm_category_toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "toxic": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - - glm.HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - 3: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "harm_category_violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "violent": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - - glm.HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - 4: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "sex": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - - glm.HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - 5: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "harm_category_medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "med": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - - glm.HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - 6: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = { + protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + protos.HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + 1: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + "harm_category_derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + "derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + + protos.HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + 2: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "harm_category_toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxic": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + + protos.HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + 3: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "harm_category_violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violent": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + + protos.HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + 4: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "sex": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + + protos.HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + 5: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "harm_category_medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "med": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + + protos.HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + 6: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, } # fmt: on -def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory: if isinstance(x, str): x = x.lower() return _HARM_CATEGORIES[x] @@ -161,7 +161,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__) def convert_filters_to_enums( @@ -177,15 +177,15 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory probability: HarmProbability - __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": glm.HarmCategory(rating["category"]), + "category": protos.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -198,10 +198,10 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory threshold: HarmBlockThreshold - __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -251,7 +251,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": glm.HarmCategory(setting["category"]), + "category": protos.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } @@ -260,7 +260,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/google/generativeai/types/permission_types.py b/google/generativeai/types/permission_types.py index db1867695..f03657e53 100644 --- a/google/generativeai/types/permission_types.py +++ b/google/generativeai/types/permission_types.py @@ -18,7 +18,7 @@ from typing import Optional, Union, Any, Iterable, AsyncIterable import re -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.protobuf import field_mask_pb2 @@ -28,8 +28,8 @@ from google.generativeai import string_utils -GranteeType = glm.Permission.GranteeType -Role = glm.Permission.Role +GranteeType = protos.Permission.GranteeType +Role = protos.Permission.Role GranteeTypeOptions = Union[str, int, GranteeType] RoleOptions = Union[str, int, Role] @@ -101,26 +101,26 @@ class Permission: def delete( self, - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, ) -> None: """ Delete permission (self). """ if client is None: client = get_default_permission_client() - delete_request = glm.DeletePermissionRequest(name=self.name) + delete_request = protos.DeletePermissionRequest(name=self.name) client.delete_permission(request=delete_request) async def delete_async( self, - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, ) -> None: """ This is the async version of `Permission.delete`. """ if client is None: client = get_default_permission_async_client() - delete_request = glm.DeletePermissionRequest(name=self.name) + delete_request = protos.DeletePermissionRequest(name=self.name) await client.delete_permission(request=delete_request) # TODO (magashe): Add a method to validate update value. As of now only `role` is supported as a mask path @@ -133,7 +133,7 @@ def _apply_update(self, path, value): def update( self, updates: dict[str, Any], - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, ) -> Permission: """ Update a list of fields for a specified permission. @@ -161,7 +161,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - update_request = glm.UpdatePermissionRequest( + update_request = protos.UpdatePermissionRequest( permission=self._to_proto(), update_mask=field_mask ) client.update_permission(request=update_request) @@ -170,7 +170,7 @@ def update( async def update_async( self, updates: dict[str, Any], - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, ) -> Permission: """ This is the async version of `Permission.update`. @@ -191,14 +191,14 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - update_request = glm.UpdatePermissionRequest( + update_request = protos.UpdatePermissionRequest( permission=self._to_proto(), update_mask=field_mask ) await client.update_permission(request=update_request) return self - def _to_proto(self) -> glm.Permission: - return glm.Permission( + def _to_proto(self) -> protos.Permission: + return protos.Permission( name=self.name, role=self.role, grantee_type=self.grantee_type, @@ -212,7 +212,7 @@ def to_dict(self) -> dict[str, Any]: def get( cls, name: str, - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, ) -> Permission: """ Get information about a specific permission. @@ -225,7 +225,7 @@ def get( """ if client is None: client = get_default_permission_client() - get_perm_request = glm.GetPermissionRequest(name=name) + get_perm_request = protos.GetPermissionRequest(name=name) get_perm_response = client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) return cls(**get_perm_response) @@ -234,14 +234,14 @@ def get( async def get_async( cls, name: str, - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, ) -> Permission: """ This is the async version of `Permission.get`. """ if client is None: client = get_default_permission_async_client() - get_perm_request = glm.GetPermissionRequest(name=name) + get_perm_request = protos.GetPermissionRequest(name=name) get_perm_response = await client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) return cls(**get_perm_response) @@ -263,7 +263,7 @@ def _make_create_permission_request( role: RoleOptions, grantee_type: Optional[GranteeTypeOptions] = None, email_address: Optional[str] = None, - ) -> glm.CreatePermissionRequest: + ) -> protos.CreatePermissionRequest: role = to_role(role) if grantee_type: @@ -279,12 +279,12 @@ def _make_create_permission_request( f"`email_address` must be specified unless `grantee_type` is set to `EVERYONE`." ) - permission = glm.Permission( + permission = protos.Permission( role=role, grantee_type=grantee_type, email_address=email_address, ) - return glm.CreatePermissionRequest( + return protos.CreatePermissionRequest( parent=self.parent, permission=permission, ) @@ -294,7 +294,7 @@ def create( role: RoleOptions, grantee_type: Optional[GranteeTypeOptions] = None, email_address: Optional[str] = None, - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, ) -> Permission: """ Create a new permission on a resource (self). @@ -327,7 +327,7 @@ async def create_async( role: RoleOptions, grantee_type: Optional[GranteeTypeOptions] = None, email_address: Optional[str] = None, - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, ) -> Permission: """ This is the async version of `PermissionAdapter.create_permission`. @@ -345,7 +345,7 @@ async def create_async( def list( self, page_size: Optional[int] = None, - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, ) -> Iterable[Permission]: """ List `Permission`s enforced on a resource (self). @@ -360,7 +360,7 @@ def list( if client is None: client = get_default_permission_client() - request = glm.ListPermissionsRequest( + request = protos.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error ) for permission in client.list_permissions(request): @@ -370,7 +370,7 @@ def list( async def list_async( self, page_size: Optional[int] = None, - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, ) -> AsyncIterable[Permission]: """ This is the async version of `PermissionAdapter.list_permissions`. @@ -378,7 +378,7 @@ async def list_async( if client is None: client = get_default_permission_async_client() - request = glm.ListPermissionsRequest( + request = protos.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error ) async for permission in await client.list_permissions(request): @@ -388,7 +388,7 @@ async def list_async( def transfer_ownership( self, email_address: str, - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, ) -> None: """ Transfer ownership of a resource (self) to a new owner. @@ -401,7 +401,7 @@ def transfer_ownership( raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: client = get_default_permission_client() - transfer_request = glm.TransferOwnershipRequest( + transfer_request = protos.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) return client.transfer_ownership(request=transfer_request) @@ -409,14 +409,14 @@ def transfer_ownership( async def transfer_ownership_async( self, email_address: str, - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, ) -> None: """This is the async version of `PermissionAdapter.transfer_ownership`.""" if self.parent.startswith("corpora"): raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: client = get_default_permission_async_client() - transfer_request = glm.TransferOwnershipRequest( + transfer_request = protos.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) return await client.transfer_ownership(request=transfer_request) diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index 538d3924a..f7890fed6 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -21,7 +21,7 @@ from typing import Any, AsyncIterable, Optional, Union, Iterable, Mapping from typing_extensions import deprecated # type: ignore -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.protobuf import field_mask_pb2 from google.generativeai.client import get_default_retriever_client @@ -44,14 +44,14 @@ def valid_name(name): return re.match(_VALID_NAME, name) and len(name) < 40 -Operator = glm.Condition.Operator -State = glm.Chunk.State +Operator = protos.Condition.Operator +State = protos.Chunk.State OperatorOptions = Union[str, int, Operator] StateOptions = Union[str, int, State] ChunkOptions = Union[ - glm.Chunk, + protos.Chunk, str, tuple[str, str], tuple[str, str, Any], @@ -59,17 +59,17 @@ def valid_name(name): ] # fmt: no BatchCreateChunkOptions = Union[ - glm.BatchCreateChunksRequest, + protos.BatchCreateChunksRequest, Mapping[str, str], Mapping[str, tuple[str, str]], Iterable[ChunkOptions], ] # fmt: no -UpdateChunkOptions = Union[glm.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] +UpdateChunkOptions = Union[protos.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] -BatchUpdateChunksOptions = Union[glm.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] +BatchUpdateChunksOptions = Union[protos.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] -BatchDeleteChunkOptions = Union[list[glm.DeleteChunkRequest], Iterable[str]] +BatchDeleteChunkOptions = Union[list[protos.DeleteChunkRequest], Iterable[str]] _OPERATOR: dict[OperatorOptions, Operator] = { Operator.OPERATOR_UNSPECIFIED: Operator.OPERATOR_UNSPECIFIED, @@ -163,10 +163,10 @@ def _to_proto(self): ) kwargs["operation"] = c.operation - condition = glm.Condition(**kwargs) + condition = protos.Condition(**kwargs) conditions.append(condition) - return glm.MetadataFilter(key=self.key, conditions=conditions) + return protos.MetadataFilter(key=self.key, conditions=conditions) @string_utils.prettyprint @@ -188,10 +188,10 @@ def _to_proto(self): kwargs["string_value"] = self.value elif isinstance(self.value, Iterable): if isinstance(self.value, Mapping): - # If already converted to a glm.StringList, get the values + # If already converted to a protos.StringList, get the values kwargs["string_list_value"] = self.value else: - kwargs["string_list_value"] = glm.StringList(values=self.value) + kwargs["string_list_value"] = protos.StringList(values=self.value) elif isinstance(self.value, (int, float)): kwargs["numeric_value"] = float(self.value) else: @@ -199,7 +199,7 @@ def _to_proto(self): f"The value for a custom_metadata specification must be either a list of string values, a string, or an integer/float, but got {self.value}." ) - return glm.CustomMetadata(key=self.key, **kwargs) + return protos.CustomMetadata(key=self.key, **kwargs) @classmethod def _from_dict(cls, cm): @@ -217,14 +217,14 @@ def _to_dict(self): return type(proto).to_dict(proto) -CustomMetadataOptions = Union[CustomMetadata, glm.CustomMetadata, dict] +CustomMetadataOptions = Union[CustomMetadata, protos.CustomMetadata, dict] def make_custom_metadata(cm: CustomMetadataOptions) -> CustomMetadata: if isinstance(cm, CustomMetadata): return cm - if isinstance(cm, glm.CustomMetadata): + if isinstance(cm, protos.CustomMetadata): cm = type(cm).to_dict(cm) if isinstance(cm, dict): @@ -262,7 +262,7 @@ def create_document( name: str | None = None, display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ @@ -294,9 +294,9 @@ def create_document( c_data.append(cm._to_proto()) if name is None: - document = glm.Document(display_name=display_name, custom_metadata=c_data) + document = protos.Document(display_name=display_name, custom_metadata=c_data) elif valid_name(name): - document = glm.Document( + document = protos.Document( name=f"{self.name}/documents/{name}", display_name=display_name, custom_metadata=c_data, @@ -304,7 +304,7 @@ def create_document( else: raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateDocumentRequest(parent=self.name, document=document) + request = protos.CreateDocumentRequest(parent=self.name, document=document) response = client.create_document(request, **request_options) return decode_document(response) @@ -313,7 +313,7 @@ async def create_document_async( name: str | None = None, display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.create_document`.""" @@ -330,9 +330,9 @@ async def create_document_async( c_data.append(cm._to_proto()) if name is None: - document = glm.Document(display_name=display_name, custom_metadata=c_data) + document = protos.Document(display_name=display_name, custom_metadata=c_data) elif valid_name(name): - document = glm.Document( + document = protos.Document( name=f"{self.name}/documents/{name}", display_name=display_name, custom_metadata=c_data, @@ -340,14 +340,14 @@ async def create_document_async( else: raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateDocumentRequest(parent=self.name, document=document) + request = protos.CreateDocumentRequest(parent=self.name, document=document) response = await client.create_document(request, **request_options) return decode_document(response) def get_document( self, name: str, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ @@ -369,14 +369,14 @@ def get_document( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.GetDocumentRequest(name=name) + request = protos.GetDocumentRequest(name=name) response = client.get_document(request, **request_options) return decode_document(response) async def get_document_async( self, name: str, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.get_document`.""" @@ -389,7 +389,7 @@ async def get_document_async( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.GetDocumentRequest(name=name) + request = protos.GetDocumentRequest(name=name) response = await client.get_document(request, **request_options) return decode_document(response) @@ -402,7 +402,7 @@ def _apply_update(self, path, value): def update( self, updates: dict[str, Any], - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -433,14 +433,14 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + request = protos.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) client.update_corpus(request, **request_options) return self async def update_async( self, updates: dict[str, Any], - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.update`.""" @@ -462,7 +462,7 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + request = protos.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) await client.update_corpus(request, **request_options) return self @@ -471,7 +471,7 @@ def query( query: str, metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """ @@ -501,7 +501,7 @@ def query( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryCorpusRequest( + request = protos.QueryCorpusRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -525,7 +525,7 @@ async def query_async( query: str, metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """This is the async version of `Corpus.query`.""" @@ -544,7 +544,7 @@ async def query_async( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryCorpusRequest( + request = protos.QueryCorpusRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -567,7 +567,7 @@ def delete_document( self, name: str, force: bool = False, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -587,14 +587,14 @@ def delete_document( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.DeleteDocumentRequest(name=name, force=bool(force)) + request = protos.DeleteDocumentRequest(name=name, force=bool(force)) client.delete_document(request, **request_options) async def delete_document_async( self, name: str, force: bool = False, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.delete_document`.""" @@ -607,13 +607,13 @@ async def delete_document_async( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.DeleteDocumentRequest(name=name, force=bool(force)) + request = protos.DeleteDocumentRequest(name=name, force=bool(force)) await client.delete_document(request, **request_options) def list_documents( self, page_size: int | None = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Document]: """ @@ -633,7 +633,7 @@ def list_documents( if client is None: client = get_default_retriever_client() - request = glm.ListDocumentsRequest( + request = protos.ListDocumentsRequest( parent=self.name, page_size=page_size, ) @@ -643,7 +643,7 @@ def list_documents( async def list_documents_async( self, page_size: int | None = None, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Document]: """This is the async version of `Corpus.list_documents`.""" @@ -653,7 +653,7 @@ async def list_documents_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListDocumentsRequest( + request = protos.ListDocumentsRequest( parent=self.name, page_size=page_size, ) @@ -670,7 +670,7 @@ def create_permission( role: permission_types.RoleOptions, grantee_type: Optional[permission_types.GranteeTypeOptions] = None, email_address: Optional[str] = None, - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, ) -> permission_types.Permission: return self.permissions.create( role=role, grantee_type=grantee_type, email_address=email_address, client=client @@ -685,7 +685,7 @@ async def create_permission_async( role: permission_types.RoleOptions, grantee_type: Optional[permission_types.GranteeTypeOptions] = None, email_address: Optional[str] = None, - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, ) -> permission_types.Permission: return await self.permissions.create_async( role=role, grantee_type=grantee_type, email_address=email_address, client=client @@ -698,7 +698,7 @@ async def create_permission_async( def list_permissions( self, page_size: Optional[int] = None, - client: glm.PermissionServiceClient | None = None, + client: protos.PermissionServiceClient | None = None, ) -> Iterable[permission_types.Permission]: return self.permissions.list(page_size=page_size, client=client) @@ -709,7 +709,7 @@ def list_permissions( async def list_permissions_async( self, page_size: Optional[int] = None, - client: glm.PermissionServiceAsyncClient | None = None, + client: protos.PermissionServiceAsyncClient | None = None, ) -> AsyncIterable[permission_types.Permission]: return self.permissions.list_async(page_size=page_size, client=client) @@ -745,7 +745,7 @@ def create_chunk( data: str | ChunkData, name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """ @@ -785,15 +785,15 @@ def create_chunk( chunk_name = name if isinstance(data, str): - chunk = glm.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) else: - chunk = glm.Chunk( + chunk = protos.Chunk( name=chunk_name, data={"string_value": data.string_value}, custom_metadata=c_data, ) - request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + request = protos.CreateChunkRequest(parent=self.name, chunk=chunk) response = client.create_chunk(request, **request_options) return decode_chunk(response) @@ -802,7 +802,7 @@ async def create_chunk_async( data: str | ChunkData, name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """This is the async version of `Document.create_chunk`.""" @@ -827,24 +827,24 @@ async def create_chunk_async( chunk_name = name if isinstance(data, str): - chunk = glm.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) else: - chunk = glm.Chunk( + chunk = protos.Chunk( name=chunk_name, data={"string_value": data.string_value}, custom_metadata=c_data, ) - request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + request = protos.CreateChunkRequest(parent=self.name, chunk=chunk) response = await client.create_chunk(request, **request_options) return decode_chunk(response) - def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: + def _make_chunk(self, chunk: ChunkOptions) -> protos.Chunk: # del self - if isinstance(chunk, glm.Chunk): - return glm.Chunk(chunk) + if isinstance(chunk, protos.Chunk): + return protos.Chunk(chunk) elif isinstance(chunk, str): - return glm.Chunk(data={"string_value": chunk}) + return protos.Chunk(data={"string_value": chunk}) elif isinstance(chunk, tuple): if len(chunk) == 2: name, data = chunk # pytype: disable=bad-unpacking @@ -857,7 +857,7 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: f"value: {chunk}" ) - return glm.Chunk( + return protos.Chunk( name=name, data={"string_value": data}, custom_metadata=custom_metadata, @@ -866,7 +866,7 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: if isinstance(chunk["data"], str): chunk = dict(chunk) chunk["data"] = {"string_value": chunk["data"]} - return glm.Chunk(chunk) + return protos.Chunk(chunk) else: raise TypeError( f"Could not convert instance of `{type(chunk)}` chunk:" f"value: {chunk}" @@ -874,8 +874,8 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: def _make_batch_create_chunk_request( self, chunks: BatchCreateChunkOptions - ) -> glm.BatchCreateChunksRequest: - if isinstance(chunks, glm.BatchCreateChunksRequest): + ) -> protos.BatchCreateChunksRequest: + if isinstance(chunks, protos.BatchCreateChunksRequest): return chunks if isinstance(chunks, Mapping): @@ -894,14 +894,14 @@ def _make_batch_create_chunk_request( chunk.name = f"{self.name}/chunks/{chunk.name}" - requests.append(glm.CreateChunkRequest(parent=self.name, chunk=chunk)) + requests.append(protos.CreateChunkRequest(parent=self.name, chunk=chunk)) - return glm.BatchCreateChunksRequest(parent=self.name, requests=requests) + return protos.BatchCreateChunksRequest(parent=self.name, requests=requests) def batch_create_chunks( self, chunks: BatchCreateChunkOptions, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -927,7 +927,7 @@ def batch_create_chunks( async def batch_create_chunks_async( self, chunks: BatchCreateChunkOptions, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_create_chunk`.""" @@ -944,7 +944,7 @@ async def batch_create_chunks_async( def get_chunk( self, name: str, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -966,14 +966,14 @@ def get_chunk( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.GetChunkRequest(name=name) + request = protos.GetChunkRequest(name=name) response = client.get_chunk(request, **request_options) return decode_chunk(response) async def get_chunk_async( self, name: str, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.get_chunk`.""" @@ -986,14 +986,14 @@ async def get_chunk_async( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.GetChunkRequest(name=name) + request = protos.GetChunkRequest(name=name) response = await client.get_chunk(request, **request_options) return decode_chunk(response) def list_chunks( self, page_size: int | None = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Chunk]: """ @@ -1012,14 +1012,14 @@ def list_chunks( if client is None: client = get_default_retriever_client() - request = glm.ListChunksRequest(parent=self.name, page_size=page_size) + request = protos.ListChunksRequest(parent=self.name, page_size=page_size) for chunk in client.list_chunks(request, **request_options): yield decode_chunk(chunk) async def list_chunks_async( self, page_size: int | None = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Chunk]: """This is the async version of `Document.list_chunks`.""" @@ -1029,7 +1029,7 @@ async def list_chunks_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListChunksRequest(parent=self.name, page_size=page_size) + request = protos.ListChunksRequest(parent=self.name, page_size=page_size) async for chunk in await client.list_chunks(request, **request_options): yield decode_chunk(chunk) @@ -1038,7 +1038,7 @@ def query( query: str, metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """ @@ -1067,7 +1067,7 @@ def query( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryDocumentRequest( + request = protos.QueryDocumentRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -1091,7 +1091,7 @@ async def query_async( query: str, metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """This is the async version of `Document.query`.""" @@ -1110,7 +1110,7 @@ async def query_async( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryDocumentRequest( + request = protos.QueryDocumentRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -1138,7 +1138,7 @@ def _apply_update(self, path, value): def update( self, updates: dict[str, Any], - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -1168,14 +1168,14 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + request = protos.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) client.update_document(request, **request_options) return self async def update_async( self, updates: dict[str, Any], - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.update`.""" @@ -1196,14 +1196,14 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + request = protos.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) await client.update_document(request, **request_options) return self def batch_update_chunks( self, chunks: BatchUpdateChunksOptions, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -1222,7 +1222,7 @@ def batch_update_chunks( if client is None: client = get_default_retriever_client() - if isinstance(chunks, glm.BatchUpdateChunksRequest): + if isinstance(chunks, protos.BatchUpdateChunksRequest): response = client.batch_update_chunks(chunks) response = type(response).to_dict(response) return response @@ -1255,15 +1255,15 @@ def batch_update_chunks( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - glm.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): for chunk in chunks: - if isinstance(chunk, glm.UpdateChunkRequest): + if isinstance(chunk, protos.UpdateChunkRequest): _requests.append(chunk) elif isinstance(chunk, tuple): # First element is name of chunk, second element contains updates @@ -1289,10 +1289,10 @@ def batch_update_chunks( ) else: raise TypeError( - "The `chunks` parameter must be a list of glm.UpdateChunkRequests," + "The `chunks` parameter must be a list of protos.UpdateChunkRequests," "dictionaries, or tuples of dictionaries." ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response @@ -1300,7 +1300,7 @@ def batch_update_chunks( async def batch_update_chunks_async( self, chunks: BatchUpdateChunksOptions, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_update_chunks`.""" @@ -1310,7 +1310,7 @@ async def batch_update_chunks_async( if client is None: client = get_default_retriever_async_client() - if isinstance(chunks, glm.BatchUpdateChunksRequest): + if isinstance(chunks, protos.BatchUpdateChunksRequest): response = client.batch_update_chunks(chunks) response = type(response).to_dict(response) return response @@ -1343,15 +1343,15 @@ async def batch_update_chunks_async( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - glm.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): for chunk in chunks: - if isinstance(chunk, glm.UpdateChunkRequest): + if isinstance(chunk, protos.UpdateChunkRequest): _requests.append(chunk) elif isinstance(chunk, tuple): # First element is name of chunk, second element contains updates @@ -1377,10 +1377,10 @@ async def batch_update_chunks_async( ) else: raise TypeError( - "The `chunks` parameter must be a list of glm.UpdateChunkRequests," + "The `chunks` parameter must be a list of protos.UpdateChunkRequests," "dictionaries, or tuples of dictionaries." ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response @@ -1388,7 +1388,7 @@ async def batch_update_chunks_async( def delete_chunk( self, name: str, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """ @@ -1407,13 +1407,13 @@ def delete_chunk( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.DeleteChunkRequest(name=name) + request = protos.DeleteChunkRequest(name=name) client.delete_chunk(request, **request_options) async def delete_chunk_async( self, name: str, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """This is the async version of `Document.delete_chunk`.""" @@ -1426,13 +1426,13 @@ async def delete_chunk_async( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.DeleteChunkRequest(name=name) + request = protos.DeleteChunkRequest(name=name) await client.delete_chunk(request, **request_options) def batch_delete_chunks( self, chunks: BatchDeleteChunkOptions, - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -1448,24 +1448,24 @@ def batch_delete_chunks( if client is None: client = get_default_retriever_client() - if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + if all(isinstance(x, protos.DeleteChunkRequest) for x in chunks): + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=chunks) client.batch_delete_chunks(request, **request_options) elif isinstance(chunks, Iterable): _request_list = [] for chunk_name in chunks: - _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + _request_list.append(protos.DeleteChunkRequest(name=chunk_name)) + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `glm.DeleteChunkRequest`s." + "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `protos.DeleteChunkRequest`s." ) async def batch_delete_chunks_async( self, chunks: BatchDeleteChunkOptions, - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_delete_chunks`.""" @@ -1475,18 +1475,18 @@ async def batch_delete_chunks_async( if client is None: client = get_default_retriever_async_client() - if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + if all(isinstance(x, protos.DeleteChunkRequest) for x in chunks): + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=chunks) await client.batch_delete_chunks(request, **request_options) elif isinstance(chunks, Iterable): _request_list = [] for chunk_name in chunks: - _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + _request_list.append(protos.DeleteChunkRequest(name=chunk_name)) + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) await client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `glm.DeleteChunkRequest`s." + "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `protos.DeleteChunkRequest`s." ) def to_dict(self) -> dict[str, Any]: @@ -1498,7 +1498,7 @@ def to_dict(self) -> dict[str, Any]: return result -def decode_chunk(chunk: glm.Chunk) -> Chunk: +def decode_chunk(chunk: protos.Chunk) -> Chunk: chunk = type(chunk).to_dict(chunk) idecode_time(chunk, "create_time") idecode_time(chunk, "update_time") @@ -1571,7 +1571,7 @@ def _apply_update(self, path, value): def update( self, updates: dict[str, Any], - client: glm.RetrieverServiceClient | None = None, + client: protos.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -1612,7 +1612,7 @@ def update( field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + request = protos.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) client.update_chunk(request, **request_options) return self @@ -1620,7 +1620,7 @@ def update( async def update_async( self, updates: dict[str, Any], - client: glm.RetrieverServiceAsyncClient | None = None, + client: protos.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Chunk.update`.""" @@ -1652,7 +1652,7 @@ async def update_async( field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + request = protos.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) await client.update_chunk(request, **request_options) return self diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 85e57c8f6..8e07ca51d 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -23,7 +23,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -39,9 +39,9 @@ ] # These are basic python enums, it's okay to expose them -HarmProbability = glm.SafetyRating.HarmProbability -HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold -BlockedReason = glm.ContentFilter.BlockedReason +HarmProbability = protos.SafetyRating.HarmProbability +HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold +BlockedReason = protos.ContentFilter.BlockedReason import proto @@ -51,57 +51,57 @@ class HarmCategory(proto.Enum): Harm Categories supported by the gemini-family model """ - HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value - HARM_CATEGORY_HARASSMENT = glm.HarmCategory.HARM_CATEGORY_HARASSMENT.value - HARM_CATEGORY_HATE_SPEECH = glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value - HARM_CATEGORY_SEXUALLY_EXPLICIT = glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value - HARM_CATEGORY_DANGEROUS_CONTENT = glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value + HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_HARASSMENT = protos.HarmCategory.HARM_CATEGORY_HARASSMENT.value + HARM_CATEGORY_HATE_SPEECH = protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value + HARM_CATEGORY_SEXUALLY_EXPLICIT = protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value + HARM_CATEGORY_DANGEROUS_CONTENT = protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { - glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = { + protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 7: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - glm.HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - "harm_category_harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - "harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - - 8: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'harm_category_hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - - 9: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sex": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - - 10: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous_content": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + 7: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + protos.HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harm_category_harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + + 8: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'harm_category_hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + + 9: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sex": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + + 10: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous_content": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, } # fmt: on -def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory: if isinstance(x, str): x = x.lower() return _HARM_CATEGORIES[x] @@ -150,7 +150,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__) def convert_filters_to_enums( @@ -166,15 +166,15 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory probability: HarmProbability - __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": glm.HarmCategory(rating["category"]), + "category": protos.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -187,10 +187,10 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory threshold: HarmBlockThreshold - __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -240,7 +240,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": glm.HarmCategory(setting["category"]), + "category": protos.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } @@ -249,7 +249,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/tests/test_answer.py b/tests/test_answer.py index 4128567f4..25f824e86 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import answer from google.generativeai import types as genai_types @@ -47,14 +47,14 @@ def add_client_method(f): @add_client_method def generate_answer( - request: glm.GenerateAnswerRequest, + request: protos.GenerateAnswerRequest, **kwargs, - ) -> glm.GenerateAnswerResponse: + ) -> protos.GenerateAnswerResponse: self.observed_requests.append(request) - return glm.GenerateAnswerResponse( - answer=glm.Candidate( + return protos.GenerateAnswerResponse( + answer=protos.Candidate( index=1, - content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + content=(protos.Content(parts=[protos.Part(text="Demo answer.")])), ), answerable_probability=0.500, ) @@ -62,17 +62,17 @@ def generate_answer( def test_make_grounding_passages_mixed_types(self): inline_passages = [ "I am a chicken", - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), + protos.Content(parts=[protos.Part(text="I am a bird.")]), + protos.Content(parts=[protos.Part(text="I can fly!")]), ] x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -82,23 +82,23 @@ def test_make_grounding_passages_mixed_types(self): [ dict( testcase_name="grounding_passage", - inline_passages=glm.GroundingPassages( + inline_passages=protos.GroundingPassages( passages=[ { "id": "0", - "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), }, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), ), dict( testcase_name="content_object", inline_passages=[ - glm.Content(parts=[glm.Part(text="I am a chicken")]), - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), + protos.Content(parts=[protos.Part(text="I am a chicken")]), + protos.Content(parts=[protos.Part(text="I am a bird.")]), + protos.Content(parts=[protos.Part(text="I can fly!")]), ], ), dict( @@ -109,13 +109,13 @@ def test_make_grounding_passages_mixed_types(self): ) def test_make_grounding_passages(self, inline_passages): x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -133,27 +133,27 @@ def test_make_grounding_passages(self, inline_passages): dict( testcase_name="list_of_grounding_passages", inline_passages=[ - glm.GroundingPassage( - id="4", content=glm.Content(parts=[glm.Part(text="I am a chicken")]) + protos.GroundingPassage( + id="4", content=protos.Content(parts=[protos.Part(text="I am a chicken")]) ), - glm.GroundingPassage( - id="5", content=glm.Content(parts=[glm.Part(text="I am a bird.")]) + protos.GroundingPassage( + id="5", content=protos.Content(parts=[protos.Part(text="I am a bird.")]) ), - glm.GroundingPassage( - id="6", content=glm.Content(parts=[glm.Part(text="I can fly!")]) + protos.GroundingPassage( + id="6", content=protos.Content(parts=[protos.Part(text="I can fly!")]) ), ], ), ) def test_make_grounding_passages_different_id(self, inline_passages): x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "4", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "5", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "6", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "4", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "5", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "6", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -167,16 +167,16 @@ def test_make_grounding_passages_key_strings(self): } x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ { "id": "first", - "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), }, - {"id": "second", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "third", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "second", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "third", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -184,14 +184,14 @@ def test_make_grounding_passages_key_strings(self): def test_generate_answer_request(self): # Should be a list of contents to use to_contents() function. - contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + contents = [protos.Content(parts=[protos.Part(text="I have wings.")])] inline_passages = ["I am a chicken", "I am a bird.", "I can fly!"] - grounding_passages = glm.GroundingPassages( + grounding_passages = protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ) @@ -200,7 +200,7 @@ def test_generate_answer_request(self): ) self.assertEqual( - glm.GenerateAnswerRequest( + protos.GenerateAnswerRequest( model=DEFAULT_ANSWER_MODEL, contents=contents, inline_passages=grounding_passages ), x, @@ -208,13 +208,13 @@ def test_generate_answer_request(self): def test_generate_answer(self): # Test handling return value of generate_answer(). - contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + contents = [protos.Content(parts=[protos.Part(text="I have wings.")])] - grounding_passages = glm.GroundingPassages( + grounding_passages = protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ) @@ -225,13 +225,13 @@ def test_generate_answer(self): answer_style="ABSTRACTIVE", ) - self.assertIsInstance(a, glm.GenerateAnswerResponse) + self.assertIsInstance(a, protos.GenerateAnswerResponse) self.assertEqual( a, - glm.GenerateAnswerResponse( - answer=glm.Candidate( + protos.GenerateAnswerResponse( + answer=protos.Candidate( index=1, - content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + content=(protos.Content(parts=[protos.Part(text="Demo answer.")])), ), answerable_probability=0.500, ), diff --git a/tests/test_client.py b/tests/test_client.py index 34a0f9fc3..84c3e83b4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,8 +4,10 @@ from absl.testing import absltest from absl.testing import parameterized -from google.api_core import client_options import google.ai.generativelanguage as glm + +from google.api_core import client_options +from google.generativeai import protos from google.generativeai import client diff --git a/tests/test_content.py b/tests/test_content.py index 5f22b93a1..3829ebc86 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -19,7 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import content_types import IPython.display import PIL.Image @@ -71,7 +71,7 @@ class UnitTests(parameterized.TestCase): ) def test_png_to_blob(self, image): blob = content_types.image_to_blob(image) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -81,29 +81,29 @@ def test_png_to_blob(self, image): ) def test_jpg_to_blob(self, image): blob = content_types.image_to_blob(image) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/jpeg") self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF") @parameterized.named_parameters( ["BlobDict", {"mime_type": "image/png", "data": TEST_PNG_DATA}], - ["glm.Blob", glm.Blob(mime_type="image/png", data=TEST_PNG_DATA)], + ["protos.Blob", protos.Blob(mime_type="image/png", data=TEST_PNG_DATA)], ["Image", IPython.display.Image(filename=TEST_PNG_PATH)], ) def test_to_blob(self, example): blob = content_types.to_blob(example) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( ["dict", {"text": "Hello world!"}], - ["glm.Part", glm.Part(text="Hello world!")], + ["protos.Part", protos.Part(text="Hello world!")], ["str", "Hello world!"], ) def test_to_part(self, example): part = content_types.to_part(example) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -116,12 +116,12 @@ def test_to_part(self, example): ) def test_img_to_part(self, example): blob = content_types.to_part(example).inline_data - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( - ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["protos.Content", protos.Content(parts=[{"text": "Hello world!"}])], ["ContentDict", {"parts": [{"text": "Hello world!"}]}], ["ContentDict-str", {"parts": ["Hello world!"]}], ["list[parts]", [{"text": "Hello world!"}]], @@ -135,7 +135,7 @@ def test_to_content(self, example): part = content.parts[0] self.assertLen(content.parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -147,12 +147,12 @@ def test_img_to_content(self, example): content = content_types.to_content(example) blob = content.parts[0].inline_data self.assertLen(content.parts, 1) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( - ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["protos.Content", protos.Content(parts=[{"text": "Hello world!"}])], ["ContentDict", {"parts": [{"text": "Hello world!"}]}], ["ContentDict-str", {"parts": ["Hello world!"]}], ) @@ -161,7 +161,7 @@ def test_strict_to_content(self, example): part = content.parts[0] self.assertLen(content.parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -176,7 +176,7 @@ def test_strict_to_contents_fails(self, examples): content_types.strict_to_content(examples) @parameterized.named_parameters( - ["glm.Content", [glm.Content(parts=[{"text": "Hello world!"}])]], + ["protos.Content", [protos.Content(parts=[{"text": "Hello world!"}])]], ["ContentDict", [{"parts": [{"text": "Hello world!"}]}]], ["ContentDict-unwraped", [{"parts": ["Hello world!"]}]], ["ContentDict+str-part", [{"parts": "Hello world!"}]], @@ -188,7 +188,7 @@ def test_to_contents(self, example): self.assertLen(contents, 1) self.assertLen(contents[0].parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") def test_dict_to_content_fails(self): @@ -209,7 +209,7 @@ def test_img_to_contents(self, example): self.assertLen(contents, 1) self.assertLen(contents[0].parts, 1) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -217,9 +217,9 @@ def test_img_to_contents(self, example): [ "FunctionLibrary", content_types.FunctionLibrary( - tools=glm.Tool( + tools=protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -231,7 +231,7 @@ def test_img_to_contents(self, example): [ content_types.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -239,11 +239,11 @@ def test_img_to_contents(self, example): ], ], [ - "IterableTool-glm.Tool", + "IterableTool-protos.Tool", [ - glm.Tool( + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -268,7 +268,7 @@ def test_img_to_contents(self, example): "IterableTool-IterableFD", [ [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -278,7 +278,7 @@ def test_img_to_contents(self, example): [ "IterableTool-FD", [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -288,17 +288,17 @@ def test_img_to_contents(self, example): "Tool", content_types.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] ), ], [ - "glm.Tool", - glm.Tool( + "protos.Tool", + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -350,8 +350,8 @@ def test_img_to_contents(self, example): ), ], [ - "glm.FD", - glm.FunctionDeclaration( + "protos.FD", + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ), ], @@ -391,83 +391,83 @@ def b(): self.assertLen(tools[0].function_declarations, 2) @parameterized.named_parameters( - ["int", int, glm.Schema(type=glm.Type.INTEGER)], - ["float", float, glm.Schema(type=glm.Type.NUMBER)], - ["str", str, glm.Schema(type=glm.Type.STRING)], - ["nullable_str", Union[str, None], glm.Schema(type=glm.Type.STRING, nullable=True)], + ["int", int, protos.Schema(type=protos.Type.INTEGER)], + ["float", float, protos.Schema(type=protos.Type.NUMBER)], + ["str", str, protos.Schema(type=protos.Type.STRING)], + ["nullable_str", Union[str, None], protos.Schema(type=protos.Type.STRING, nullable=True)], [ "list", list[str], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.STRING), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.STRING), ), ], [ "list-list-int", list[list[int]], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema( - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.INTEGER), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema( + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.INTEGER), ), ), ), ], - ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], - ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ["dict", dict, protos.Schema(type=protos.Type.OBJECT)], + ["dict-str-any", dict[str, Any], protos.Schema(type=protos.Type.OBJECT)], [ "dataclass", ADataClass, - glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER}}, + protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ], [ "nullable_dataclass", Union[ADataClass, None], - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, nullable=True, - properties={"a": {"type_": glm.Type.INTEGER}}, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ], [ "list_of_dataclass", list[ADataClass], - glm.Schema( + protos.Schema( type="ARRAY", - items=glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER}}, + items=protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ), ], [ "dataclass_with_nullable", ADataClassWithNullable, - glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER, "nullable": True}}, + protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER, "nullable": True}}, ), ], [ "dataclass_with_list", ADataClassWithList, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, ), ], [ "list_of_dataclass_with_list", list[ADataClassWithList], - glm.Schema( - items=glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + items=protos.Schema( + type=protos.Type.OBJECT, properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, ), type="ARRAY", @@ -476,31 +476,31 @@ def b(): [ "list_of_nullable", list[Union[int, None]], - glm.Schema( + protos.Schema( type="ARRAY", - items={"type_": glm.Type.INTEGER, "nullable": True}, + items={"type_": protos.Type.INTEGER, "nullable": True}, ), ], [ "TypedDict", ATypedDict, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={ - "a": {"type_": glm.Type.INTEGER}, + "a": {"type_": protos.Type.INTEGER}, }, ), ], [ "nested", Nested, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={ - "x": glm.Schema( - type=glm.Type.OBJECT, + "x": protos.Schema( + type=protos.Type.OBJECT, properties={ - "a": {"type_": glm.Type.INTEGER}, + "a": {"type_": protos.Type.INTEGER}, }, ), }, diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 7db0a63d8..e7411bcb6 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -16,7 +16,7 @@ import unittest.mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import discuss from google.generativeai import client @@ -37,18 +37,18 @@ def setUp(self): self.observed_request = None - self.mock_response = glm.GenerateMessageResponse( + self.mock_response = protos.GenerateMessageResponse( candidates=[ - glm.Message(content="a", author="1"), - glm.Message(content="b", author="1"), - glm.Message(content="c", author="1"), + protos.Message(content="a", author="1"), + protos.Message(content="b", author="1"), + protos.Message(content="c", author="1"), ], ) def fake_generate_message( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, **kwargs, - ) -> glm.GenerateMessageResponse: + ) -> protos.GenerateMessageResponse: self.observed_request = request response = copy.copy(self.mock_response) response.messages = request.prompt.messages @@ -60,22 +60,22 @@ def fake_generate_message( ["string", "Hello", ""], ["dict", {"content": "Hello"}, ""], ["dict_author", {"content": "Hello", "author": "me"}, "me"], - ["proto", glm.Message(content="Hello"), ""], - ["proto_author", glm.Message(content="Hello", author="me"), "me"], + ["proto", protos.Message(content="Hello"), ""], + ["proto_author", protos.Message(content="Hello", author="me"), "me"], ) def test_make_message(self, message, author): x = discuss._make_message(message) - self.assertIsInstance(x, glm.Message) + self.assertIsInstance(x, protos.Message) self.assertEqual("Hello", x.content) self.assertEqual(author, x.author) @parameterized.named_parameters( ["string", "Hello", ["Hello"]], ["dict", {"content": "Hello"}, ["Hello"]], - ["proto", glm.Message(content="Hello"), ["Hello"]], + ["proto", protos.Message(content="Hello"), ["Hello"]], [ "list", - ["hello0", {"content": "hello1"}, glm.Message(content="hello2")], + ["hello0", {"content": "hello1"}, protos.Message(content="hello2")], ["hello0", "hello1", "hello2"], ], ) @@ -90,15 +90,15 @@ def test_make_messages(self, messages, expected_contents): ["dict", {"input": "hello", "output": "goodbye"}], [ "proto", - glm.Example( - input=glm.Message(content="hello"), - output=glm.Message(content="goodbye"), + protos.Example( + input=protos.Message(content="hello"), + output=protos.Message(content="goodbye"), ), ], ) def test_make_example(self, example): x = discuss._make_example(example) - self.assertIsInstance(x, glm.Example) + self.assertIsInstance(x, protos.Example) self.assertEqual("hello", x.input.content) self.assertEqual("goodbye", x.output.content) return @@ -110,7 +110,7 @@ def test_make_example(self, example): "Hi", {"content": "Hello!"}, "what's your name?", - glm.Message(content="Dave, what's yours"), + protos.Message(content="Dave, what's yours"), ], ], [ @@ -145,15 +145,15 @@ def test_make_examples_from_example(self): @parameterized.named_parameters( ["str", "hello"], - ["message", glm.Message(content="hello")], + ["message", protos.Message(content="hello")], ["messages", ["hello"]], ["dict", {"messages": "hello"}], ["dict2", {"messages": ["hello"]}], - ["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])], + ["proto", protos.MessagePrompt(messages=[protos.Message(content="hello")])], ) def test_make_message_prompt_from_messages(self, prompt): x = discuss._make_message_prompt(prompt) - self.assertIsInstance(x, glm.MessagePrompt) + self.assertIsInstance(x, protos.MessagePrompt) self.assertEqual(x.messages[0].content, "hello") return @@ -181,15 +181,15 @@ def test_make_message_prompt_from_messages(self, prompt): [ "proto", [ - glm.MessagePrompt( + protos.MessagePrompt( context="you are a cat", examples=[ - glm.Example( - input=glm.Message(content="are you hungry?"), - output=glm.Message(content="meow!"), + protos.Example( + input=protos.Message(content="are you hungry?"), + output=protos.Message(content="meow!"), ) ], - messages=[glm.Message(content="hello")], + messages=[protos.Message(content="hello")], ) ], {}, @@ -197,7 +197,7 @@ def test_make_message_prompt_from_messages(self, prompt): ) def test_make_message_prompt_from_prompt(self, args, kwargs): x = discuss._make_message_prompt(*args, **kwargs) - self.assertIsInstance(x, glm.MessagePrompt) + self.assertIsInstance(x, protos.MessagePrompt) self.assertEqual(x.context, "you are a cat") self.assertEqual(x.examples[0].input.content, "are you hungry?") self.assertEqual(x.examples[0].output.content, "meow!") @@ -229,8 +229,8 @@ def test_make_generate_message_request_nested( } ) - self.assertIsInstance(request0, glm.GenerateMessageRequest) - self.assertIsInstance(request1, glm.GenerateMessageRequest) + self.assertIsInstance(request0, protos.GenerateMessageRequest) + self.assertIsInstance(request1, protos.GenerateMessageRequest) self.assertEqual(request0, request1) @parameterized.parameters( @@ -285,11 +285,11 @@ def test_reply(self, kwargs): response = response.reply("again") def test_receive_and_reply_with_filters(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( - candidates=[glm.Message(content="a", author="1")], + self.mock_response = mock_response = protos.GenerateMessageResponse( + candidates=[protos.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe"), - glm.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), + protos.ContentFilter(reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe"), + protos.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), ], ) response = discuss.chat(messages="do filters work?") @@ -300,10 +300,10 @@ def test_receive_and_reply_with_filters(self): self.assertEqual(filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) self.assertEqual(filters[0]["message"], "unsafe") - self.mock_response = glm.GenerateMessageResponse( - candidates=[glm.Message(content="a", author="1")], + self.mock_response = protos.GenerateMessageResponse( + candidates=[protos.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + protos.ContentFilter(reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) ], ) @@ -317,7 +317,7 @@ def test_receive_and_reply_with_filters(self): ) def test_chat_citations(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( + self.mock_response = mock_response = protos.GenerateMessageResponse( candidates=[ { "content": "Hello google!", diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py index 7e1f7947c..d35d03525 100644 --- a/tests/test_discuss_async.py +++ b/tests/test_discuss_async.py @@ -17,7 +17,7 @@ from typing import Any import unittest -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import discuss from absl.testing import absltest @@ -31,14 +31,14 @@ async def test_chat_async(self): observed_request = None async def fake_generate_message( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, **kwargs, - ) -> glm.GenerateMessageResponse: + ) -> protos.GenerateMessageResponse: nonlocal observed_request observed_request = request - return glm.GenerateMessageResponse( + return protos.GenerateMessageResponse( candidates=[ - glm.Message( + protos.Message( author="1", content="Why did the chicken cross the road?", ) @@ -59,17 +59,17 @@ async def fake_generate_message( self.assertEqual( observed_request, - glm.GenerateMessageRequest( + protos.GenerateMessageRequest( model="models/bard", - prompt=glm.MessagePrompt( + prompt=protos.MessagePrompt( context="Example Prompt", examples=[ - glm.Example( - input=glm.Message(content="Example from human"), - output=glm.Message(content="Example response from AI"), + protos.Example( + input=protos.Message(content="Example from human"), + output=protos.Message(content="Example response from AI"), ) ], - messages=[glm.Message(author="0", content="Tell me a joke")], + messages=[protos.Message(author="0", content="Tell me a joke")], ), temperature=0.75, candidate_count=1, diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 5f6aa8d89..921ad46a6 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import embedding @@ -45,20 +45,20 @@ def add_client_method(f): @add_client_method def embed_content( - request: glm.EmbedContentRequest, + request: protos.EmbedContentRequest, **kwargs, - ) -> glm.EmbedContentResponse: + ) -> protos.EmbedContentResponse: self.observed_requests.append(request) - return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + return protos.EmbedContentResponse(embedding=protos.ContentEmbedding(values=[1, 2, 3])) @add_client_method def batch_embed_contents( - request: glm.BatchEmbedContentsRequest, + request: protos.BatchEmbedContentsRequest, **kwargs, - ) -> glm.BatchEmbedContentsResponse: + ) -> protos.BatchEmbedContentsResponse: self.observed_requests.append(request) - return glm.BatchEmbedContentsResponse( - embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + return protos.BatchEmbedContentsResponse( + embeddings=[protos.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) ) def test_embed_content(self): @@ -68,8 +68,8 @@ def test_embed_content(self): self.assertIsInstance(emb, dict) self.assertEqual( self.observed_requests[-1], - glm.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + protos.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, content=protos.Content(parts=[protos.Part(text="What are you?")]) ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_embedding_async.py b/tests/test_embedding_async.py index d4ca16c08..6e8887bb9 100644 --- a/tests/test_embedding_async.py +++ b/tests/test_embedding_async.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import embedding @@ -44,20 +44,20 @@ def add_client_method(f): @add_client_method async def embed_content( - request: glm.EmbedContentRequest, + request: protos.EmbedContentRequest, **kwargs, - ) -> glm.EmbedContentResponse: + ) -> protos.EmbedContentResponse: self.observed_requests.append(request) - return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + return protos.EmbedContentResponse(embedding=protos.ContentEmbedding(values=[1, 2, 3])) @add_client_method async def batch_embed_contents( - request: glm.BatchEmbedContentsRequest, + request: protos.BatchEmbedContentsRequest, **kwargs, - ) -> glm.BatchEmbedContentsResponse: + ) -> protos.BatchEmbedContentsResponse: self.observed_requests.append(request) - return glm.BatchEmbedContentsResponse( - embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + return protos.BatchEmbedContentsResponse( + embeddings=[protos.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) ) async def test_embed_content_async(self): @@ -67,8 +67,8 @@ async def test_embed_content_async(self): self.assertIsInstance(emb, dict) self.assertEqual( self.observed_requests[-1], - glm.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + protos.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, content=protos.Content(parts=[protos.Part(text="What are you?")]) ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_generation.py b/tests/test_generation.py index 82beac16b..1b50badaf 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -5,7 +5,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import generation_types @@ -24,9 +24,9 @@ class Person(TypedDict): class UnitTests(parameterized.TestCase): @parameterized.named_parameters( [ - "glm.GenerationConfig", - glm.GenerationConfig( - temperature=0.1, stop_sequences=["end"], response_schema=glm.Schema(type="STRING") + "protos.GenerationConfig", + protos.GenerationConfig( + temperature=0.1, stop_sequences=["end"], response_schema=protos.Schema(type="STRING") ), ], [ @@ -48,15 +48,15 @@ def test_to_generation_config(self, config): def test_join_citation_metadatas(self): citations = [ - glm.CitationMetadata( + protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=21, uri="https://google.com"), ] ), - glm.CitationMetadata( + protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=3, end_index=33, uri="https://google.com"), - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=33, uri="https://google.com"), + protos.CitationSource(start_index=55, end_index=92, uri="https://google.com"), ] ), ] @@ -74,14 +74,14 @@ def test_join_citation_metadatas(self): def test_join_safety_ratings_list(self): ratings = [ [ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), ], [ - glm.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), - glm.SafetyRating( + protos.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), + protos.SafetyRating( category="HARM_CATEGORY_DANGEROUS", probability="HIGH", blocked=True, @@ -101,14 +101,14 @@ def test_join_safety_ratings_list(self): def test_join_contents(self): contents = [ - glm.Content(role="assistant", parts=[glm.Part(text="Tell me a story about a ")]), - glm.Content( + protos.Content(role="assistant", parts=[protos.Part(text="Tell me a story about a ")]), + protos.Content( role="assistant", - parts=[glm.Part(text="magic backpack that looks like this: ")], + parts=[protos.Part(text="magic backpack that looks like this: ")], ), - glm.Content( + protos.Content( role="assistant", - parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!"))], ), ] result = generation_types._join_contents(contents) @@ -126,7 +126,7 @@ def test_many_join_contents(self): import string contents = [ - glm.Content(role="assistant", parts=[glm.Part(text=a)]) for a in string.ascii_lowercase + protos.Content(role="assistant", parts=[protos.Part(text=a)]) for a in string.ascii_lowercase ] result = generation_types._join_contents(contents) @@ -139,41 +139,41 @@ def test_many_join_contents(self): def test_join_candidates(self): candidates = [ - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(text="Tell me a story about a ")], + parts=[protos.Part(text="Tell me a story about a ")], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=85, uri="https://google.com"), + protos.CitationSource(start_index=55, end_index=85, uri="https://google.com"), ] ), ), - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(text="magic backpack that looks like this: ")], + parts=[protos.Part(text="magic backpack that looks like this: ")], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=21, uri="https://google.com"), ] ), ), - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!"))], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=21, uri="https://google.com"), ] ), finish_reason="STOP", @@ -213,17 +213,17 @@ def test_join_candidates(self): def test_join_prompt_feedbacks(self): feedbacks = [ - glm.GenerateContentResponse.PromptFeedback( + protos.GenerateContentResponse.PromptFeedback( block_reason="SAFETY", safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), ], ), - glm.GenerateContentResponse.PromptFeedback(), - glm.GenerateContentResponse.PromptFeedback(), - glm.GenerateContentResponse.PromptFeedback( + protos.GenerateContentResponse.PromptFeedback(), + protos.GenerateContentResponse.PromptFeedback(), + protos.GenerateContentResponse.PromptFeedback( safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), + protos.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), ] ), ] @@ -396,23 +396,23 @@ def test_join_prompt_feedbacks(self): ] def test_join_candidates(self): - candidate_lists = [[glm.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] + candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] result = generation_types._join_candidate_lists(candidate_lists) self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result]) def test_join_chunks(self): - chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] - chunks[0].prompt_feedback = glm.GenerateContentResponse.PromptFeedback( + chunks[0].prompt_feedback = protos.GenerateContentResponse.PromptFeedback( block_reason="SAFETY", safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), ], ) result = generation_types._join_chunks(chunks) - expected = glm.GenerateContentResponse( + expected = protos.GenerateContentResponse( { "candidates": self.MERGED_CANDIDATES, "prompt_feedback": { @@ -431,7 +431,7 @@ def test_join_chunks(self): self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected)) def test_generate_content_response_iterator_end_to_end(self): - chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] merged = generation_types._join_chunks(chunks) response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -453,7 +453,7 @@ def test_generate_content_response_iterator_end_to_end(self): def test_generate_content_response_multiple_iterators(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in string.ascii_lowercase ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -483,7 +483,7 @@ def test_generate_content_response_multiple_iterators(self): def test_generate_content_response_resolve(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in "abcd" ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -497,7 +497,7 @@ def test_generate_content_response_resolve(self): self.assertEqual(response.candidates[0].content.parts[0].text, "abcd") def test_generate_content_response_from_response(self): - raw_response = glm.GenerateContentResponse( + raw_response = protos.GenerateContentResponse( {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} ) response = generation_types.GenerateContentResponse.from_response(raw_response) @@ -511,7 +511,7 @@ def test_generate_content_response_from_response(self): ) def test_repr_for_generate_content_response_from_response(self): - raw_response = glm.GenerateContentResponse( + raw_response = protos.GenerateContentResponse( {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} ) response = generation_types.GenerateContentResponse.from_response(raw_response) @@ -523,7 +523,7 @@ def test_repr_for_generate_content_response_from_response(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -547,7 +547,7 @@ def test_repr_for_generate_content_response_from_response(self): def test_repr_for_generate_content_response_from_iterator(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in "abcd" ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -559,7 +559,7 @@ def test_repr_for_generate_content_response_from_iterator(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -583,35 +583,35 @@ def test_repr_for_generate_content_response_from_iterator(self): @parameterized.named_parameters( [ - "glm.Schema", - glm.Schema(type="STRING"), - glm.Schema(type="STRING"), + "protos.Schema", + protos.Schema(type="STRING"), + protos.Schema(type="STRING"), ], [ "SchemaDict", {"type": "STRING"}, - glm.Schema(type="STRING"), + protos.Schema(type="STRING"), ], [ "str", str, - glm.Schema(type="STRING"), + protos.Schema(type="STRING"), ], - ["list_of_str", list[str], glm.Schema(type="ARRAY", items=glm.Schema(type="STRING"))], + ["list_of_str", list[str], protos.Schema(type="ARRAY", items=protos.Schema(type="STRING"))], [ "fancy", Person, - glm.Schema( + protos.Schema( type="OBJECT", properties=dict( - name=glm.Schema(type="STRING"), - favorite_color=glm.Schema(type="STRING"), - birthday=glm.Schema( + name=protos.Schema(type="STRING"), + favorite_color=protos.Schema(type="STRING"), + birthday=protos.Schema( type="OBJECT", properties=dict( - day=glm.Schema(type="INTEGER"), - month=glm.Schema(type="INTEGER"), - year=glm.Schema(type="INTEGER"), + day=protos.Schema(type="INTEGER"), + month=protos.Schema(type="INTEGER"), + year=protos.Schema(type="INTEGER"), ), ), ), diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 6fabd59e9..2832c55c5 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -7,7 +7,7 @@ import unittest.mock from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai import generative_models from google.generativeai.types import content_types @@ -21,16 +21,16 @@ TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes() -def simple_part(text: str) -> glm.Content: - return glm.Content({"parts": [{"text": text}]}) +def simple_part(text: str) -> protos.Content: + return protos.Content({"parts": [{"text": text}]}) -def iter_part(texts: Iterable[str]) -> glm.Content: - return glm.Content({"parts": [{"text": t} for t in texts]}) +def iter_part(texts: Iterable[str]) -> protos.Content: + return protos.Content({"parts": [{"text": t} for t in texts]}) -def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) +def simple_response(text: str) -> protos.GenerateContentResponse: + return protos.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) class CUJTests(parameterized.TestCase): @@ -51,28 +51,28 @@ def add_client_method(f): @add_client_method def generate_content( - request: glm.GenerateContentRequest, + request: protos.GenerateContentRequest, **kwargs, - ) -> glm.GenerateContentResponse: - self.assertIsInstance(request, glm.GenerateContentRequest) + ) -> protos.GenerateContentResponse: + self.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) response = self.responses["generate_content"].pop(0) return response @add_client_method def stream_generate_content( - request: glm.GetModelRequest, + request: protos.GetModelRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["stream_generate_content"].pop(0) return response @add_client_method def count_tokens( - request: glm.CountTokensRequest, + request: protos.CountTokensRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["count_tokens"].pop(0) return response @@ -129,9 +129,9 @@ def test_image(self, content): generation_types.GenerationConfig(temperature=0.5), ], [ - "glm", - glm.GenerationConfig(temperature=0.0), - glm.GenerationConfig(temperature=0.5), + "protos", + protos.GenerationConfig(temperature=0.0), + protos.GenerationConfig(temperature=0.5), ], ) def test_generation_config_overwrite(self, config1, config2): @@ -155,8 +155,8 @@ def test_generation_config_overwrite(self, config1, config2): "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ @@ -166,15 +166,15 @@ def test_generation_config_overwrite(self, config1, config2): [ "object", [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], ], @@ -191,21 +191,21 @@ def test_safety_overwrite(self, safe1, safe2): _ = model.generate_content("hello") self.assertEqual( self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, ) self.assertEqual( self.observed_requests[-1].safety_settings[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) _ = model.generate_content("hello", safety_settings={"danger": "high"}) self.assertEqual( self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, ) self.assertEqual( self.observed_requests[-1].safety_settings[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ) def test_stream_basic(self): @@ -239,7 +239,7 @@ def test_stream_lookahead(self): def test_stream_prompt_feedback_blocked(self): chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -252,7 +252,7 @@ def test_stream_prompt_feedback_blocked(self): self.assertEqual( response.prompt_feedback.block_reason, - glm.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, + protos.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, ) with self.assertRaises(generation_types.BlockedPromptException): @@ -261,20 +261,20 @@ def test_stream_prompt_feedback_blocked(self): def test_stream_prompt_feedback_not_blocked(self): chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": { "safety_ratings": [ { - "category": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "probability": glm.SafetyRating.HarmProbability.NEGLIGIBLE, + "category": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "probability": protos.SafetyRating.HarmProbability.NEGLIGIBLE, } ] }, "candidates": [{"content": {"parts": [{"text": "first"}]}}], } ), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"content": {"parts": [{"text": " second"}]}}], } @@ -287,7 +287,7 @@ def test_stream_prompt_feedback_not_blocked(self): self.assertEqual( response.prompt_feedback.safety_ratings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS, ) text = "".join(chunk.text for chunk in response) @@ -520,7 +520,7 @@ def no_throw(): def test_chat_prompt_blocked(self): self.responses["generate_content"] = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -538,7 +538,7 @@ def test_chat_prompt_blocked(self): def test_chat_candidate_blocked(self): # I feel like chat needs a .last so you can look at the partial results. self.responses["generate_content"] = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -558,7 +558,7 @@ def test_chat_streaming_unexpected_stop(self): simple_response("a"), simple_response("b"), simple_response("c"), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -645,9 +645,9 @@ def test_tools(self): }, ), dict( - testcase_name="test_glm_FunctionCallingConfig", + testcase_name="test_protos.FunctionCallingConfig", tool_config={ - "function_calling_config": glm.FunctionCallingConfig( + "function_calling_config": protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.AUTO ) }, @@ -674,9 +674,9 @@ def test_tools(self): }, ), dict( - testcase_name="test_glm_ToolConfig", - tool_config=glm.ToolConfig( - function_calling_config=glm.FunctionCallingConfig( + testcase_name="test_protos.ToolConfig", + tool_config=protos.ToolConfig( + function_calling_config=protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.NONE ) ), @@ -734,7 +734,7 @@ def test_system_instruction(self, instruction, expected_instr): ["contents", [{"role": "user", "parts": ["hello"]}]], ) def test_count_tokens_smoke(self, contents): - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") response = model.count_tokens(contents) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) @@ -801,7 +801,7 @@ def test_repr_for_unary_non_streamed_response(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -839,7 +839,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -869,7 +869,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -912,7 +912,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -950,7 +950,7 @@ def test_repr_for_streaming_start_to_finish(self): def test_repr_error_info_for_stream_prompt_feedback_blocked(self): # response._error => BlockedPromptException chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -968,7 +968,7 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "prompt_feedback": { "block_reason": 1, "safety_ratings": [] @@ -1020,7 +1020,7 @@ def no_throw(): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -1064,7 +1064,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): simple_response("a"), simple_response("b"), simple_response("c"), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -1093,7 +1093,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -1169,7 +1169,7 @@ def test_repr_for_multi_turn_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'first'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), glm.Content({'parts': [{'text': 'second'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'third'}], 'role': 'model'})] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'first'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), protos.Content({'parts': [{'text': 'second'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" ) self.assertEqual(expected, result) @@ -1197,7 +1197,7 @@ def test_repr_for_incomplete_streaming_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" ) self.assertEqual(expected, result) @@ -1213,7 +1213,7 @@ def test_repr_for_broken_streaming_chat(self): for chunk in [ simple_response("first"), # FinishReason.SAFETY = 3 - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [ {"finish_reason": 3, "content": {"parts": [{"text": "second"}]}} @@ -1241,7 +1241,7 @@ def test_repr_for_broken_streaming_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" ) self.assertEqual(expected, result) @@ -1256,7 +1256,7 @@ def test_count_tokens_called_with_request_options(self): request = unittest.mock.ANY request_options = {"timeout": 120} - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") model.count_tokens([{"role": "user", "parts": ["hello"]}], request_options=request_options) diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 2c465d1d3..b5babda1e 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -24,14 +24,14 @@ from google.generativeai import client as client_lib from google.generativeai import generative_models from google.generativeai.types import content_types -import google.ai.generativelanguage as glm +from google.generativeai import protos from absl.testing import absltest from absl.testing import parameterized -def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) +def simple_response(text: str) -> protos.GenerateContentResponse: + return protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): @@ -50,28 +50,28 @@ def add_client_method(f): @add_client_method async def generate_content( - request: glm.GenerateContentRequest, + request: protos.GenerateContentRequest, **kwargs, - ) -> glm.GenerateContentResponse: - self.assertIsInstance(request, glm.GenerateContentRequest) + ) -> protos.GenerateContentResponse: + self.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) response = self.responses["generate_content"].pop(0) return response @add_client_method async def stream_generate_content( - request: glm.GetModelRequest, + request: protos.GetModelRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["stream_generate_content"].pop(0) return response @add_client_method async def count_tokens( - request: glm.CountTokensRequest, + request: protos.CountTokensRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["count_tokens"].pop(0) return response @@ -140,9 +140,9 @@ async def responses(): }, ), dict( - testcase_name="test_glm_FunctionCallingConfig", + testcase_name="test_protos.FunctionCallingConfig", tool_config={ - "function_calling_config": glm.FunctionCallingConfig( + "function_calling_config": protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.AUTO ) }, @@ -169,9 +169,9 @@ async def responses(): }, ), dict( - testcase_name="test_glm_ToolConfig", - tool_config=glm.ToolConfig( - function_calling_config=glm.FunctionCallingConfig( + testcase_name="test_protos.ToolConfig", + tool_config=protos.ToolConfig( + function_calling_config=protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.NONE ) ), @@ -211,7 +211,7 @@ async def test_tool_config(self, tool_config, expected_tool_config): ["contents", [{"role": "user", "parts": ["hello"]}]], ) async def test_count_tokens_smoke(self, contents): - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") response = await model.count_tokens_async(contents) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 0c2de7f29..f060caf88 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -19,7 +19,7 @@ from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import client from google.generativeai import models @@ -35,15 +35,15 @@ def __init__(self, test): def get_model( self, - request: Union[glm.GetModelRequest, None] = None, + request: Union[protos.GetModelRequest, None] = None, *, name=None, timeout=None, retry=None - ) -> glm.Model: + ) -> protos.Model: if request is None: - request = glm.GetModelRequest(name=name) - self.test.assertIsInstance(request, glm.GetModelRequest) + request = protos.GetModelRequest(name=name) + self.test.assertIsInstance(request, protos.GetModelRequest) self.test.observed_requests.append(request) self.test.observed_timeout.append(timeout) self.test.observed_retry.append(retry) @@ -75,7 +75,7 @@ def setUp(self): ], ) def test_get_model(self, request_options, expected_timeout, expected_retry): - self.responses = {"get_model": glm.Model(name="models/fake-bison-001")} + self.responses = {"get_model": protos.Model(name="models/fake-bison-001")} _ = models.get_model("models/fake-bison-001", request_options=request_options) diff --git a/tests/test_models.py b/tests/test_models.py index f39ed3a2c..c591c7f71 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,7 +25,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.api_core import operation from google.generativeai import models @@ -45,7 +45,7 @@ def setUp(self): client._client_manager.clients["model"] = self.client # TODO(markdaoust): Check if typechecking works better if wee define this as a - # subclass of `glm.ModelServiceClient`, would pyi files for `glm` help? + # subclass of `protos.ModelServiceClient`, would pyi files for `protos. help? def add_client_method(f): name = f.__name__ setattr(self.client, name, f) @@ -55,63 +55,63 @@ def add_client_method(f): self.responses = {} @add_client_method - def get_model(request: Union[glm.GetModelRequest, None] = None, *, name=None) -> glm.Model: + def get_model(request: Union[protos.GetModelRequest, None] = None, *, name=None) -> protos.Model: if request is None: - request = glm.GetModelRequest(name=name) - self.assertIsInstance(request, glm.GetModelRequest) + request = protos.GetModelRequest(name=name) + self.assertIsInstance(request, protos.GetModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_model"]) return response @add_client_method def get_tuned_model( - request: Union[glm.GetTunedModelRequest, None] = None, + request: Union[protos.GetTunedModelRequest, None] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method def list_models( - request: Union[glm.ListModelsRequest, None] = None, + request: Union[protos.ListModelsRequest, None] = None, *, page_size=None, page_token=None, **kwargs, - ) -> glm.ListModelsResponse: + ) -> protos.ListModelsResponse: if request is None: - request = glm.ListModelsRequest(page_size=page_size, page_token=page_token) - self.assertIsInstance(request, glm.ListModelsRequest) + request = protos.ListModelsRequest(page_size=page_size, page_token=page_token) + self.assertIsInstance(request, protos.ListModelsRequest) self.observed_requests.append(request) response = self.responses["list_models"] return (item for item in response) @add_client_method def list_tuned_models( - request: glm.ListTunedModelsRequest = None, + request: protos.ListTunedModelsRequest = None, *, page_size=None, page_token=None, **kwargs, - ) -> Iterable[glm.TunedModel]: + ) -> Iterable[protos.TunedModel]: if request is None: - request = glm.ListTunedModelsRequest(page_size=page_size, page_token=page_token) - self.assertIsInstance(request, glm.ListTunedModelsRequest) + request = protos.ListTunedModelsRequest(page_size=page_size, page_token=page_token) + self.assertIsInstance(request, protos.ListTunedModelsRequest) self.observed_requests.append(request) response = self.responses["list_tuned_models"] return (item for item in response) @add_client_method def update_tuned_model( - request: glm.UpdateTunedModelRequest, + request: protos.UpdateTunedModelRequest, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: self.observed_requests.append(request) response = self.responses.get("update_tuned_model", None) if response is None: @@ -120,7 +120,7 @@ def update_tuned_model( @add_client_method def delete_tuned_model(name): - request = glm.DeleteTunedModelRequest(name=name) + request = protos.DeleteTunedModelRequest(name=name) self.observed_requests.append(request) response = True return response @@ -130,26 +130,26 @@ def create_tuned_model( request, **kwargs, ): - request = glm.CreateTunedModelRequest(request) + request = protos.CreateTunedModelRequest(request) self.observed_requests.append(request) return self.responses["create_tuned_model"] def test_decode_tuned_model_time_round_trip(self): example_dt = datetime.datetime(2000, 1, 2, 3, 4, 5, 600_000, pytz.UTC) - tuned_model = glm.TunedModel(name="tunedModels/house-mouse-001", create_time=example_dt) + tuned_model = protos.TunedModel(name="tunedModels/house-mouse-001", create_time=example_dt) tuned_model = model_types.decode_tuned_model(tuned_model) self.assertEqual(tuned_model.create_time, example_dt) @parameterized.named_parameters( ["simple", "models/fake-bison-001"], ["simple-tuned", "tunedModels/my-pig-001"], - ["model-instance", glm.Model(name="models/fake-bison-001")], - ["tuned-model-instance", glm.TunedModel(name="tunedModels/my-pig-001")], + ["model-instance", protos.Model(name="models/fake-bison-001")], + ["tuned-model-instance", protos.TunedModel(name="tunedModels/my-pig-001")], ) def test_get_model(self, name): self.responses = { - "get_model": glm.Model(name="models/fake-bison-001"), - "get_tuned_model": glm.TunedModel(name="tunedModels/my-pig-001"), + "get_model": protos.Model(name="models/fake-bison-001"), + "get_tuned_model": protos.TunedModel(name="tunedModels/my-pig-001"), } model = models.get_model(name) @@ -160,7 +160,7 @@ def test_get_model(self, name): @parameterized.named_parameters( ["simple", "mystery-bison-001"], - ["model-instance", glm.Model(name="how?-bison-001")], + ["model-instance", protos.Model(name="how?-bison-001")], ) def test_fail_with_unscoped_model_name(self, name): with self.assertRaises(ValueError): @@ -170,9 +170,9 @@ def test_list_models(self): # The low level lib wraps the response in an iterable, so this is a fair test. self.responses = { "list_models": [ - glm.Model(name="models/fake-bison-001"), - glm.Model(name="models/fake-bison-002"), - glm.Model(name="models/fake-bison-003"), + protos.Model(name="models/fake-bison-001"), + protos.Model(name="models/fake-bison-002"), + protos.Model(name="models/fake-bison-003"), ] } @@ -185,9 +185,9 @@ def test_list_tuned_models(self): self.responses = { # The low level lib wraps the response in an iterable, so this is a fair test. "list_tuned_models": [ - glm.TunedModel(name="tunedModels/my-pig-001"), - glm.TunedModel(name="tunedModels/my-pig-002"), - glm.TunedModel(name="tunedModels/my-pig-003"), + protos.TunedModel(name="tunedModels/my-pig-001"), + protos.TunedModel(name="tunedModels/my-pig-002"), + protos.TunedModel(name="tunedModels/my-pig-003"), ] } found_models = list(models.list_tuned_models()) @@ -197,8 +197,8 @@ def test_list_tuned_models(self): @parameterized.named_parameters( [ - "edited-glm-model", - glm.TunedModel( + "edited-protos.model", + protos.TunedModel( name="tunedModels/my-pig-001", description="Trained on my data", ), @@ -211,7 +211,7 @@ def test_list_tuned_models(self): ], ) def test_update_tuned_model_basics(self, tuned_model, updates): - self.responses["get_tuned_model"] = glm.TunedModel(name="tunedModels/my-pig-001") + self.responses["get_tuned_model"] = protos.TunedModel(name="tunedModels/my-pig-001") # No self.responses['update_tuned_model'] the mock just returns the input. updated_model = models.update_tuned_model(tuned_model, updates) updated_model.description = "Trained on my data" @@ -227,7 +227,7 @@ def test_update_tuned_model_basics(self, tuned_model, updates): ], ) def test_update_tuned_model_nested_fields(self, updates): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/my-pig-001", base_model="models/dance-monkey-007" ) @@ -250,8 +250,8 @@ def test_update_tuned_model_nested_fields(self, updates): @parameterized.named_parameters( ["name", "tunedModels/bipedal-pangolin-223"], [ - "glm.TunedModel", - glm.TunedModel(name="tunedModels/bipedal-pangolin-223"), + "protos.TunedModel", + protos.TunedModel(name="tunedModels/bipedal-pangolin-223"), ], [ "models.TunedModel", @@ -275,23 +275,23 @@ def test_decode_micros(self, time_str, micros): self.assertEqual(time["time"].microsecond, micros) def test_decode_tuned_model(self): - out_fields = glm.TunedModel( - state=glm.TunedModel.State.CREATING, + out_fields = protos.TunedModel( + state=protos.TunedModel.State.CREATING, create_time="2000-01-01T01:01:01.0Z", update_time="2001-01-01T01:01:01.0Z", - tuning_task=glm.TuningTask( - hyperparameters=glm.Hyperparameters( + tuning_task=protos.TuningTask( + hyperparameters=protos.Hyperparameters( batch_size=72, epoch_count=1, learning_rate=0.1 ), start_time="2002-01-01T01:01:01.0Z", complete_time="2003-01-01T01:01:01.0Z", snapshots=[ - glm.TuningSnapshot( + protos.TuningSnapshot( step=1, epoch=1, compute_time="2004-01-01T01:01:01.0Z", ), - glm.TuningSnapshot( + protos.TuningSnapshot( step=2, epoch=1, compute_time="2005-01-01T01:01:01.0Z", @@ -301,7 +301,7 @@ def test_decode_tuned_model(self): ) decoded = model_types.decode_tuned_model(out_fields) - self.assertEqual(decoded.state, glm.TunedModel.State.CREATING) + self.assertEqual(decoded.state, protos.TunedModel.State.CREATING) self.assertEqual(decoded.create_time.year, 2000) self.assertEqual(decoded.update_time.year, 2001) self.assertIsInstance(decoded.tuning_task.hyperparameters, model_types.Hyperparameters) @@ -314,10 +314,10 @@ def test_decode_tuned_model(self): self.assertEqual(decoded.tuning_task.snapshots[1]["compute_time"].year, 2005) @parameterized.named_parameters( - ["simple", glm.TunedModel(base_model="models/swim-fish-000")], + ["simple", protos.TunedModel(base_model="models/swim-fish-000")], [ "nested", - glm.TunedModel( + protos.TunedModel( tuned_model_source={ "tuned_model": "tunedModels/hidden-fish-55", "base_model": "models/swim-fish-000", @@ -341,7 +341,7 @@ def test_smoke_create_tuned_model(self): training_data=[ ("in", "out"), {"text_input": "in", "output": "out"}, - glm.TuningExample(text_input="in", output="out"), + protos.TuningExample(text_input="in", output="out"), ], ) req = self.observed_requests[-1] @@ -351,10 +351,10 @@ def test_smoke_create_tuned_model(self): self.assertLen(req.tuned_model.tuning_task.training_data.examples.examples, 3) @parameterized.named_parameters( - ["simple", glm.TunedModel(base_model="models/swim-fish-000")], + ["simple", protos.TunedModel(base_model="models/swim-fish-000")], [ "nested", - glm.TunedModel( + protos.TunedModel( tuned_model_source={ "tuned_model": "tunedModels/hidden-fish-55", "base_model": "models/swim-fish-000", @@ -380,9 +380,9 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): @parameterized.named_parameters( [ - "glm", - glm.Dataset( - examples=glm.TuningExamples( + "protos", + protos.Dataset( + examples=protos.TuningExamples( examples=[ {"text_input": "a", "output": "1"}, {"text_input": "b", "output": "2"}, @@ -396,7 +396,7 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): [ ("a", "1"), {"text_input": "b", "output": "2"}, - glm.TuningExample({"text_input": "c", "output": "3"}), + protos.TuningExample({"text_input": "c", "output": "3"}), ], ], ["dict", {"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}], @@ -445,8 +445,8 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): def test_create_dataset(self, data, ik="text_input", ok="output"): ds = model_types.encode_tuning_data(data, input_key=ik, output_key=ok) - expect = glm.Dataset( - examples=glm.TuningExamples( + expect = protos.Dataset( + examples=protos.TuningExamples( examples=[ {"text_input": "a", "output": "1"}, {"text_input": "b", "output": "2"}, @@ -502,7 +502,7 @@ def test_update_tuned_model_called_with_request_options(self): self.client.update_tuned_model = unittest.mock.MagicMock() request = unittest.mock.ANY request_options = {"timeout": 120} - self.responses["get_tuned_model"] = glm.TunedModel(name="tunedModels/") + self.responses["get_tuned_model"] = protos.TunedModel(name="tunedModels/") try: models.update_tuned_model( @@ -534,7 +534,7 @@ def test_create_tuned_model_called_with_request_options(self): training_data=[ ("in", "out"), {"text_input": "in", "output": "out"}, - glm.TuningExample(text_input="in", output="out"), + protos.TuningExample(text_input="in", output="out"), ], request_options=request_options, ) diff --git a/tests/test_operations.py b/tests/test_operations.py index 80262db88..6529b77e5 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -16,7 +16,7 @@ from contextlib import redirect_stderr import io -import google.ai.generativelanguage as glm +from google.generativeai import protos import google.protobuf.any_pb2 import google.generativeai.operations as genai_operation @@ -41,7 +41,7 @@ def test_end_to_end(self): # `Any` takes a type name and a serialized proto. metadata = google.protobuf.any_pb2.Any( type_url=self.metadata_type, - value=glm.CreateTunedModelMetadata(tuned_model=name)._pb.SerializeToString(), + value=protos.CreateTunedModelMetadata(tuned_model=name)._pb.SerializeToString(), ) # Initially the `Operation` is not `done`, so it only gives a metadata. @@ -58,7 +58,7 @@ def test_end_to_end(self): metadata=metadata, response=google.protobuf.any_pb2.Any( type_url=self.result_type, - value=glm.TunedModel(name=name)._pb.SerializeToString(), + value=protos.TunedModel(name=name)._pb.SerializeToString(), ), ) @@ -72,8 +72,8 @@ def refresh(*_, **__): operation=initial_pb, refresh=refresh, cancel=lambda: print(f"cancel!"), - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) # Use our wrapper instead. @@ -99,7 +99,7 @@ def gen_operations(): def make_metadata(completed_steps): return google.protobuf.any_pb2.Any( type_url=self.metadata_type, - value=glm.CreateTunedModelMetadata( + value=protos.CreateTunedModelMetadata( tuned_model=name, total_steps=total_steps, completed_steps=completed_steps, @@ -122,7 +122,7 @@ def make_metadata(completed_steps): metadata=make_metadata(total_steps), response=google.protobuf.any_pb2.Any( type_url=self.result_type, - value=glm.TunedModel(name=name)._pb.SerializeToString(), + value=protos.TunedModel(name=name)._pb.SerializeToString(), ), ) @@ -142,8 +142,8 @@ def refresh(*_, **__): operation=initial_pb, refresh=refresh, cancel=None, - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) # Use our wrapper instead. diff --git a/tests/test_permission.py b/tests/test_permission.py index 55ad7a2f0..66b396977 100644 --- a/tests/test_permission.py +++ b/tests/test_permission.py @@ -17,7 +17,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import permission @@ -50,11 +50,11 @@ def add_client_method(f): @add_client_method def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -63,24 +63,24 @@ def create_corpus( @add_client_method def get_tuned_model( - request: Optional[glm.GetTunedModelRequest] = None, + request: Optional[protos.GetTunedModelRequest] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method def create_permission( - request: glm.CreatePermissionRequest, - ) -> glm.Permission: + request: protos.CreatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -88,17 +88,17 @@ def create_permission( @add_client_method def delete_permission( - request: glm.DeletePermissionRequest, + request: protos.DeletePermissionRequest, ) -> None: self.observed_requests.append(request) return None @add_client_method def get_permission( - request: glm.GetPermissionRequest, - ) -> glm.Permission: + request: protos.GetPermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -106,16 +106,16 @@ def get_permission( @add_client_method def list_permissions( - request: glm.ListPermissionsRequest, - ) -> glm.ListPermissionsResponse: + request: protos.ListPermissionsRequest, + ) -> protos.ListPermissionsResponse: self.observed_requests.append(request) return [ - glm.Permission( + protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), ), - glm.Permission( + protos.Permission( name="corpora/demo-corpus/permissions/987654321", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -125,10 +125,10 @@ def list_permissions( @add_client_method def update_permission( - request: glm.UpdatePermissionRequest, - ) -> glm.Permission: + request: protos.UpdatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -136,16 +136,16 @@ def update_permission( @add_client_method def transfer_ownership( - request: glm.TransferOwnershipRequest, - ) -> glm.TransferOwnershipResponse: + request: protos.TransferOwnershipRequest, + ) -> protos.TransferOwnershipResponse: self.observed_requests.append(request) - return glm.TransferOwnershipResponse() + return protos.TransferOwnershipResponse() def test_create_permission_success(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create(role="writer", grantee_type="everyone", email_address=None) self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.CreatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreatePermissionRequest) def test_create_permission_failure_email_set_when_grantee_type_is_everyone(self): x = retriever.create_corpus("demo-corpus") @@ -161,14 +161,14 @@ def test_delete_permission(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") perm.delete() - self.assertIsInstance(self.observed_requests[-1], glm.DeletePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeletePermissionRequest) def test_get_permission_with_full_name(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") fetch_perm = permission.get_permission(name=perm.name) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) def test_get_permission_with_resource_name_and_id_1(self): @@ -178,7 +178,7 @@ def test_get_permission_with_resource_name_and_id_1(self): resource_name="corpora/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) def test_get_permission_with_resource_name_name_and_id_2(self): @@ -186,14 +186,14 @@ def test_get_permission_with_resource_name_name_and_id_2(self): resource_name="tunedModels/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) def test_get_permission_with_resource_type(self): fetch_perm = permission.get_permission( resource_name="demo-model", permission_id=123456789, resource_type="tunedModels" ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) @parameterized.named_parameters( dict( @@ -257,14 +257,14 @@ def test_list_permission(self): self.assertEqual(perms[1].email_address, "_") for perm in perms: self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.ListPermissionsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListPermissionsRequest) def test_update_permission_success(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") updated_perm = perm.update({"role": permission_services.to_role("reader")}) self.assertIsInstance(updated_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.UpdatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdatePermissionRequest) def test_update_permission_failure_restricted_update_path(self): x = retriever.create_corpus("demo-corpus") @@ -275,12 +275,12 @@ def test_update_permission_failure_restricted_update_path(self): ) def test_transfer_ownership(self): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/fake-pig-001", base_model="models/dance-monkey-007" ) x = models.get_tuned_model("tunedModels/fake-pig-001") response = x.permissions.transfer_ownership(email_address="_") - self.assertIsInstance(self.observed_requests[-1], glm.TransferOwnershipRequest) + self.assertIsInstance(self.observed_requests[-1], protos.TransferOwnershipRequest) def test_transfer_ownership_on_corpora(self): x = retriever.create_corpus("demo-corpus") diff --git a/tests/test_permission_async.py b/tests/test_permission_async.py index 165039122..ddc9c22a2 100644 --- a/tests/test_permission_async.py +++ b/tests/test_permission_async.py @@ -17,7 +17,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import permission @@ -49,11 +49,11 @@ def add_client_method(f): @add_client_method async def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -62,24 +62,24 @@ async def create_corpus( @add_client_method def get_tuned_model( - request: Optional[glm.GetTunedModelRequest] = None, + request: Optional[protos.GetTunedModelRequest] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method async def create_permission( - request: glm.CreatePermissionRequest, - ) -> glm.Permission: + request: protos.CreatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -87,17 +87,17 @@ async def create_permission( @add_client_method async def delete_permission( - request: glm.DeletePermissionRequest, + request: protos.DeletePermissionRequest, ) -> None: self.observed_requests.append(request) return None @add_client_method async def get_permission( - request: glm.GetPermissionRequest, - ) -> glm.Permission: + request: protos.GetPermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -105,17 +105,17 @@ async def get_permission( @add_client_method async def list_permissions( - request: glm.ListPermissionsRequest, - ) -> glm.ListPermissionsResponse: + request: protos.ListPermissionsRequest, + ) -> protos.ListPermissionsResponse: self.observed_requests.append(request) async def results(): - yield glm.Permission( + yield protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), ) - yield glm.Permission( + yield protos.Permission( name="corpora/demo-corpus/permissions/987654321", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -126,10 +126,10 @@ async def results(): @add_client_method async def update_permission( - request: glm.UpdatePermissionRequest, - ) -> glm.Permission: + request: protos.UpdatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -137,10 +137,10 @@ async def update_permission( @add_client_method async def transfer_ownership( - request: glm.TransferOwnershipRequest, - ) -> glm.TransferOwnershipResponse: + request: protos.TransferOwnershipRequest, + ) -> protos.TransferOwnershipResponse: self.observed_requests.append(request) - return glm.TransferOwnershipResponse() + return protos.TransferOwnershipResponse() async def test_create_permission_success(self): x = await retriever.create_corpus_async("demo-corpus") @@ -148,7 +148,7 @@ async def test_create_permission_success(self): role="writer", grantee_type="everyone", email_address=None ) self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.CreatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreatePermissionRequest) async def test_create_permission_failure_email_set_when_grantee_type_is_everyone(self): x = await retriever.create_corpus_async("demo-corpus") @@ -168,14 +168,14 @@ async def test_delete_permission(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") await perm.delete_async() - self.assertIsInstance(self.observed_requests[-1], glm.DeletePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeletePermissionRequest) async def test_get_permission_with_full_name(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") fetch_perm = await permission.get_permission_async(name=perm.name) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) async def test_get_permission_with_resource_name_and_id_1(self): @@ -185,7 +185,7 @@ async def test_get_permission_with_resource_name_and_id_1(self): resource_name="corpora/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) async def test_get_permission_with_resource_name_name_and_id_2(self): @@ -193,14 +193,14 @@ async def test_get_permission_with_resource_name_name_and_id_2(self): resource_name="tunedModels/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) async def test_get_permission_with_resource_type(self): fetch_perm = await permission.get_permission_async( resource_name="demo-model", permission_id=123456789, resource_type="tunedModels" ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) @parameterized.named_parameters( dict( @@ -264,14 +264,14 @@ async def test_list_permission(self): self.assertEqual(perms[1].email_address, "_") for perm in perms: self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.ListPermissionsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListPermissionsRequest) async def test_update_permission_success(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") updated_perm = await perm.update_async({"role": permission_services.to_role("reader")}) self.assertIsInstance(updated_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.UpdatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdatePermissionRequest) async def test_update_permission_failure_restricted_update_path(self): x = await retriever.create_corpus_async("demo-corpus") @@ -282,12 +282,12 @@ async def test_update_permission_failure_restricted_update_path(self): ) async def test_transfer_ownership(self): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/fake-pig-001", base_model="models/dance-monkey-007" ) x = models.get_tuned_model("tunedModels/fake-pig-001") response = await x.permissions.transfer_ownership_async(email_address="_") - self.assertIsInstance(self.observed_requests[-1], glm.TransferOwnershipRequest) + self.assertIsInstance(self.observed_requests[-1], protos.TransferOwnershipRequest) async def test_transfer_ownership_on_corpora(self): x = await retriever.create_corpus_async("demo-corpus") diff --git a/tests/test_responder.py b/tests/test_responder.py index 4eb310815..c075fc65a 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -17,7 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import responder import IPython.display import PIL.Image @@ -42,9 +42,9 @@ class UnitTests(parameterized.TestCase): [ "FunctionLibrary", responder.FunctionLibrary( - tools=glm.Tool( + tools=protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -56,7 +56,7 @@ class UnitTests(parameterized.TestCase): [ responder.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -64,11 +64,11 @@ class UnitTests(parameterized.TestCase): ], ], [ - "IterableTool-glm.Tool", + "IterableTool-protos.Tool", [ - glm.Tool( + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -93,7 +93,7 @@ class UnitTests(parameterized.TestCase): "IterableTool-IterableFD", [ [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -103,7 +103,7 @@ class UnitTests(parameterized.TestCase): [ "IterableTool-FD", [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -113,17 +113,17 @@ class UnitTests(parameterized.TestCase): "Tool", responder.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] ), ], [ - "glm.Tool", - glm.Tool( + "protos.Tool", + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -175,8 +175,8 @@ class UnitTests(parameterized.TestCase): ), ], [ - "glm.FD", - glm.FunctionDeclaration( + "protos.FD", + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ), ], @@ -216,32 +216,32 @@ def b(): self.assertLen(tools[0].function_declarations, 2) @parameterized.named_parameters( - ["int", int, glm.Schema(type=glm.Type.INTEGER)], - ["float", float, glm.Schema(type=glm.Type.NUMBER)], - ["str", str, glm.Schema(type=glm.Type.STRING)], + ["int", int, protos.Schema(type=protos.Type.INTEGER)], + ["float", float, protos.Schema(type=protos.Type.NUMBER)], + ["str", str, protos.Schema(type=protos.Type.STRING)], [ "list", list[str], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.STRING), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.STRING), ), ], [ "list-list-int", list[list[int]], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema( - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.INTEGER), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema( + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.INTEGER), ), ), ), ], - ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], - ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ["dict", dict, protos.Schema(type=protos.Type.OBJECT)], + ["dict-str-any", dict[str, Any], protos.Schema(type=protos.Type.OBJECT)], ) def test_auto_schema(self, annotation, expected): def fun(a: annotation): diff --git a/tests/test_retriever.py b/tests/test_retriever.py index 910183789..bce9a402b 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -16,7 +16,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import client @@ -42,11 +42,11 @@ def add_client_method(f): @add_client_method def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo_corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -55,11 +55,11 @@ def create_corpus( @add_client_method def get_corpus( - request: glm.GetCorpusRequest, + request: protos.GetCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo_corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -68,11 +68,11 @@ def get_corpus( @add_client_method def update_corpus( - request: glm.UpdateCorpusRequest, + request: protos.UpdateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", @@ -81,18 +81,18 @@ def update_corpus( @add_client_method def list_corpora( - request: glm.ListCorporaRequest, + request: protos.ListCorporaRequest, **kwargs, - ) -> glm.ListCorporaResponse: + ) -> protos.ListCorporaResponse: self.observed_requests.append(request) return [ - glm.Corpus( + protos.Corpus( name="corpora/demo_corpus-1", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Corpus( + protos.Corpus( name="corpora/demo-corpus-2", display_name="demo-corpus-2", create_time="2000-01-01T01:01:01.123456Z", @@ -102,15 +102,15 @@ def list_corpora( @add_client_method def query_corpus( - request: glm.QueryCorpusRequest, + request: protos.QueryCorpusRequest, **kwargs, - ) -> glm.QueryCorpusResponse: + ) -> protos.QueryCorpusResponse: self.observed_requests.append(request) - return glm.QueryCorpusResponse( + return protos.QueryCorpusResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -124,18 +124,18 @@ def query_corpus( @add_client_method def delete_corpus( - request: glm.DeleteCorpusRequest, + request: protos.DeleteCorpusRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def create_document( - request: glm.CreateDocumentRequest, + request: protos.CreateDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -144,11 +144,11 @@ def create_document( @add_client_method def get_document( - request: glm.GetDocumentRequest, + request: protos.GetDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo_doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -157,11 +157,11 @@ def get_document( @add_client_method def update_document( - request: glm.UpdateDocumentRequest, + request: protos.UpdateDocumentRequest, **kwargs, - ) -> glm.Document: + ) -> protos.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo_doc", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", @@ -170,18 +170,18 @@ def update_document( @add_client_method def list_documents( - request: glm.ListDocumentsRequest, + request: protos.ListDocumentsRequest, **kwargs, - ) -> glm.ListDocumentsResponse: + ) -> protos.ListDocumentsResponse: self.observed_requests.append(request) return [ - glm.Document( + protos.Document( name="corpora/demo-corpus/documents/demo_doc_1", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Document( + protos.Document( name="corpora/demo-corpus/documents/demo_doc_2", display_name="demo-doc-2", create_time="2000-01-01T01:01:01.123456Z", @@ -191,22 +191,22 @@ def list_documents( @add_client_method def delete_document( - request: glm.DeleteDocumentRequest, + request: protos.DeleteDocumentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def query_document( - request: glm.QueryDocumentRequest, + request: protos.QueryDocumentRequest, **kwargs, - ) -> glm.QueryDocumentResponse: + ) -> protos.QueryDocumentResponse: self.observed_requests.append(request) - return glm.QueryDocumentResponse( + return protos.QueryDocumentResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -220,11 +220,11 @@ def query_document( @add_client_method def create_chunk( - request: glm.CreateChunkRequest, + request: protos.CreateChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -233,19 +233,19 @@ def create_chunk( @add_client_method def batch_create_chunks( - request: glm.BatchCreateChunksRequest, + request: protos.BatchCreateChunksRequest, **kwargs, - ) -> glm.BatchCreateChunksResponse: + ) -> protos.BatchCreateChunksResponse: self.observed_requests.append(request) - return glm.BatchCreateChunksResponse( + return protos.BatchCreateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/dc", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/dc1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -256,11 +256,11 @@ def batch_create_chunks( @add_client_method def get_chunk( - request: glm.GetChunkRequest, + request: protos.GetChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -269,18 +269,18 @@ def get_chunk( @add_client_method def list_chunks( - request: glm.ListChunksRequest, + request: protos.ListChunksRequest, **kwargs, - ) -> glm.ListChunksResponse: + ) -> protos.ListChunksResponse: self.observed_requests.append(request) return [ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk-1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -290,17 +290,17 @@ def list_chunks( @add_client_method def update_chunk( - request: glm.UpdateChunkRequest, + request: protos.UpdateChunkRequest, **kwargs, - ) -> glm.Chunk: + ) -> protos.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated demo chunk."}, custom_metadata=[ - glm.CustomMetadata( + protos.CustomMetadata( key="tags", - string_list_value=glm.StringList( + string_list_value=protos.StringList( values=["Google For Developers", "Project IDX", "Blog", "Announcement"] ), ) @@ -311,19 +311,19 @@ def update_chunk( @add_client_method def batch_update_chunks( - request: glm.BatchUpdateChunksRequest, + request: protos.BatchUpdateChunksRequest, **kwargs, - ) -> glm.BatchUpdateChunksResponse: + ) -> protos.BatchUpdateChunksResponse: self.observed_requests.append(request) - return glm.BatchUpdateChunksResponse( + return protos.BatchUpdateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk-1", data={"string_value": "This is another updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -334,14 +334,14 @@ def batch_update_chunks( @add_client_method def delete_chunk( - request: glm.DeleteChunkRequest, + request: protos.DeleteChunkRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def batch_delete_chunks( - request: glm.BatchDeleteChunksRequest, + request: protos.BatchDeleteChunksRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -366,7 +366,7 @@ def test_get_corpus(self, name="demo-corpus"): def test_update_corpus(self): demo_corpus = retriever.create_corpus(name="demo-corpus") update_request = demo_corpus.update(updates={"display_name": "demo-corpus_1"}) - self.assertIsInstance(self.observed_requests[-1], glm.UpdateCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdateCorpusRequest) self.assertEqual("demo-corpus_1", demo_corpus.display_name) def test_list_corpora(self): @@ -402,7 +402,7 @@ def test_delete_corpus(self): demo_corpus = retriever.create_corpus(name="demo-corpus") demo_document = demo_corpus.create_document(name="demo-doc") delete_request = retriever.delete_corpus(name="corpora/demo_corpus", force=True) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCorpusRequest) def test_create_document(self, display_name="demo-doc"): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -433,7 +433,7 @@ def test_delete_document(self): demo_document = demo_corpus.create_document(name="demo-doc") demo_doc2 = demo_corpus.create_document(name="demo-doc-2") delete_request = demo_corpus.delete_document(name="corpora/demo-corpus/documents/demo_doc") - self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteDocumentRequest) def test_list_documents(self): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -521,7 +521,7 @@ def test_batch_create_chunks(self, chunks): demo_corpus = retriever.create_corpus(name="demo-corpus") demo_document = demo_corpus.create_document(name="demo-doc") chunks = demo_document.batch_create_chunks(chunks=chunks) - self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchCreateChunksRequest) self.assertEqual("This is a demo chunk.", chunks[0].data.string_value) self.assertEqual("This is another demo chunk.", chunks[1].data.string_value) @@ -548,7 +548,7 @@ def test_list_chunks(self): ) list_req = list(demo_document.list_chunks()) - self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListChunksRequest) self.assertLen(list_req, 2) def test_update_chunk(self): @@ -615,7 +615,7 @@ def test_batch_update_chunks_data_structures(self, updates): data="This is another demo chunk.", ) update_request = demo_document.batch_update_chunks(chunks=updates) - self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchUpdateChunksRequest) self.assertEqual( "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] ) @@ -631,7 +631,7 @@ def test_delete_chunk(self): data="This is a demo chunk.", ) delete_request = demo_document.delete_chunk(name="demo-chunk") - self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteChunkRequest) def test_batch_delete_chunks(self): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -645,7 +645,7 @@ def test_batch_delete_chunks(self): data="This is another demo chunk.", ) delete_request = demo_document.batch_delete_chunks(chunks=[x.name, y.name]) - self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchDeleteChunksRequest) @parameterized.parameters( {"method": "create_corpus"}, diff --git a/tests/test_retriever_async.py b/tests/test_retriever_async.py index b764c23b2..bb0c862d1 100644 --- a/tests/test_retriever_async.py +++ b/tests/test_retriever_async.py @@ -19,7 +19,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import client as client_lib @@ -44,11 +44,11 @@ def add_client_method(f): @add_client_method async def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -57,11 +57,11 @@ async def create_corpus( @add_client_method async def get_corpus( - request: glm.GetCorpusRequest, + request: protos.GetCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -70,11 +70,11 @@ async def get_corpus( @add_client_method async def update_corpus( - request: glm.UpdateCorpusRequest, + request: protos.UpdateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", @@ -83,19 +83,19 @@ async def update_corpus( @add_client_method async def list_corpora( - request: glm.ListCorporaRequest, + request: protos.ListCorporaRequest, **kwargs, - ) -> glm.ListCorporaResponse: + ) -> protos.ListCorporaResponse: self.observed_requests.append(request) async def results(): - yield glm.Corpus( + yield protos.Corpus( name="corpora/demo-corpus-1", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Corpus( + yield protos.Corpus( name="corpora/demo-corpus_2", display_name="demo-corpus-2", create_time="2000-01-01T01:01:01.123456Z", @@ -106,15 +106,15 @@ async def results(): @add_client_method async def query_corpus( - request: glm.QueryCorpusRequest, + request: protos.QueryCorpusRequest, **kwargs, - ) -> glm.QueryCorpusResponse: + ) -> protos.QueryCorpusResponse: self.observed_requests.append(request) - return glm.QueryCorpusResponse( + return protos.QueryCorpusResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -128,18 +128,18 @@ async def query_corpus( @add_client_method async def delete_corpus( - request: glm.DeleteCorpusRequest, + request: protos.DeleteCorpusRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def create_document( - request: glm.CreateDocumentRequest, + request: protos.CreateDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -148,11 +148,11 @@ async def create_document( @add_client_method async def get_document( - request: glm.GetDocumentRequest, + request: protos.GetDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -161,11 +161,11 @@ async def get_document( @add_client_method async def update_document( - request: glm.UpdateDocumentRequest, + request: protos.UpdateDocumentRequest, **kwargs, - ) -> glm.Document: + ) -> protos.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", @@ -174,19 +174,19 @@ async def update_document( @add_client_method async def list_documents( - request: glm.ListDocumentsRequest, + request: protos.ListDocumentsRequest, **kwargs, - ) -> glm.ListDocumentsResponse: + ) -> protos.ListDocumentsResponse: self.observed_requests.append(request) async def results(): - yield glm.Document( + yield protos.Document( name="corpora/demo-corpus/documents/dem-doc_1", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Document( + yield protos.Document( name="corpora/demo-corpus/documents/dem-doc_2", display_name="demo-doc_2", create_time="2000-01-01T01:01:01.123456Z", @@ -197,22 +197,22 @@ async def results(): @add_client_method async def delete_document( - request: glm.DeleteDocumentRequest, + request: protos.DeleteDocumentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def query_document( - request: glm.QueryDocumentRequest, + request: protos.QueryDocumentRequest, **kwargs, - ) -> glm.QueryDocumentResponse: + ) -> protos.QueryDocumentResponse: self.observed_requests.append(request) - return glm.QueryDocumentResponse( + return protos.QueryDocumentResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -226,11 +226,11 @@ async def query_document( @add_client_method async def create_chunk( - request: glm.CreateChunkRequest, + request: protos.CreateChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -239,19 +239,19 @@ async def create_chunk( @add_client_method async def batch_create_chunks( - request: glm.BatchCreateChunksRequest, + request: protos.BatchCreateChunksRequest, **kwargs, - ) -> glm.BatchCreateChunksResponse: + ) -> protos.BatchCreateChunksResponse: self.observed_requests.append(request) - return glm.BatchCreateChunksResponse( + return protos.BatchCreateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/dc", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/dc1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -262,11 +262,11 @@ async def batch_create_chunks( @add_client_method async def get_chunk( - request: glm.GetChunkRequest, + request: protos.GetChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -275,19 +275,19 @@ async def get_chunk( @add_client_method async def list_chunks( - request: glm.ListChunksRequest, + request: protos.ListChunksRequest, **kwargs, - ) -> glm.ListChunksResponse: + ) -> protos.ListChunksResponse: self.observed_requests.append(request) async def results(): - yield glm.Chunk( + yield protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Chunk( + yield protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk-1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -298,11 +298,11 @@ async def results(): @add_client_method async def update_chunk( - request: glm.UpdateChunkRequest, + request: protos.UpdateChunkRequest, **kwargs, - ) -> glm.Chunk: + ) -> protos.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -311,19 +311,19 @@ async def update_chunk( @add_client_method async def batch_update_chunks( - request: glm.BatchUpdateChunksRequest, + request: protos.BatchUpdateChunksRequest, **kwargs, - ) -> glm.BatchUpdateChunksResponse: + ) -> protos.BatchUpdateChunksResponse: self.observed_requests.append(request) - return glm.BatchUpdateChunksResponse( + return protos.BatchUpdateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk-1", data={"string_value": "This is another updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -334,14 +334,14 @@ async def batch_update_chunks( @add_client_method async def delete_chunk( - request: glm.DeleteChunkRequest, + request: protos.DeleteChunkRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def batch_delete_chunks( - request: glm.BatchDeleteChunksRequest, + request: protos.BatchDeleteChunksRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -398,7 +398,7 @@ async def test_delete_corpus(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") demo_document = await demo_corpus.create_document_async(name="demo-doc") delete_request = await retriever.delete_corpus_async(name="corpora/demo-corpus", force=True) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCorpusRequest) async def test_create_document(self, display_name="demo-doc"): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -425,7 +425,7 @@ async def test_delete_document(self): delete_request = await demo_corpus.delete_document_async( name="corpora/demo-corpus/documents/demo-doc" ) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteDocumentRequest) async def test_list_documents(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -513,7 +513,7 @@ async def test_batch_create_chunks(self, chunks): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") demo_document = await demo_corpus.create_document_async(name="demo-doc") chunks = await demo_document.batch_create_chunks_async(chunks=chunks) - self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchCreateChunksRequest) self.assertEqual("This is a demo chunk.", chunks[0].data.string_value) self.assertEqual("This is another demo chunk.", chunks[1].data.string_value) @@ -541,7 +541,7 @@ async def test_list_chunks(self): chunks = [] async for chunk in demo_document.list_chunks_async(): chunks.append(chunk) - self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListChunksRequest) self.assertLen(chunks, 2) async def test_update_chunk(self): @@ -597,7 +597,7 @@ async def test_batch_update_chunks_data_structures(self, updates): data="This is another demo chunk.", ) update_request = await demo_document.batch_update_chunks_async(chunks=updates) - self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchUpdateChunksRequest) self.assertEqual( "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] ) @@ -615,7 +615,7 @@ async def test_delete_chunk(self): delete_request = await demo_document.delete_chunk_async( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk" ) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteChunkRequest) async def test_batch_delete_chunks(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -629,7 +629,7 @@ async def test_batch_delete_chunks(self): data="This is another demo chunk.", ) delete_request = await demo_document.batch_delete_chunks_async(chunks=[x.name, y.name]) - self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchDeleteChunksRequest) async def test_get_corpus_called_with_request_options(self): self.client.get_corpus = unittest.mock.AsyncMock() diff --git a/tests/test_text.py b/tests/test_text.py index 5dcda93b9..50c1b5539 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import text as text_service from google.generativeai import client @@ -46,42 +46,42 @@ def add_client_method(f): @add_client_method def generate_text( - request: glm.GenerateTextRequest, + request: protos.GenerateTextRequest, **kwargs, - ) -> glm.GenerateTextResponse: + ) -> protos.GenerateTextResponse: self.observed_requests.append(request) return self.responses["generate_text"] @add_client_method def embed_text( - request: glm.EmbedTextRequest, + request: protos.EmbedTextRequest, **kwargs, - ) -> glm.EmbedTextResponse: + ) -> protos.EmbedTextResponse: self.observed_requests.append(request) return self.responses["embed_text"] @add_client_method def batch_embed_text( - request: glm.EmbedTextRequest, + request: protos.EmbedTextRequest, **kwargs, - ) -> glm.EmbedTextResponse: + ) -> protos.EmbedTextResponse: self.observed_requests.append(request) - return glm.BatchEmbedTextResponse( - embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts) + return protos.BatchEmbedTextResponse( + embeddings=[protos.Embedding(value=[1, 2, 3])] * len(request.texts) ) @add_client_method def count_text_tokens( - request: glm.CountTextTokensRequest, + request: protos.CountTextTokensRequest, **kwargs, - ) -> glm.CountTextTokensResponse: + ) -> protos.CountTextTokensResponse: self.observed_requests.append(request) return self.responses["count_text_tokens"] @add_client_method - def get_tuned_model(name) -> glm.TunedModel: - request = glm.GetTunedModelRequest(name=name) + def get_tuned_model(name) -> protos.TunedModel: + request = protos.GetTunedModelRequest(name=name) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @@ -93,7 +93,7 @@ def get_tuned_model(name) -> glm.TunedModel: ) def test_make_prompt(self, prompt): x = text_service._make_text_prompt(prompt) - self.assertIsInstance(x, glm.TextPrompt) + self.assertIsInstance(x, protos.TextPrompt) self.assertEqual("Hello how are", x.text) @parameterized.named_parameters( @@ -104,7 +104,7 @@ def test_make_prompt(self, prompt): def test_make_generate_text_request(self, prompt): x = text_service._make_generate_text_request(model="models/chat-bison-001", prompt=prompt) self.assertEqual("models/chat-bison-001", x.model) - self.assertIsInstance(x, glm.GenerateTextRequest) + self.assertIsInstance(x, protos.GenerateTextRequest) @parameterized.named_parameters( [ @@ -116,14 +116,14 @@ def test_make_generate_text_request(self, prompt): ] ) def test_generate_embeddings(self, model, text): - self.responses["embed_text"] = glm.EmbedTextResponse( - embedding=glm.Embedding(value=[1, 2, 3]) + self.responses["embed_text"] = protos.EmbedTextResponse( + embedding=protos.Embedding(value=[1, 2, 3]) ) emb = text_service.generate_embeddings(model=model, text=text) self.assertIsInstance(emb, dict) - self.assertEqual(self.observed_requests[-1], glm.EmbedTextRequest(model=model, text=text)) + self.assertEqual(self.observed_requests[-1], protos.EmbedTextRequest(model=model, text=text)) self.assertIsInstance(emb["embedding"][0], float) @parameterized.named_parameters( @@ -191,11 +191,11 @@ def test_generate_embeddings_batch(self, model, text): ] ) def test_generate_response(self, *, prompt, **kwargs): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output=" road?"), - glm.TextCompletion(output=" bridge?"), - glm.TextCompletion(output=" river?"), + protos.TextCompletion(output=" road?"), + protos.TextCompletion(output=" bridge?"), + protos.TextCompletion(output=" river?"), ] ) @@ -203,8 +203,8 @@ def test_generate_response(self, *, prompt, **kwargs): self.assertEqual( self.observed_requests[-1], - glm.GenerateTextRequest( - model="models/text-bison-001", prompt=glm.TextPrompt(text=prompt), **kwargs + protos.GenerateTextRequest( + model="models/text-bison-001", prompt=protos.TextPrompt(text=prompt), **kwargs ), ) @@ -220,20 +220,20 @@ def test_generate_response(self, *, prompt, **kwargs): ) def test_stop_string(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output="Hello world?"), - glm.TextCompletion(output="Hell!"), - glm.TextCompletion(output="I'm going to stop"), + protos.TextCompletion(output="Hello world?"), + protos.TextCompletion(output="Hell!"), + protos.TextCompletion(output="I'm going to stop"), ] ) complete = text_service.generate_text(prompt="Hello", stop_sequences="stop") self.assertEqual( self.observed_requests[-1], - glm.GenerateTextRequest( + protos.GenerateTextRequest( model="models/text-bison-001", - prompt=glm.TextPrompt(text="Hello"), + prompt=protos.TextPrompt(text="Hello"), stop_sequences=["stop"], ), ) @@ -282,9 +282,9 @@ def test_stop_string(self): ] ) def test_safety_settings(self, safety_settings): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output="No"), + protos.TextCompletion(output="No"), ] ) # This test really just checks that the safety_settings get converted to a proto. @@ -298,7 +298,7 @@ def test_safety_settings(self, safety_settings): ) def test_filters(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[{"output": "hello"}], filters=[ { @@ -313,7 +313,7 @@ def test_filters(self): self.assertEqual(response.filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) def test_safety_feedback(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[{"output": "hello"}], safety_feedback=[ { @@ -341,7 +341,7 @@ def test_safety_feedback(self): self.assertIsInstance( response.safety_feedback[0]["setting"]["category"], - glm.HarmCategory, + protos.HarmCategory, ) self.assertEqual( response.safety_feedback[0]["setting"]["category"], @@ -349,7 +349,7 @@ def test_safety_feedback(self): ) def test_candidate_safety_feedback(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ { "output": "hello", @@ -370,7 +370,7 @@ def test_candidate_safety_feedback(self): result = text_service.generate_text(prompt="Write a story from the ER.") self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["category"], - glm.HarmCategory, + protos.HarmCategory, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["category"], @@ -387,7 +387,7 @@ def test_candidate_safety_feedback(self): ) def test_candidate_citations(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ { "output": "Hello Google!", @@ -434,21 +434,21 @@ def test_candidate_citations(self): ), ), dict( - testcase_name="glm_model", - model=glm.Model( + testcase_name="protos.model", + model=protos.Model( name="models/text-bison-001", ), ), dict( - testcase_name="glm_tuned_model", - model=glm.TunedModel( + testcase_name="protos.tuned_model", + model=protos.TunedModel( name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001", ), ), dict( - testcase_name="glm_tuned_model_nested", - model=glm.TunedModel( + testcase_name="protos.tuned_model_nested", + model=protos.TunedModel( name="tunedModels/bipedal-pangolin-002", tuned_model_source={ "tuned_model": "tunedModels/bipedal-pangolin-002", @@ -459,10 +459,10 @@ def test_candidate_citations(self): ] ) def test_count_message_tokens(self, model): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001" ) - self.responses["count_text_tokens"] = glm.CountTextTokensResponse(token_count=7) + self.responses["count_text_tokens"] = protos.CountTextTokensResponse(token_count=7) response = text_service.count_text_tokens(model, "Tell me a story about a magic backpack.") self.assertEqual({"token_count": 7}, response) @@ -472,7 +472,7 @@ def test_count_message_tokens(self, model): self.assertLen(self.observed_requests, 2) self.assertEqual( self.observed_requests[0], - glm.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), + protos.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), ) def test_count_text_tokens_called_with_request_options(self): From fcae45ba49dff6d4f603e5c58082ea9a40d1f172 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 06:35:45 -0700 Subject: [PATCH 02/16] Add genai.protos Change-Id: I9c8473d4ca1a0e92489f145a18ef1abd29af22b3 --- google/generativeai/protos.py | 78 +++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 google/generativeai/protos.py diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py new file mode 100644 index 000000000..fd6a96541 --- /dev/null +++ b/google/generativeai/protos.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module publishes the proto classes from `google.ai.generativelanguage`. + +`google.ai.generativelanguage` is a low-level auto-generated client library for the Gemini API. + +It is built using the same tooling as Google Cloud client libraries, and will be quite familiar if you've used +those before. + +While we encourage Python users to access the Geini API using the `google.generativeai` package (aka `palm`), +the lower level package is also available. + +Each method in the Gemini API is connected to one of the client classes. Pass your API-key to the class' `client_options` +when initializing a client: + +``` +from google.generativeai import protos + +client = protos.DiscussServiceClient( + client_options={'api_key':'YOUR_API_KEY'}) +``` + +To call the api, pass an appropriate request-proto-object. For the `DiscussServiceClient.generate_message` pass +a `generativelanguage.GenerateMessageRequest` instance: + +``` +request = protos.GenerateMessageRequest( + model='models/chat-bison-001', + prompt=protos.MessagePrompt( + messages=[protos.Message(content='Hello!')])) + +client.generate_message(request) +``` +``` +candidates { + author: "1" + content: "Hello! How can I help you today?" +} +... +``` + +For simplicity: + +* The API methods also accept key-word arguments. +* Anywhere you might pass a proto-object, the library will also accept simple python structures. + +So the following is equivalent to the previous example: + +``` +client.generate_message( + model='models/chat-bison-001', + prompt={'messages':[{'content':'Hello!'}]}) +``` +``` +candidates { + author: "1" + content: "Hello! How can I help you today?" +} +... +``` + +""" + +from google.ai.generativelanguage_v1beta.types import * +from google.ai.generativelanguage_v1beta.types import __all__ From f581c0d04bf4ff53e5b376a232ad1feeef91343c Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 11:06:52 -0700 Subject: [PATCH 03/16] test_protos.py Change-Id: I576080fb80cf9dc9345d8bb2178eb4b9ac59ce97 --- tests/test_protos.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 tests/test_protos.py diff --git a/tests/test_protos.py b/tests/test_protos.py new file mode 100644 index 000000000..82ffa3831 --- /dev/null +++ b/tests/test_protos.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib + +from absl.testing import parameterized +import google.ai.generativelanguage as glm +from google.generativeai import responder + +ROOT = pathlib.Path(__file__).parent.parent + + +class UnitTests(parameterized.TestCase): + def test_check_glm_imports(self): + for fpath in ROOT.rglob("*.py"): + if fpath.name in ["client.py", "discuss.py", "test_protos.py", "test_client.py", "build_docs.py"]: + continue + + content = fpath.read_text() + self.assertNotRegex(content, 'import google.ai.generativelanguage|from google.ai import generativelanguage', + msg=f'generativelanguage found in {fpath}') From 9bb1cf02cbc43513200768af4aabd66cb7458614 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 12:16:55 -0700 Subject: [PATCH 04/16] fix docs + format Change-Id: I5f9aa3f8e3ae780e5cec2078d3eb153157b195fe --- docs/build_docs.py | 51 +------------------- google/generativeai/__init__.py | 1 + google/generativeai/answer.py | 4 +- google/generativeai/discuss.py | 4 +- google/generativeai/responder.py | 3 +- google/generativeai/types/content_types.py | 3 +- google/generativeai/types/retriever_types.py | 16 ++++-- tests/test_answer.py | 50 +++++++++++++++---- tests/test_discuss.py | 8 ++- tests/test_embedding.py | 3 +- tests/test_embedding_async.py | 3 +- tests/test_generation.py | 31 +++++++++--- tests/test_generative_models_async.py | 4 +- tests/test_models.py | 4 +- tests/test_protos.py | 15 ++++-- tests/test_text.py | 4 +- 16 files changed, 119 insertions(+), 85 deletions(-) diff --git a/docs/build_docs.py b/docs/build_docs.py index 280738700..012cd3441 100644 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -46,8 +46,8 @@ typing.TYPE_CHECKING = True from google import generativeai as genai - from tensorflow_docs.api_generator import generate_lib +from tensorflow_docs.api_generator import public_api import yaml @@ -75,33 +75,6 @@ ) -class MyFilter: - def __init__(self, base_dirs): - self.filter_base_dirs = public_api.FilterBaseDirs(base_dirs) - - def __call__(self, path, parent, children): - if any("generativelanguage" in part for part in path) or "generativeai" in path: - children = self.filter_base_dirs(path, parent, children) - children = public_api.explicit_package_contents_filter(path, parent, children) - - return children - - -class MyDocGenerator(generate_lib.DocGenerator): - def make_default_filters(self): - return [ - # filter the api. - public_api.FailIfNestedTooDeep(10), - public_api.filter_module_all, - public_api.add_proto_fields, - public_api.filter_private_symbols, - MyFilter(self._base_dir), # Replaces: public_api.FilterBaseDirs(self._base_dir), - public_api.FilterPrivateMap(self._private_map), - public_api.filter_doc_controls_skip, - public_api.ignore_typing, - ] - - def gen_api_docs(): """Generates api docs for the generative-ai package.""" for name in dir(google): @@ -127,32 +100,12 @@ def gen_api_docs(): ), search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - callbacks=[], + callbacks=[public_api.explicit_package_contents_filter], ) out_path = pathlib.Path(_OUTPUT_DIR.value) doc_generator.build(out_path) - # Fixup the toc file. - toc_path = out_path / "google/_toc.yaml" - toc = yaml.safe_load(toc_path.read_text()) - assert toc["toc"][0]["title"] == "google" - toc["toc"] = toc["toc"][1:] - toc["toc"][0]["title"] = "google.ai.generativelanguage" - toc["toc"][0]["section"] = toc["toc"][0]["section"][1]["section"] - toc["toc"][0], toc["toc"][1] = toc["toc"][1], toc["toc"][0] - toc_path.write_text(yaml.dump(toc)) - - # remove some dummy files and redirect them to `api/` - (out_path / "google.md").unlink() - (out_path / "google/ai.md").unlink() - redirects_path = out_path / "_redirects.yaml" - redirects = {"redirects": []} - redirects["redirects"].insert(0, {"from": "/api/python/google/ai", "to": "/api/"}) - redirects["redirects"].insert(0, {"from": "/api/python/google", "to": "/api/"}) - redirects["redirects"].insert(0, {"from": "/api/python", "to": "/api/"}) - redirects_path.write_text(yaml.dump(redirects)) - # clear `oneof` junk from proto pages for fpath in out_path.rglob("*.md"): old_content = fpath.read_text() diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 53383a1b3..b060def8d 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -42,6 +42,7 @@ from google.generativeai import version +from google.generativeai import protos from google.generativeai import types from google.generativeai.types import GenerationConfig diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index eefe6e68d..25b45d95a 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -66,7 +66,9 @@ def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: GroundingPassageOptions = ( - Union[protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], + Union[ + protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType + ], ) GroundingPassagesOptions = Union[ diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 4cea08421..172aa76e4 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -447,7 +447,9 @@ async def chat_async( @string_utils.set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): - _client: protos.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False) + _client: protos.DiscussServiceClient | None = dataclasses.field( + default=lambda: None, repr=False + ) def __init__(self, **kwargs): for key, value in kwargs.items(): diff --git a/google/generativeai/responder.py b/google/generativeai/responder.py index 612923b03..6d3a562f6 100644 --- a/google/generativeai/responder.py +++ b/google/generativeai/responder.py @@ -483,7 +483,8 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.Functio obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Could not convert input to `protos.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", + f"Could not convert input to `protos.FunctionCallingConfig`: \n'" + f" type: {type(obj)}\n", obj, ) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index afa33e34e..00e0b7066 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -799,7 +799,8 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.Functio obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Could not convert input to `protos.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", + f"Could not convert input to `protos.FunctionCallingConfig`: \n'" + f" type: {type(obj)}\n", obj, ) diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index f7890fed6..3af5e4081 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -785,7 +785,9 @@ def create_chunk( chunk_name = name if isinstance(data, str): - chunk = protos.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=c_data + ) else: chunk = protos.Chunk( name=chunk_name, @@ -827,7 +829,9 @@ async def create_chunk_async( chunk_name = name if isinstance(data, str): - chunk = protos.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=c_data + ) else: chunk = protos.Chunk( name=chunk_name, @@ -1255,7 +1259,9 @@ def batch_update_chunks( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - protos.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest( + chunk=chunk_to_update.to_dict(), update_mask=field_mask + ) ) request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) @@ -1343,7 +1349,9 @@ async def batch_update_chunks_async( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - protos.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest( + chunk=chunk_to_update.to_dict(), update_mask=field_mask + ) ) request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) diff --git a/tests/test_answer.py b/tests/test_answer.py index 25f824e86..2669b207c 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -70,8 +70,14 @@ def test_make_grounding_passages_mixed_types(self): self.assertEqual( protos.GroundingPassages( passages=[ - {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, - {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + { + "id": "0", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), @@ -88,8 +94,14 @@ def test_make_grounding_passages_mixed_types(self): "id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), }, - {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, - {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + { + "id": "2", + "content": protos.Content(parts=[protos.Part(text="I can fly!")]), + }, ] ), ), @@ -113,8 +125,14 @@ def test_make_grounding_passages(self, inline_passages): self.assertEqual( protos.GroundingPassages( passages=[ - {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, - {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + { + "id": "0", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), @@ -151,8 +169,14 @@ def test_make_grounding_passages_different_id(self, inline_passages): self.assertEqual( protos.GroundingPassages( passages=[ - {"id": "4", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, - {"id": "5", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + { + "id": "4", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "5", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, {"id": "6", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), @@ -175,8 +199,14 @@ def test_make_grounding_passages_key_strings(self): "id": "first", "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), }, - {"id": "second", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, - {"id": "third", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, + { + "id": "second", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + { + "id": "third", + "content": protos.Content(parts=[protos.Part(text="I can fly!")]), + }, ] ), x, diff --git a/tests/test_discuss.py b/tests/test_discuss.py index e7411bcb6..4e54cf754 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -288,7 +288,9 @@ def test_receive_and_reply_with_filters(self): self.mock_response = mock_response = protos.GenerateMessageResponse( candidates=[protos.Message(content="a", author="1")], filters=[ - protos.ContentFilter(reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe"), + protos.ContentFilter( + reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe" + ), protos.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), ], ) @@ -303,7 +305,9 @@ def test_receive_and_reply_with_filters(self): self.mock_response = protos.GenerateMessageResponse( candidates=[protos.Message(content="a", author="1")], filters=[ - protos.ContentFilter(reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + protos.ContentFilter( + reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED + ) ], ) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 921ad46a6..a208a4743 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -69,7 +69,8 @@ def test_embed_content(self): self.assertEqual( self.observed_requests[-1], protos.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=protos.Content(parts=[protos.Part(text="What are you?")]) + model=DEFAULT_EMB_MODEL, + content=protos.Content(parts=[protos.Part(text="What are you?")]), ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_embedding_async.py b/tests/test_embedding_async.py index 6e8887bb9..367cf7ded 100644 --- a/tests/test_embedding_async.py +++ b/tests/test_embedding_async.py @@ -68,7 +68,8 @@ async def test_embed_content_async(self): self.assertEqual( self.observed_requests[-1], protos.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=protos.Content(parts=[protos.Part(text="What are you?")]) + model=DEFAULT_EMB_MODEL, + content=protos.Content(parts=[protos.Part(text="What are you?")]), ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_generation.py b/tests/test_generation.py index 1b50badaf..dd8ed87f7 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -26,7 +26,9 @@ class UnitTests(parameterized.TestCase): [ "protos.GenerationConfig", protos.GenerationConfig( - temperature=0.1, stop_sequences=["end"], response_schema=protos.Schema(type="STRING") + temperature=0.1, + stop_sequences=["end"], + response_schema=protos.Schema(type="STRING"), ), ], [ @@ -126,7 +128,8 @@ def test_many_join_contents(self): import string contents = [ - protos.Content(role="assistant", parts=[protos.Part(text=a)]) for a in string.ascii_lowercase + protos.Content(role="assistant", parts=[protos.Part(text=a)]) + for a in string.ascii_lowercase ] result = generation_types._join_contents(contents) @@ -147,7 +150,9 @@ def test_join_candidates(self): ), citation_metadata=protos.CitationMetadata( citation_sources=[ - protos.CitationSource(start_index=55, end_index=85, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=85, uri="https://google.com" + ), ] ), ), @@ -159,8 +164,12 @@ def test_join_candidates(self): ), citation_metadata=protos.CitationMetadata( citation_sources=[ - protos.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - protos.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=92, uri="https://google.com" + ), + protos.CitationSource( + start_index=3, end_index=21, uri="https://google.com" + ), ] ), ), @@ -168,12 +177,18 @@ def test_join_candidates(self): index=0, content=protos.Content( role="assistant", - parts=[protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[ + protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!")) + ], ), citation_metadata=protos.CitationMetadata( citation_sources=[ - protos.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - protos.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=92, uri="https://google.com" + ), + protos.CitationSource( + start_index=3, end_index=21, uri="https://google.com" + ), ] ), finish_reason="STOP", diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index b5babda1e..03055ffb3 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -31,7 +31,9 @@ def simple_response(text: str) -> protos.GenerateContentResponse: - return protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) + return protos.GenerateContentResponse( + {"candidates": [{"content": {"parts": [{"text": text}]}}]} + ) class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): diff --git a/tests/test_models.py b/tests/test_models.py index c591c7f71..67d533df4 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -55,7 +55,9 @@ def add_client_method(f): self.responses = {} @add_client_method - def get_model(request: Union[protos.GetModelRequest, None] = None, *, name=None) -> protos.Model: + def get_model( + request: Union[protos.GetModelRequest, None] = None, *, name=None + ) -> protos.Model: if request is None: request = protos.GetModelRequest(name=name) self.assertIsInstance(request, protos.GetModelRequest) diff --git a/tests/test_protos.py b/tests/test_protos.py index 82ffa3831..f0d0497b5 100644 --- a/tests/test_protos.py +++ b/tests/test_protos.py @@ -24,9 +24,18 @@ class UnitTests(parameterized.TestCase): def test_check_glm_imports(self): for fpath in ROOT.rglob("*.py"): - if fpath.name in ["client.py", "discuss.py", "test_protos.py", "test_client.py", "build_docs.py"]: + if fpath.name in [ + "client.py", + "discuss.py", + "test_protos.py", + "test_client.py", + "build_docs.py", + ]: continue content = fpath.read_text() - self.assertNotRegex(content, 'import google.ai.generativelanguage|from google.ai import generativelanguage', - msg=f'generativelanguage found in {fpath}') + self.assertNotRegex( + content, + "import google.ai.generativelanguage|from google.ai import generativelanguage", + msg=f"generativelanguage found in {fpath}", + ) diff --git a/tests/test_text.py b/tests/test_text.py index 50c1b5539..795c3dfcd 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -123,7 +123,9 @@ def test_generate_embeddings(self, model, text): emb = text_service.generate_embeddings(model=model, text=text) self.assertIsInstance(emb, dict) - self.assertEqual(self.observed_requests[-1], protos.EmbedTextRequest(model=model, text=text)) + self.assertEqual( + self.observed_requests[-1], protos.EmbedTextRequest(model=model, text=text) + ) self.assertIsInstance(emb["embedding"][0], float) @parameterized.named_parameters( From d2ebda1054e62ec0c3ba86d3eb291c96862a8716 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 12:26:02 -0700 Subject: [PATCH 05/16] fix merge Change-Id: I17014791d966d797b481bca17df69558b23a9a1a --- google/generativeai/generative_models.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index e2be3ca60..c3b34de42 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -128,10 +128,7 @@ def _prepare_request( tools: content_types.FunctionLibraryType | None, tool_config: content_types.ToolConfigType | None, ) -> protos.GenerateContentRequest: - """Creates a `glm.GenerateContentRequest` from raw inputs.""" - if not contents: - raise TypeError("contents must not be empty") - + """Creates a `protos.GenerateContentRequest` from raw inputs.""" tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() From 8aa737d7ea15ac67175ad7f989bfb06e2561979a Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 21 May 2024 12:32:25 -0700 Subject: [PATCH 06/16] format Change-Id: I51d30f6568640456bcf28db2bd338a58a82346de --- tests/test_generative_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index cc21e1295..4392cb184 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -24,11 +24,11 @@ def simple_part(text: str) -> protos.Content: return protos.Content({"parts": [{"text": text}]}) + def noop(x: int): return x - def iter_part(texts: Iterable[str]) -> protos.Content: return protos.Content({"parts": [{"text": t} for t in texts]}) From c06b62c57f3694db661b05241fb8bf3733c4065f Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 23 May 2024 13:21:05 -0700 Subject: [PATCH 07/16] Fix client references Change-Id: I4899231706c9624a0f189b22b6f70aeeb4cbea29 --- google/generativeai/answer.py | 9 ++- google/generativeai/discuss.py | 18 ++--- google/generativeai/embedding.py | 13 ++-- google/generativeai/models.py | 22 +++--- google/generativeai/permission.py | 5 +- google/generativeai/protos.py | 2 +- google/generativeai/retriever.py | 16 ++-- google/generativeai/text.py | 20 ++--- google/generativeai/types/permission_types.py | 24 +++--- google/generativeai/types/retriever_types.py | 73 ++++++++++--------- tests/test_models.py | 2 +- tests/test_protos.py | 2 +- 12 files changed, 107 insertions(+), 99 deletions(-) diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 25b45d95a..0d12fa929 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -20,6 +20,7 @@ from typing import Any, Iterable, Union, Mapping, Optional from typing_extensions import TypedDict +import google.ai.generativelanguage as glm from google.generativeai import protos from google.generativeai.client import ( @@ -244,7 +245,7 @@ def generate_answer( answer_style: AnswerStyle | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, - client: protos.GenerativeServiceClient | None = None, + client: glm.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -281,7 +282,7 @@ def generate_answer( answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. temperature: Controls the randomness of the output. - client: If you're not relying on a default client, you pass a `protos.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. request_options: Options for the request. Returns: @@ -317,7 +318,7 @@ async def generate_answer_async( answer_style: AnswerStyle | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, - client: protos.GenerativeServiceClient | None = None, + client: glm.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -335,7 +336,7 @@ async def generate_answer_async( answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. temperature: Controls the randomness of the output. - client: If you're not relying on a default client, you pass a `protos.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. Returns: A `types.Answer` containing the model's text answer response. diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 172aa76e4..a2fe5cbd2 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -317,7 +317,7 @@ def chat( top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, - client: protos.DiscussServiceClient | None = None, + client: glm.DiscussServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: """Calls the API and returns a `types.ChatResponse` containing the response. @@ -384,7 +384,7 @@ def chat( prompt: You may pass a `types.MessagePromptOptions` **instead** of a setting `context`/`examples`/`messages`, but not both. client: If you're not relying on the default client, you pass a - `protos.DiscussServiceClient` instead. + `glm.DiscussServiceClient` instead. request_options: Options for the request. Returns: @@ -417,7 +417,7 @@ async def chat_async( top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, - client: protos.DiscussServiceAsyncClient | None = None, + client: glm.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: request = _make_generate_message_request( @@ -447,7 +447,7 @@ async def chat_async( @string_utils.set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): - _client: protos.DiscussServiceClient | None = dataclasses.field( + _client: glm.DiscussServiceClient | None = dataclasses.field( default=lambda: None, repr=False ) @@ -498,7 +498,7 @@ def reply( async def reply_async( self, message: discuss_types.MessageOptions ) -> discuss_types.ChatResponse: - if isinstance(self._client, protos.DiscussServiceClient): + if isinstance(self._client, glm.DiscussServiceClient): raise TypeError( f"reply_async can't be called on a non-async client, use reply instead." ) @@ -514,7 +514,7 @@ async def reply_async( def _build_chat_response( request: protos.GenerateMessageRequest, response: protos.GenerateMessageResponse, - client: protos.DiscussServiceClient | protos.DiscussServiceAsyncClient, + client: glm.DiscussServiceClient | protos.DiscussServiceAsyncClient, ) -> ChatResponse: request = type(request).to_dict(request) prompt = request.pop("prompt") @@ -540,7 +540,7 @@ def _build_chat_response( def _generate_response( request: protos.GenerateMessageRequest, - client: protos.DiscussServiceClient | None = None, + client: glm.DiscussServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: @@ -556,7 +556,7 @@ def _generate_response( async def _generate_response_async( request: protos.GenerateMessageRequest, - client: protos.DiscussServiceAsyncClient | None = None, + client: glm.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: @@ -577,7 +577,7 @@ def count_message_tokens( examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, - client: protos.DiscussServiceAsyncClient | None = None, + client: glm.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.TokenCount: model = model_types.make_model_name(model) diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index b62e6b450..fcf0cf3d8 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -17,6 +17,7 @@ import itertools from typing import Any, Iterable, overload, TypeVar, Union, Mapping +import google.ai.generativelanguage as glm from google.generativeai import protos from google.generativeai.client import get_default_generative_client @@ -101,7 +102,7 @@ def embed_content( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: protos.GenerativeServiceClient | None = None, + client: glm.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -113,7 +114,7 @@ def embed_content( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: protos.GenerativeServiceClient | None = None, + client: glm.GenerativeServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -124,7 +125,7 @@ def embed_content( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: protos.GenerativeServiceClient = None, + client: glm.GenerativeServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create embeddings for content passed in. @@ -221,7 +222,7 @@ async def embed_content_async( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: protos.GenerativeServiceAsyncClient | None = None, + client: glm.GenerativeServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -233,7 +234,7 @@ async def embed_content_async( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: protos.GenerativeServiceAsyncClient | None = None, + client: glm.GenerativeServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -244,7 +245,7 @@ async def embed_content_async( task_type: EmbeddingTaskTypeOptions | None = None, title: str | None = None, output_dimensionality: int | None = None, - client: protos.GenerativeServiceAsyncClient = None, + client: glm.GenerativeServiceAsyncClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """The async version of `genai.embed_content`.""" diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 3d914dfa7..e4453c1ca 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -17,6 +17,8 @@ import typing from typing import Any, Literal +import google.ai.generativelanguage as glm + from google.generativeai import protos from google.generativeai import operations from google.generativeai.client import get_default_model_client @@ -137,7 +139,7 @@ def get_tuned_model( def get_base_model_name( - model: model_types.AnyModelNameOptions, client: protos.ModelServiceClient | None = None + model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None ): if isinstance(model, str): if model.startswith("tunedModels/"): @@ -164,7 +166,7 @@ def get_base_model_name( def list_models( *, page_size: int | None = 50, - client: protos.ModelServiceClient | None = None, + client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.ModelsIterable: """Lists available models. @@ -177,7 +179,7 @@ def list_models( Args: page_size: How many `types.Models` to fetch per page (api call). - client: You may pass a `protos.ModelServiceClient` instead of using the default client. + client: You may pass a `glm.ModelServiceClient` instead of using the default client. request_options: Options for the request. Yields: @@ -198,7 +200,7 @@ def list_models( def list_tuned_models( *, page_size: int | None = 50, - client: protos.ModelServiceClient | None = None, + client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModelsIterable: """Lists available models. @@ -211,7 +213,7 @@ def list_tuned_models( Args: page_size: How many `types.Models` to fetch per page (api call). - client: You may pass a `protos.ModelServiceClient` instead of using the default client. + client: You may pass a `glm.ModelServiceClient` instead of using the default client. request_options: Options for the request. Yields: @@ -246,7 +248,7 @@ def create_tuned_model( learning_rate: float | None = None, input_key: str = "text_input", output_key: str = "output", - client: protos.ModelServiceClient | None = None, + client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> operations.CreateTunedModelOperation: """Launches a tuning job to create a TunedModel. @@ -360,7 +362,7 @@ def update_tuned_model( tuned_model: protos.TunedModel, updates: None = None, *, - client: protos.ModelServiceClient | None = None, + client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -371,7 +373,7 @@ def update_tuned_model( tuned_model: str, updates: dict[str, Any], *, - client: protos.ModelServiceClient | None = None, + client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -381,7 +383,7 @@ def update_tuned_model( tuned_model: str | protos.TunedModel, updates: dict[str, Any] | None = None, *, - client: protos.ModelServiceClient | None = None, + client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: """Push updates to the tuned model. Only certain attributes are updatable.""" @@ -440,7 +442,7 @@ def _apply_update(thing, path, value): def delete_tuned_model( tuned_model: model_types.TunedModelNameOptions, - client: protos.ModelServiceClient | None = None, + client: glm.ModelServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> None: if request_options is None: diff --git a/google/generativeai/permission.py b/google/generativeai/permission.py index b4672a607..faf31faaf 100644 --- a/google/generativeai/permission.py +++ b/google/generativeai/permission.py @@ -16,6 +16,7 @@ from typing import Callable +import google.ai.generativelanguage as glm from google.generativeai import protos from google.generativeai.types import permission_types @@ -123,7 +124,7 @@ def _construct_name( def get_permission( name: str | None = None, *, - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, resource_name: str | None = None, permission_id: str | int | None = None, resource_type: str | None = None, @@ -152,7 +153,7 @@ def get_permission( async def get_permission_async( name: str | None = None, *, - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, resource_name: str | None = None, permission_id: str | int | None = None, resource_type: str | None = None, diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py index fd6a96541..8912d88d1 100644 --- a/google/generativeai/protos.py +++ b/google/generativeai/protos.py @@ -29,7 +29,7 @@ ``` from google.generativeai import protos -client = protos.DiscussServiceClient( +client = glm.DiscussServiceClient( client_options={'api_key':'YOUR_API_KEY'}) ``` diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index 0b9e83a05..eda36f93b 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -31,7 +31,7 @@ def create_corpus( name: str | None = None, display_name: str | None = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """ @@ -78,7 +78,7 @@ def create_corpus( async def create_corpus_async( name: str | None = None, display_name: str | None = None, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """This is the async version of `retriever.create_corpus`.""" @@ -106,7 +106,7 @@ async def create_corpus_async( def get_corpus( name: str, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """ @@ -139,7 +139,7 @@ def get_corpus( async def get_corpus_async( name: str, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """This is the async version of `retriever.get_corpus`.""" @@ -164,7 +164,7 @@ async def get_corpus_async( def delete_corpus( name: str, force: bool = False, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """ @@ -191,7 +191,7 @@ def delete_corpus( async def delete_corpus_async( name: str, force: bool = False, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """This is the async version of `retriever.delete_corpus`.""" @@ -211,7 +211,7 @@ async def delete_corpus_async( def list_corpora( *, page_size: Optional[int] = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[retriever_types.Corpus]: """ @@ -242,7 +242,7 @@ def list_corpora( async def list_corpora_async( *, page_size: Optional[int] = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[retriever_types.Corpus]: """This is the async version of `retriever.list_corpora`.""" diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 3c7cf612b..83cc13aa6 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -19,6 +19,8 @@ import itertools from typing import Any, Iterable, overload, TypeVar +import google.ai.generativelanguage as glm + from google.generativeai import protos from google.generativeai.client import get_default_text_client @@ -139,7 +141,7 @@ def generate_text( top_k: float | None = None, safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, - client: protos.TextServiceClient | None = None, + client: glm.TextServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.Completion: """Calls the API and returns a `types.Completion` containing the response. @@ -180,7 +182,7 @@ def generate_text( stop_sequences: A set of up to 5 character sequences that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. - client: If you're not relying on a default client, you pass a `protos.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. request_options: Options for the request. Returns: @@ -215,11 +217,11 @@ def __init__(self, **kwargs): def _generate_response( request: protos.GenerateTextRequest, - client: protos.TextServiceClient = None, + client: glm.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ - Generates a response using the provided `protos.GenerateTextRequest` and client. + Generates a response using the provided `glm.GenerateTextRequest` and client. Args: request: The text generation request. @@ -251,7 +253,7 @@ def _generate_response( def count_text_tokens( model: model_types.AnyModelNameOptions, prompt: str, - client: protos.TextServiceClient | None = None, + client: glm.TextServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.TokenCount: base_model = models.get_base_model_name(model) @@ -274,7 +276,7 @@ def count_text_tokens( def generate_embeddings( model: model_types.BaseModelNameOptions, text: str, - client: protos.TextServiceClient = None, + client: glm.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -283,7 +285,7 @@ def generate_embeddings( def generate_embeddings( model: model_types.BaseModelNameOptions, text: Sequence[str], - client: protos.TextServiceClient = None, + client: glm.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -291,7 +293,7 @@ def generate_embeddings( def generate_embeddings( model: model_types.BaseModelNameOptions, text: str | Sequence[str], - client: protos.TextServiceClient = None, + client: glm.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create an embedding for the text passed in. @@ -302,7 +304,7 @@ def generate_embeddings( text: Free-form input text given to the model. Given a string, the model will generate an embedding based on the input text. - client: If you're not relying on a default client, you pass a `protos.TextServiceClient` instead. + client: If you're not relying on a default client, you pass a `glm.TextServiceClient` instead. request_options: Options for the request. diff --git a/google/generativeai/types/permission_types.py b/google/generativeai/types/permission_types.py index f03657e53..86fd44ac9 100644 --- a/google/generativeai/types/permission_types.py +++ b/google/generativeai/types/permission_types.py @@ -101,7 +101,7 @@ class Permission: def delete( self, - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, ) -> None: """ Delete permission (self). @@ -113,7 +113,7 @@ def delete( async def delete_async( self, - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, ) -> None: """ This is the async version of `Permission.delete`. @@ -133,7 +133,7 @@ def _apply_update(self, path, value): def update( self, updates: dict[str, Any], - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, ) -> Permission: """ Update a list of fields for a specified permission. @@ -170,7 +170,7 @@ def update( async def update_async( self, updates: dict[str, Any], - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, ) -> Permission: """ This is the async version of `Permission.update`. @@ -212,7 +212,7 @@ def to_dict(self) -> dict[str, Any]: def get( cls, name: str, - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, ) -> Permission: """ Get information about a specific permission. @@ -234,7 +234,7 @@ def get( async def get_async( cls, name: str, - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, ) -> Permission: """ This is the async version of `Permission.get`. @@ -294,7 +294,7 @@ def create( role: RoleOptions, grantee_type: Optional[GranteeTypeOptions] = None, email_address: Optional[str] = None, - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, ) -> Permission: """ Create a new permission on a resource (self). @@ -327,7 +327,7 @@ async def create_async( role: RoleOptions, grantee_type: Optional[GranteeTypeOptions] = None, email_address: Optional[str] = None, - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, ) -> Permission: """ This is the async version of `PermissionAdapter.create_permission`. @@ -345,7 +345,7 @@ async def create_async( def list( self, page_size: Optional[int] = None, - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, ) -> Iterable[Permission]: """ List `Permission`s enforced on a resource (self). @@ -370,7 +370,7 @@ def list( async def list_async( self, page_size: Optional[int] = None, - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, ) -> AsyncIterable[Permission]: """ This is the async version of `PermissionAdapter.list_permissions`. @@ -388,7 +388,7 @@ async def list_async( def transfer_ownership( self, email_address: str, - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, ) -> None: """ Transfer ownership of a resource (self) to a new owner. @@ -409,7 +409,7 @@ def transfer_ownership( async def transfer_ownership_async( self, email_address: str, - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, ) -> None: """This is the async version of `PermissionAdapter.transfer_ownership`.""" if self.parent.startswith("corpora"): diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index 3af5e4081..cb41b0f76 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -21,6 +21,7 @@ from typing import Any, AsyncIterable, Optional, Union, Iterable, Mapping from typing_extensions import deprecated # type: ignore +import google.ai.generativelanguage as glm from google.generativeai import protos from google.protobuf import field_mask_pb2 @@ -262,7 +263,7 @@ def create_document( name: str | None = None, display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ @@ -313,7 +314,7 @@ async def create_document_async( name: str | None = None, display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.create_document`.""" @@ -347,7 +348,7 @@ async def create_document_async( def get_document( self, name: str, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ @@ -376,7 +377,7 @@ def get_document( async def get_document_async( self, name: str, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.get_document`.""" @@ -402,7 +403,7 @@ def _apply_update(self, path, value): def update( self, updates: dict[str, Any], - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -440,7 +441,7 @@ def update( async def update_async( self, updates: dict[str, Any], - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.update`.""" @@ -471,7 +472,7 @@ def query( query: str, metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """ @@ -525,7 +526,7 @@ async def query_async( query: str, metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """This is the async version of `Corpus.query`.""" @@ -567,7 +568,7 @@ def delete_document( self, name: str, force: bool = False, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -594,7 +595,7 @@ async def delete_document_async( self, name: str, force: bool = False, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.delete_document`.""" @@ -613,7 +614,7 @@ async def delete_document_async( def list_documents( self, page_size: int | None = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Document]: """ @@ -643,7 +644,7 @@ def list_documents( async def list_documents_async( self, page_size: int | None = None, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Document]: """This is the async version of `Corpus.list_documents`.""" @@ -670,7 +671,7 @@ def create_permission( role: permission_types.RoleOptions, grantee_type: Optional[permission_types.GranteeTypeOptions] = None, email_address: Optional[str] = None, - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, ) -> permission_types.Permission: return self.permissions.create( role=role, grantee_type=grantee_type, email_address=email_address, client=client @@ -685,7 +686,7 @@ async def create_permission_async( role: permission_types.RoleOptions, grantee_type: Optional[permission_types.GranteeTypeOptions] = None, email_address: Optional[str] = None, - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, ) -> permission_types.Permission: return await self.permissions.create_async( role=role, grantee_type=grantee_type, email_address=email_address, client=client @@ -698,7 +699,7 @@ async def create_permission_async( def list_permissions( self, page_size: Optional[int] = None, - client: protos.PermissionServiceClient | None = None, + client: glm.PermissionServiceClient | None = None, ) -> Iterable[permission_types.Permission]: return self.permissions.list(page_size=page_size, client=client) @@ -709,7 +710,7 @@ def list_permissions( async def list_permissions_async( self, page_size: Optional[int] = None, - client: protos.PermissionServiceAsyncClient | None = None, + client: glm.PermissionServiceAsyncClient | None = None, ) -> AsyncIterable[permission_types.Permission]: return self.permissions.list_async(page_size=page_size, client=client) @@ -745,7 +746,7 @@ def create_chunk( data: str | ChunkData, name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """ @@ -804,7 +805,7 @@ async def create_chunk_async( data: str | ChunkData, name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """This is the async version of `Document.create_chunk`.""" @@ -905,7 +906,7 @@ def _make_batch_create_chunk_request( def batch_create_chunks( self, chunks: BatchCreateChunkOptions, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -931,7 +932,7 @@ def batch_create_chunks( async def batch_create_chunks_async( self, chunks: BatchCreateChunkOptions, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_create_chunk`.""" @@ -948,7 +949,7 @@ async def batch_create_chunks_async( def get_chunk( self, name: str, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -977,7 +978,7 @@ def get_chunk( async def get_chunk_async( self, name: str, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.get_chunk`.""" @@ -997,7 +998,7 @@ async def get_chunk_async( def list_chunks( self, page_size: int | None = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Chunk]: """ @@ -1023,7 +1024,7 @@ def list_chunks( async def list_chunks_async( self, page_size: int | None = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Chunk]: """This is the async version of `Document.list_chunks`.""" @@ -1042,7 +1043,7 @@ def query( query: str, metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """ @@ -1095,7 +1096,7 @@ async def query_async( query: str, metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """This is the async version of `Document.query`.""" @@ -1142,7 +1143,7 @@ def _apply_update(self, path, value): def update( self, updates: dict[str, Any], - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -1179,7 +1180,7 @@ def update( async def update_async( self, updates: dict[str, Any], - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.update`.""" @@ -1207,7 +1208,7 @@ async def update_async( def batch_update_chunks( self, chunks: BatchUpdateChunksOptions, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -1306,7 +1307,7 @@ def batch_update_chunks( async def batch_update_chunks_async( self, chunks: BatchUpdateChunksOptions, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_update_chunks`.""" @@ -1396,7 +1397,7 @@ async def batch_update_chunks_async( def delete_chunk( self, name: str, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """ @@ -1421,7 +1422,7 @@ def delete_chunk( async def delete_chunk_async( self, name: str, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """This is the async version of `Document.delete_chunk`.""" @@ -1440,7 +1441,7 @@ async def delete_chunk_async( def batch_delete_chunks( self, chunks: BatchDeleteChunkOptions, - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -1473,7 +1474,7 @@ def batch_delete_chunks( async def batch_delete_chunks_async( self, chunks: BatchDeleteChunkOptions, - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_delete_chunks`.""" @@ -1579,7 +1580,7 @@ def _apply_update(self, path, value): def update( self, updates: dict[str, Any], - client: protos.RetrieverServiceClient | None = None, + client: glm.RetrieverServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """ @@ -1628,7 +1629,7 @@ def update( async def update_async( self, updates: dict[str, Any], - client: protos.RetrieverServiceAsyncClient | None = None, + client: glm.RetrieverServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Chunk.update`.""" diff --git a/tests/test_models.py b/tests/test_models.py index 67d533df4..23f80913a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -45,7 +45,7 @@ def setUp(self): client._client_manager.clients["model"] = self.client # TODO(markdaoust): Check if typechecking works better if wee define this as a - # subclass of `protos.ModelServiceClient`, would pyi files for `protos. help? + # subclass of `glm.ModelServiceClient`, would pyi files for `glm`. help? def add_client_method(f): name = f.__name__ setattr(self.client, name, f) diff --git a/tests/test_protos.py b/tests/test_protos.py index f0d0497b5..1a31b5c9c 100644 --- a/tests/test_protos.py +++ b/tests/test_protos.py @@ -36,6 +36,6 @@ def test_check_glm_imports(self): content = fpath.read_text() self.assertNotRegex( content, - "import google.ai.generativelanguage|from google.ai import generativelanguage", + "import google\.ai\.generativelanguage|from google\.ai import generativelanguage", msg=f"generativelanguage found in {fpath}", ) From a2704b73c203607ec962b28b5a3c84e3444e08ce Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 23 May 2024 15:18:38 -0700 Subject: [PATCH 08/16] Fix tests Change-Id: I8a636fb634fd079a892cb99170a12c0613887ccf --- google/generativeai/client.py | 3 +- google/generativeai/discuss.py | 4 +- google/generativeai/models.py | 6 +- google/generativeai/permission.py | 1 - google/generativeai/text.py | 2 +- google/generativeai/types/file_types.py | 4 +- google/generativeai/types/retriever_types.py | 21 +--- google/generativeai/types/safety_types.py | 2 +- tests/test_files.py | 27 ++--- tests/test_generative_models.py | 106 +++---------------- tests/test_protos.py | 23 ++-- tests/test_safety.py | 14 +-- 12 files changed, 61 insertions(+), 152 deletions(-) diff --git a/google/generativeai/client.py b/google/generativeai/client.py index d969889d0..40c2bdcaf 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -10,6 +10,7 @@ import httplib2 import google.ai.generativelanguage as glm +import google.generativeai.protos as protos from google.auth import credentials as ga_credentials from google.auth import exceptions as ga_exceptions @@ -76,7 +77,7 @@ def create_file( name: str | None = None, display_name: str | None = None, resumable: bool = True, - ) -> glm.File: + ) -> protos.File: if self._discovery_api is None: self._setup_discovery_api() diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 273d520e4..448347b41 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -451,9 +451,7 @@ async def chat_async( @string_utils.set_doc(discuss_types.ChatResponse.__doc__) @dataclasses.dataclass(**DATACLASS_KWARGS, init=False) class ChatResponse(discuss_types.ChatResponse): - _client: glm.DiscussServiceClient | None = dataclasses.field( - default=lambda: None, repr=False - ) + _client: glm.DiscussServiceClient | None = dataclasses.field(default=lambda: None, repr=False) def __init__(self, **kwargs): for key, value in kwargs.items(): diff --git a/google/generativeai/models.py b/google/generativeai/models.py index f25357a90..9ba0745c1 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -166,7 +166,7 @@ def get_base_model_name( else: raise TypeError( f"Invalid model: The provided model '{model}' is not recognized or supported. " - "Supported types are: str, model_types.TunedModel, model_types.Model, glm.Model, and glm.TunedModel." + "Supported types are: str, model_types.TunedModel, model_types.Model, protos.Model, and protos.TunedModel." ) return base_model @@ -432,8 +432,8 @@ def update_tuned_model( field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb) else: raise TypeError( - "Invalid argument type: In the function `update_tuned_model(tuned_model:dict|glm.TunedModel)`, the " - f"`tuned_model` argument must be of type `dict` or `glm.TunedModel`. Received type: {type(tuned_model).__name__}." + "Invalid argument type: In the function `update_tuned_model(tuned_model:dict|protos.TunedModel)`, the " + f"`tuned_model` argument must be of type `dict` or `protos.TunedModel`. Received type: {type(tuned_model).__name__}." ) result = client.update_tuned_model( diff --git a/google/generativeai/permission.py b/google/generativeai/permission.py index b83cdc8a1..b2b7c15e1 100644 --- a/google/generativeai/permission.py +++ b/google/generativeai/permission.py @@ -17,7 +17,6 @@ from typing import Callable import google.ai.generativelanguage as glm -from google.generativeai import protos from google.generativeai.types import permission_types from google.generativeai.types import retriever_types diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 991138498..2a6267661 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -223,7 +223,7 @@ def _generate_response( request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ - Generates a response using the provided `glm.GenerateTextRequest` and client. + Generates a response using the provided `protos.GenerateTextRequest` and client. Args: request: The text generation request. diff --git a/google/generativeai/types/file_types.py b/google/generativeai/types/file_types.py index 3acaf12d3..ef251e296 100644 --- a/google/generativeai/types/file_types.py +++ b/google/generativeai/types/file_types.py @@ -30,7 +30,7 @@ def __init__(self, proto: protos.File | File | dict): proto = proto.to_proto() self._proto = protos.File(proto) - def to_proto(self) -> glm.File: + def to_proto(self) -> protos.File: return self._proto @property @@ -74,7 +74,7 @@ def state(self) -> protos.File.State: return self._proto.state @property - def video_metadata(self) -> glm.VideoMetadata: + def video_metadata(self) -> protos.VideoMetadata: return self._proto.video_metadata @property diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index a26e5cab9..9931ee58d 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -201,7 +201,6 @@ def _to_proto(self): ) return protos.CustomMetadata(key=self.key, **kwargs) - @classmethod def _from_dict(cls, cm): key = cm["key"] @@ -1402,12 +1401,8 @@ async def batch_update_chunks_async( ) else: raise TypeError( -<<<<<<< HEAD - "The `chunks` parameter must be a list of protos.UpdateChunkRequests," + "Invalid input: The 'chunks' parameter must be a list of 'protos.UpdateChunkRequests', " "dictionaries, or tuples of dictionaries." -======= - "Invalid input: The 'chunks' parameter must be a list of 'glm.UpdateChunkRequests', dictionaries, or tuples of dictionaries." ->>>>>>> main ) request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) @@ -1488,11 +1483,8 @@ def batch_delete_chunks( client.batch_delete_chunks(request, **request_options) else: raise ValueError( -<<<<<<< HEAD - "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `protos.DeleteChunkRequest`s." -======= - "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple 'glm.DeleteChunkRequest's." ->>>>>>> main + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, " + "or multiple 'protos.DeleteChunkRequest's." ) async def batch_delete_chunks_async( @@ -1519,11 +1511,8 @@ async def batch_delete_chunks_async( await client.batch_delete_chunks(request, **request_options) else: raise ValueError( -<<<<<<< HEAD - "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `protos.DeleteChunkRequest`s." -======= - "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple 'glm.DeleteChunkRequest's." ->>>>>>> main + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, " + "or multiple 'protos.DeleteChunkRequest's." ) def to_dict(self) -> dict[str, Any]: diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 4bf350818..74da06e45 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -225,7 +225,7 @@ def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict else: # Iterable result = {} for setting in settings: - if isinstance(setting, glm.SafetySetting): + if isinstance(setting, protos.SafetySetting): result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold) elif isinstance(setting, dict): result[to_harm_category(setting["category"])] = to_block_threshold( diff --git a/tests/test_files.py b/tests/test_files.py index 333ec1e2a..7d9139450 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -22,10 +22,10 @@ import pathlib import google -import google.ai.generativelanguage as glm import google.generativeai as genai from google.generativeai import client as client_lib +from google.generativeai import protos from absl.testing import parameterized @@ -43,7 +43,7 @@ def create_file( name: Union[str, None] = None, display_name: Union[str, None] = None, resumable: bool = True, - ) -> glm.File: + ) -> protos.File: self.observed_requests.append( dict( path=path, @@ -57,24 +57,24 @@ def create_file( def get_file( self, - request: glm.GetFileRequest, + request: protos.GetFileRequest, **kwargs, - ) -> glm.File: + ) -> protos.File: self.observed_requests.append(request) return self.responses["get_file"].pop(0) def list_files( self, - request: glm.ListFilesRequest, + request: protos.ListFilesRequest, **kwargs, - ) -> Iterable[glm.File]: + ) -> Iterable[protos.File]: self.observed_requests.append(request) for f in self.responses["list_files"].pop(0): yield f def delete_file( self, - request: glm.DeleteFileRequest, + request: protos.DeleteFileRequest, **kwargs, ): self.observed_requests.append(request) @@ -97,7 +97,7 @@ def responses(self): def test_video_metadata(self): self.responses["create_file"].append( - glm.File( + protos.File( uri="https://test", state="ACTIVE", video_metadata=dict(video_duration=datetime.timedelta(seconds=30)), @@ -108,7 +108,8 @@ def test_video_metadata(self): f = genai.upload_file(path="dummy") self.assertEqual(google.rpc.status_pb2.Status(code=7, message="ok?"), f.error) self.assertEqual( - glm.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))), f.video_metadata + protos.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))), + f.video_metadata, ) @parameterized.named_parameters( @@ -123,11 +124,11 @@ def test_video_metadata(self): ), dict( testcase_name="FileData", - file_data=glm.FileData(file_uri="https://test_uri"), + file_data=protos.FileData(file_uri="https://test_uri"), ), dict( - testcase_name="glm.File", - file_data=glm.File(uri="https://test_uri"), + testcase_name="protos.File", + file_data=protos.File(uri="https://test_uri"), ), dict( testcase_name="file_types.File", @@ -137,4 +138,4 @@ def test_video_metadata(self): ) def test_to_file_data(self, file_data): file_data = file_types.to_file_data(file_data) - self.assertEqual(glm.FileData(file_uri="https://test_uri"), file_data) + self.assertEqual(protos.FileData(file_uri="https://test_uri"), file_data) diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 4d9e4120f..0ece77e94 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -48,10 +48,10 @@ def __init__(self, test): def generate_content( self, - request: glm.GenerateContentRequest, + request: protos.GenerateContentRequest, **kwargs, - ) -> glm.GenerateContentResponse: - self.test.assertIsInstance(request, glm.GenerateContentRequest) + ) -> protos.GenerateContentResponse: + self.test.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["generate_content"].pop(0) @@ -59,9 +59,9 @@ def generate_content( def stream_generate_content( self, - request: glm.GetModelRequest, + request: protos.GetModelRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["stream_generate_content"].pop(0) @@ -69,9 +69,9 @@ def stream_generate_content( def count_tokens( self, - request: glm.CountTokensRequest, + request: protos.CountTokensRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["count_tokens"].pop(0) @@ -97,45 +97,6 @@ def setUp(self): self.client = MockGenerativeServiceClient(self) client_lib._client_manager.clients["generative"] = self.client -<<<<<<< HEAD - def add_client_method(f): - name = f.__name__ - setattr(self.client, name, f) - return f - - self.observed_requests = [] - self.responses = collections.defaultdict(list) - - @add_client_method - def generate_content( - request: protos.GenerateContentRequest, - **kwargs, - ) -> protos.GenerateContentResponse: - self.assertIsInstance(request, protos.GenerateContentRequest) - self.observed_requests.append(request) - response = self.responses["generate_content"].pop(0) - return response - - @add_client_method - def stream_generate_content( - request: protos.GetModelRequest, - **kwargs, - ) -> Iterable[protos.GenerateContentResponse]: - self.observed_requests.append(request) - response = self.responses["stream_generate_content"].pop(0) - return response - - @add_client_method - def count_tokens( - request: protos.CountTokensRequest, - **kwargs, - ) -> Iterable[protos.GenerateContentResponse]: - self.observed_requests.append(request) - response = self.responses["count_tokens"].pop(0) - return response - -======= ->>>>>>> main def test_hello(self): # Generate text from text prompt model = generative_models.GenerativeModel(model_name="gemini-pro") @@ -215,13 +176,8 @@ def test_generation_config_overwrite(self, config1, config2): "list-dict", [ dict( -<<<<<<< HEAD - category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, -======= - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ->>>>>>> main ), ], [ @@ -231,27 +187,15 @@ def test_generation_config_overwrite(self, config1, config2): [ "object", [ -<<<<<<< HEAD protos.SafetySetting( - category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ protos.SafetySetting( - category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, -======= - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - ), - ], - [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ->>>>>>> main + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ), ], ], @@ -270,40 +214,22 @@ def test_safety_overwrite(self, safe1, safe2): danger = [ s for s in self.observed_requests[-1].safety_settings - if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + if s.category == protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT ] self.assertEqual( -<<<<<<< HEAD - self.observed_requests[-1].safety_settings[0].category, - protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, - protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, -======= danger[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ->>>>>>> main + protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) _ = model.generate_content("hello", safety_settings=safe2) danger = [ s for s in self.observed_requests[-1].safety_settings - if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + if s.category == protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT ] self.assertEqual( -<<<<<<< HEAD - self.observed_requests[-1].safety_settings[0].category, - protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, - protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, -======= danger[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ->>>>>>> main + protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ) def test_stream_basic(self): @@ -1308,7 +1234,7 @@ def test_count_tokens_called_with_request_options(self): def test_chat_with_request_options(self): self.responses["generate_content"].append( - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "STOP"}], } diff --git a/tests/test_protos.py b/tests/test_protos.py index 1a31b5c9c..179403958 100644 --- a/tests/test_protos.py +++ b/tests/test_protos.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pathlib +import re from absl.testing import parameterized import google.ai.generativelanguage as glm @@ -24,18 +25,12 @@ class UnitTests(parameterized.TestCase): def test_check_glm_imports(self): for fpath in ROOT.rglob("*.py"): - if fpath.name in [ - "client.py", - "discuss.py", - "test_protos.py", - "test_client.py", - "build_docs.py", - ]: - continue - content = fpath.read_text() - self.assertNotRegex( - content, - "import google\.ai\.generativelanguage|from google\.ai import generativelanguage", - msg=f"generativelanguage found in {fpath}", - ) + for match in re.findall("glm\.\w+", content): + if "__" in match: + continue + self.assertIn( + "Client", + match, + msg=f"Bad `glm.` usage, use `genai.protos` instead,\n in {fpath}", + ) diff --git a/tests/test_safety.py b/tests/test_safety.py index f3efc4aca..2ac8aca46 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -15,26 +15,26 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm from google.generativeai.types import safety_types +from google.generativeai import protos class SafetyTests(parameterized.TestCase): """Tests are in order with the design doc.""" @parameterized.named_parameters( - ["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], + ["block_threshold", protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], ["block_threshold2", "medium"], ["block_threshold3", 2], ["dict", {"danger": "medium"}], ["dict2", {"danger": 2}], - ["dict3", {"danger": glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], + ["dict3", {"danger": protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], [ "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, ), ], ], @@ -48,8 +48,8 @@ class SafetyTests(parameterized.TestCase): def test_safety_overwrite(self, setting): setting = safety_types.to_easy_safety_dict(setting) self.assertEqual( - setting[glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], - glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + setting[protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], + protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, ) From f3c697309656f3dc5cd256af81ee436b18e93e79 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 23 May 2024 15:23:07 -0700 Subject: [PATCH 09/16] add import Change-Id: I517171389801ef249cd478f98798181da83bef69 --- google/generativeai/types/permission_types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/generativeai/types/permission_types.py b/google/generativeai/types/permission_types.py index 49c356509..1df831db0 100644 --- a/google/generativeai/types/permission_types.py +++ b/google/generativeai/types/permission_types.py @@ -18,6 +18,7 @@ from typing import Optional, Union, Any, Iterable, AsyncIterable import re +import google.ai.generativelanguage as glm from google.generativeai import protos from google.protobuf import field_mask_pb2 From eee354343f8c18cf449ae9de977ccb9bdf94abac Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 23 May 2024 15:40:33 -0700 Subject: [PATCH 10/16] fix import Change-Id: I8921c0caaa9b902ebde682ead31a2444298c2c9c --- google/generativeai/retriever.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index 18f5cc060..53c90140a 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -14,11 +14,10 @@ # limitations under the License. from __future__ import annotations -import re -import string -import dataclasses -from typing import Any, AsyncIterable, Iterable, Optional +from typing import AsyncIterable, Iterable, Optional + +import google.ai.generativelanguage as glm from google.generativeai import protos from google.generativeai.client import get_default_retriever_client From 2ea05670f48c0472944c0993c1e276bbcbcc2339 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Thu, 23 May 2024 16:56:44 -0700 Subject: [PATCH 11/16] Update docstring Change-Id: I1f6b3b9b9521baa8812a908431bf58c623860733 --- google/generativeai/protos.py | 77 +++++++++++++++-------------------- 1 file changed, 32 insertions(+), 45 deletions(-) diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py index 8912d88d1..60d84d867 100644 --- a/google/generativeai/protos.py +++ b/google/generativeai/protos.py @@ -13,64 +13,51 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This module publishes the proto classes from `google.ai.generativelanguage`. +This module publishes the ProtoBuffer "Message" classes used by the API. -`google.ai.generativelanguage` is a low-level auto-generated client library for the Gemini API. +ProtoBufers are Google API's serilization format. They are strongly typed and efficient. -It is built using the same tooling as Google Cloud client libraries, and will be quite familiar if you've used -those before. +The `genai` SDK tries to be permissive about what objects it will accept from a user, but in the end +the SDK always converts input to an apropriate Proto Message object to send as the request. -While we encourage Python users to access the Geini API using the `google.generativeai` package (aka `palm`), -the lower level package is also available. +If you have any uncertianty about what the API may accept or return, these classes provide the +complete/unambiguous answer. They come from the `google-ai-generativelanguage` package which is +generated from a snapshot of the API definition. -Each method in the Gemini API is connected to one of the client classes. Pass your API-key to the class' `client_options` -when initializing a client: +>>> from google.generativeai import protos +>>> import inspect +>>> print(inspect.getsource(protos.GenerateContentRequest)) -``` -from google.generativeai import protos +Proto classes can have "oneof" fields. Use `in` to check which `oneof` field is set. -client = glm.DiscussServiceClient( - client_options={'api_key':'YOUR_API_KEY'}) -``` +>>> p = protos.Part(text='hello') +>>> 'text' in p +True +>>> p.inline_data = {'mime_type':'image/png', 'data': b'PNG'} +>>> type(p.inline_data) is protos.Blob +True +>>> 'inline_data' in p +False +>>> 'text' in p +True -To call the api, pass an appropriate request-proto-object. For the `DiscussServiceClient.generate_message` pass -a `generativelanguage.GenerateMessageRequest` instance: +Instances of all Message classes can be converted a JSON compatible dict with the following construct: -``` -request = protos.GenerateMessageRequest( - model='models/chat-bison-001', - prompt=protos.MessagePrompt( - messages=[protos.Message(content='Hello!')])) +>>> p_dict = type(p).to_dict(p) +>>> p_dict +{'inline_data': {'mime_type': 'image/png', 'data': 'UE5H'}} -client.generate_message(request) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` +Bytes are base64 encoded. -For simplicity: +Note when converting that `to_dict` accepts additional arguments, the -* The API methods also accept key-word arguments. -* Anywhere you might pass a proto-object, the library will also accept simple python structures. +- `use_integers_for_enums:bool = True`, Set it to `False` to replace enum int values with their string + names in the output +- ` including_default_value_fields:bool = True`, Set it to `False` to reduce the verbosity of the output. -So the following is equivalent to the previous example: +Additional arguments are described in the docstring: -``` -client.generate_message( - model='models/chat-bison-001', - prompt={'messages':[{'content':'Hello!'}]}) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` +>>> help(proto.Part.to_dict) """ From 37b5399b125553e49d107c965e41e8a922bebb5a Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 24 May 2024 15:42:58 -0700 Subject: [PATCH 12/16] spelling Change-Id: I0421a35687ed14b1a5ca3b496cafd91514c4de92 --- google/generativeai/protos.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py index 60d84d867..70423158d 100644 --- a/google/generativeai/protos.py +++ b/google/generativeai/protos.py @@ -18,9 +18,9 @@ ProtoBufers are Google API's serilization format. They are strongly typed and efficient. The `genai` SDK tries to be permissive about what objects it will accept from a user, but in the end -the SDK always converts input to an apropriate Proto Message object to send as the request. +the SDK always converts input to an appropriate Proto Message object to send as the request. -If you have any uncertianty about what the API may accept or return, these classes provide the +If you have any uncertainty about what the API may accept or return, these classes provide the complete/unambiguous answer. They come from the `google-ai-generativelanguage` package which is generated from a snapshot of the API definition. @@ -41,15 +41,23 @@ >>> 'text' in p True -Instances of all Message classes can be converted a JSON compatible dict with the following construct: +Instances of all Message classes can be converted into JSON compatible dictionaries with the following construct +(Bytes are base64 encoded): >>> p_dict = type(p).to_dict(p) >>> p_dict {'inline_data': {'mime_type': 'image/png', 'data': 'UE5H'}} -Bytes are base64 encoded. +A compatible dict can ve converted to an instance of a Message class by passing it as the first argument to the +constructor: -Note when converting that `to_dict` accepts additional arguments, the +>>> p = protos.Part(p_dict) +inline_data { + mime_type: "image/png" + data: "PNG" +} + +Note when converting that `to_dict` accepts additional arguments: - `use_integers_for_enums:bool = True`, Set it to `False` to replace enum int values with their string names in the output @@ -58,7 +66,6 @@ Additional arguments are described in the docstring: >>> help(proto.Part.to_dict) - """ from google.ai.generativelanguage_v1beta.types import * From 7c6d9279b4f8979d80df2cb8c13d375fbd0eac19 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 24 May 2024 16:04:14 -0700 Subject: [PATCH 13/16] remove unused imports Change-Id: Ifc791796e36668eb473fd0fffea4833b1a062188 --- tests/test_protos.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_protos.py b/tests/test_protos.py index 179403958..b1b477b86 100644 --- a/tests/test_protos.py +++ b/tests/test_protos.py @@ -16,8 +16,6 @@ import re from absl.testing import parameterized -import google.ai.generativelanguage as glm -from google.generativeai import responder ROOT = pathlib.Path(__file__).parent.parent From c884dc2afca90a76375b7c32e676d9b0986a1bbe Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 28 May 2024 09:31:03 -0700 Subject: [PATCH 14/16] Resolve review coments. Change-Id: Ieb900190f42e883337028ae25da3be819507db4a --- google/generativeai/protos.py | 6 +++--- tests/test_protos.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py index 70423158d..9bc47a6b4 100644 --- a/google/generativeai/protos.py +++ b/google/generativeai/protos.py @@ -26,7 +26,7 @@ >>> from google.generativeai import protos >>> import inspect ->>> print(inspect.getsource(protos.GenerateContentRequest)) +>>> print(inspect.getsource(protos.Part)) Proto classes can have "oneof" fields. Use `in` to check which `oneof` field is set. @@ -37,9 +37,9 @@ >>> type(p.inline_data) is protos.Blob True >>> 'inline_data' in p -False ->>> 'text' in p True +>>> 'text' in p +False Instances of all Message classes can be converted into JSON compatible dictionaries with the following construct (Bytes are base64 encoded): diff --git a/tests/test_protos.py b/tests/test_protos.py index b1b477b86..1b59b0c6e 100644 --- a/tests/test_protos.py +++ b/tests/test_protos.py @@ -23,10 +23,10 @@ class UnitTests(parameterized.TestCase): def test_check_glm_imports(self): for fpath in ROOT.rglob("*.py"): + if fpath.name == "build_docs.py": + continue content = fpath.read_text() for match in re.findall("glm\.\w+", content): - if "__" in match: - continue self.assertIn( "Client", match, From 4a49640a1b9fd73f176a252b90ca579534fccf9c Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Tue, 28 May 2024 09:36:53 -0700 Subject: [PATCH 15/16] Update docstring. Change-Id: I805473f9aaeb04e922a9f66bb5f40716d42fb738 --- google/generativeai/protos.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py index 9bc47a6b4..92a314643 100644 --- a/google/generativeai/protos.py +++ b/google/generativeai/protos.py @@ -13,12 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This module publishes the ProtoBuffer "Message" classes used by the API. +This module provides low level access to theProtoBuffer "Message" classes used by the API. + +**For typical usage of this SDK you do not need to use any of these classes.** ProtoBufers are Google API's serilization format. They are strongly typed and efficient. The `genai` SDK tries to be permissive about what objects it will accept from a user, but in the end -the SDK always converts input to an appropriate Proto Message object to send as the request. +the SDK always converts input to an appropriate Proto Message object to send as the request. Each API request +has a `*Request` and `*Response` Message defined here. If you have any uncertainty about what the API may accept or return, these classes provide the complete/unambiguous answer. They come from the `google-ai-generativelanguage` package which is @@ -48,7 +51,7 @@ >>> p_dict {'inline_data': {'mime_type': 'image/png', 'data': 'UE5H'}} -A compatible dict can ve converted to an instance of a Message class by passing it as the first argument to the +A compatible dict can be converted to an instance of a Message class by passing it as the first argument to the constructor: >>> p = protos.Part(p_dict) From b134c69be78e87a49e5268ad97c39abccda325bf Mon Sep 17 00:00:00 2001 From: Mark McDonald Date: Thu, 30 May 2024 11:34:22 +0800 Subject: [PATCH 16/16] Fix typo --- google/generativeai/protos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py index 92a314643..010396c75 100644 --- a/google/generativeai/protos.py +++ b/google/generativeai/protos.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -This module provides low level access to theProtoBuffer "Message" classes used by the API. +This module provides low level access to the ProtoBuffer "Message" classes used by the API. **For typical usage of this SDK you do not need to use any of these classes.**