diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e51ac7205..9415df2a8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -62,17 +62,41 @@ This "editable" mode lets you edit the source without needing to reinstall the p ### Testing -Use the builtin unittest package: +To ensure the integrity of the codebase, we have a suite of tests located in the `generative-ai-python/tests` directory. +You can run all these tests using Python's built-in `unittest` module or the `pytest` library. + +For `unittest`, open a terminal and navigate to the root directory of the project. Then, execute the following command: + +``` +python -m unittest discover -s tests + +# or more simply +python -m unittest ``` - python -m unittest + +Alternatively, if you prefer using `pytest`, you can install it using pip: + ``` +pip install pytest +``` + +Then, run the tests with the following command: + +``` +pytest tests + +# or more simply +pytest +``` + Or to debug, use: ```commandline +pip install nose2 + nose2 --debugger -``` ### Type checking diff --git a/docs/build_docs.py b/docs/build_docs.py index eaa6a1ba4..012cd3441 100644 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -44,77 +44,13 @@ # For showing the conditional imports and types in `content_types.py` # grpc must be imported first. typing.TYPE_CHECKING = True -from google import generativeai as palm - +from google import generativeai as genai from tensorflow_docs.api_generator import generate_lib from tensorflow_docs.api_generator import public_api import yaml -glm.__doc__ = """\ -This package, `google.ai.generativelanguage`, is a low-level auto-generated client library for the PaLM API. - -```posix-terminal -pip install google.ai.generativelanguage -``` - -It is built using the same tooling as Google Cloud client libraries, and will be quite familiar if you've used -those before. - -While we encourage Python users to access the PaLM API using the `google.generativeai` package (aka `palm`), -this lower level package is also available. - -Each method in the PaLM API is connected to one of the client classes. Pass your API-key to the class' `client_options` -when initializing a client: - -``` -from google.ai import generativelanguage as glm - -client = glm.DiscussServiceClient( - client_options={'api_key':'YOUR_API_KEY'}) -``` - -To call the api, pass an appropriate request-proto-object. For the `DiscussServiceClient.generate_message` pass -a `generativelanguage.GenerateMessageRequest` instance: - -``` -request = glm.GenerateMessageRequest( - model='models/chat-bison-001', - prompt=glm.MessagePrompt( - messages=[glm.Message(content='Hello!')])) - -client.generate_message(request) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` - -For simplicity: - -* The API methods also accept key-word arguments. -* Anywhere you might pass a proto-object, the library will also accept simple python structures. - -So the following is equivalent to the previous example: - -``` -client.generate_message( - model='models/chat-bison-001', - prompt={'messages':[{'content':'Hello!'}]}) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` -""" - HERE = pathlib.Path(__file__).parent PROJECT_SHORT_NAME = "genai" @@ -139,43 +75,6 @@ ) -class MyFilter: - def __init__(self, base_dirs): - self.filter_base_dirs = public_api.FilterBaseDirs(base_dirs) - - def drop_staticmethods(self, parent, children): - parent = dict(parent.__dict__) - for name, value in children: - if not isinstance(parent.get(name, None), staticmethod): - yield name, value - - def __call__(self, path, parent, children): - if any("generativelanguage" in part for part in path) or "generativeai" in path: - children = self.filter_base_dirs(path, parent, children) - children = public_api.explicit_package_contents_filter(path, parent, children) - - if any("generativelanguage" in part for part in path): - if "ServiceClient" in path[-1] or "ServiceAsyncClient" in path[-1]: - children = list(self.drop_staticmethods(parent, children)) - - return children - - -class MyDocGenerator(generate_lib.DocGenerator): - def make_default_filters(self): - return [ - # filter the api. - public_api.FailIfNestedTooDeep(10), - public_api.filter_module_all, - public_api.add_proto_fields, - public_api.filter_private_symbols, - MyFilter(self._base_dir), # Replaces: public_api.FilterBaseDirs(self._base_dir), - public_api.FilterPrivateMap(self._private_map), - public_api.filter_doc_controls_skip, - public_api.ignore_typing, - ] - - def gen_api_docs(): """Generates api docs for the generative-ai package.""" for name in dir(google): @@ -188,11 +87,11 @@ def gen_api_docs(): """ ) - doc_generator = MyDocGenerator( + doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[("google", google)], + py_modules=[("google.generativeai", genai)], base_dir=( - pathlib.Path(palm.__file__).parent, + pathlib.Path(genai.__file__).parent, pathlib.Path(glm.__file__).parent.parent, ), code_url_prefix=( @@ -201,32 +100,12 @@ def gen_api_docs(): ), search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - callbacks=[], + callbacks=[public_api.explicit_package_contents_filter], ) out_path = pathlib.Path(_OUTPUT_DIR.value) doc_generator.build(out_path) - # Fixup the toc file. - toc_path = out_path / "google/_toc.yaml" - toc = yaml.safe_load(toc_path.read_text()) - assert toc["toc"][0]["title"] == "google" - toc["toc"] = toc["toc"][1:] - toc["toc"][0]["title"] = "google.ai.generativelanguage" - toc["toc"][0]["section"] = toc["toc"][0]["section"][1]["section"] - toc["toc"][0], toc["toc"][1] = toc["toc"][1], toc["toc"][0] - toc_path.write_text(yaml.dump(toc)) - - # remove some dummy files and redirect them to `api/` - (out_path / "google.md").unlink() - (out_path / "google/ai.md").unlink() - redirects_path = out_path / "_redirects.yaml" - redirects = {"redirects": []} - redirects["redirects"].insert(0, {"from": "/api/python/google/ai", "to": "/api/"}) - redirects["redirects"].insert(0, {"from": "/api/python/google", "to": "/api/"}) - redirects["redirects"].insert(0, {"from": "/api/python", "to": "/api/"}) - redirects_path.write_text(yaml.dump(redirects)) - # clear `oneof` junk from proto pages for fpath in out_path.rglob("*.md"): old_content = fpath.read_text() diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 53383a1b3..2b93fc1ce 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -30,8 +30,8 @@ genai.configure(api_key=os.environ['API_KEY']) -model = genai.GenerativeModel(name='gemini-pro') -response = model.generate_content('Please summarise this document: ...') +model = genai.GenerativeModel(name='gemini-1.5-flash') +response = model.generate_content('Teach me about how an LLM works') print(response.text) ``` @@ -42,6 +42,7 @@ from google.generativeai import version +from google.generativeai import protos from google.generativeai import types from google.generativeai.types import GenerationConfig diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index f17a82a17..4bfabbf23 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -21,23 +21,22 @@ from typing_extensions import TypedDict import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import ( get_default_generative_client, get_default_generative_async_client, ) -from google.generativeai import string_utils from google.generativeai.types import model_types -from google.generativeai import models +from google.generativeai.types import helper_types from google.generativeai.types import safety_types from google.generativeai.types import content_types -from google.generativeai.types import answer_types from google.generativeai.types import retriever_types from google.generativeai.types.retriever_types import MetadataFilter DEFAULT_ANSWER_MODEL = "models/aqa" -AnswerStyle = glm.GenerateAnswerRequest.AnswerStyle +AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle AnswerStyleOptions = Union[int, str, AnswerStyle] @@ -68,33 +67,35 @@ def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: GroundingPassageOptions = ( - Union[glm.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], + Union[ + protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType + ], ) GroundingPassagesOptions = Union[ - glm.GroundingPassages, + protos.GroundingPassages, Iterable[GroundingPassageOptions], Mapping[str, content_types.ContentType], ] -def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingPassages: +def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages: """ - Converts the `source` into a `glm.GroundingPassage`. A `GroundingPassages` contains a list of - `glm.GroundingPassage` objects, which each contain a `glm.Contant` and a string `id`. + Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of + `protos.GroundingPassage` objects, which each contain a `protos.Contant` and a string `id`. Args: - source: `Content` or a `GroundingPassagesOptions` that will be converted to glm.GroundingPassages. + source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages. Return: - `glm.GroundingPassages` to be passed into `glm.GenerateAnswer`. + `protos.GroundingPassages` to be passed into `protos.GenerateAnswer`. """ - if isinstance(source, glm.GroundingPassages): + if isinstance(source, protos.GroundingPassages): return source if not isinstance(source, Iterable): raise TypeError( - f"`source` must be a valid `GroundingPassagesOptions` type object got a: `{type(source)}`." + f"Invalid input: The 'source' argument must be an instance of 'GroundingPassagesOptions'. Received a '{type(source).__name__}' object instead." ) passages = [] @@ -102,7 +103,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP source = source.items() for n, data in enumerate(source): - if isinstance(data, glm.GroundingPassage): + if isinstance(data, protos.GroundingPassage): passages.append(data) elif isinstance(data, tuple): id, content = data # tuple must have exactly 2 items. @@ -110,11 +111,11 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP else: passages.append({"id": str(n), "content": content_types.to_content(data)}) - return glm.GroundingPassages(passages=passages) + return protos.GroundingPassages(passages=passages) SourceNameType = Union[ - str, retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document + str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document ] @@ -129,7 +130,7 @@ class SemanticRetrieverConfigDict(TypedDict): SemanticRetrieverConfigOptions = Union[ SourceNameType, SemanticRetrieverConfigDict, - glm.SemanticRetrieverConfig, + protos.SemanticRetrieverConfig, ] @@ -137,7 +138,7 @@ def _maybe_get_source_name(source) -> str | None: if isinstance(source, str): return source elif isinstance( - source, (retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document) + source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document) ): return source.name else: @@ -147,8 +148,8 @@ def _maybe_get_source_name(source) -> str | None: def _make_semantic_retriever_config( source: SemanticRetrieverConfigOptions, query: content_types.ContentsType, -) -> glm.SemanticRetrieverConfig: - if isinstance(source, glm.SemanticRetrieverConfig): +) -> protos.SemanticRetrieverConfig: + if isinstance(source, protos.SemanticRetrieverConfig): return source name = _maybe_get_source_name(source) @@ -158,9 +159,9 @@ def _make_semantic_retriever_config( source["source"] = _maybe_get_source_name(source["source"]) else: raise TypeError( - "Could create a `glm.SemanticRetrieverConfig` from:\n" - f" type: {type(source)}\n" - f" value: {source}" + f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. " + f"Received type: {type(source).__name__}, " + f"Received value: {source}" ) if source["query"] is None: @@ -168,7 +169,7 @@ def _make_semantic_retriever_config( elif isinstance(source["query"], str): source["query"] = content_types.to_content(source["query"]) - return glm.SemanticRetrieverConfig(source) + return protos.SemanticRetrieverConfig(source) def _make_generate_answer_request( @@ -180,9 +181,9 @@ def _make_generate_answer_request( answer_style: AnswerStyle | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, -) -> glm.GenerateAnswerRequest: +) -> protos.GenerateAnswerRequest: """ - Calls the API to generate a grounded answer from the model. + constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. Args: model: Name of the model used to generate the grounded response. @@ -190,43 +191,43 @@ def _make_generate_answer_request( single question to answer. For multi-turn queries, this is a repeated field that contains conversation history and the last `Content` in the list containing the question. inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style for grounded answers. safety_settings: Safety settings for generated output. temperature: The temperature for randomness in the output. Returns: - Call for glm.GenerateAnswerRequest(). + Call for protos.GenerateAnswerRequest(). """ model = model_types.make_model_name(model) contents = content_types.to_contents(contents) if safety_settings: - safety_settings = safety_types.normalize_safety_settings( - safety_settings, harm_category_set="new" - ) + safety_settings = safety_types.normalize_safety_settings(safety_settings) if inline_passages is not None and semantic_retriever is not None: raise ValueError( - "Either `inline_passages` or `semantic_retriever_config` must be set, not both." + f"Invalid configuration: Please set either 'inline_passages' or 'semantic_retriever_config', but not both. " + f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}." ) elif inline_passages is not None: inline_passages = _make_grounding_passages(inline_passages) elif semantic_retriever is not None: semantic_retriever = _make_semantic_retriever_config(semantic_retriever, contents[-1]) else: - TypeError( - f"The source must be either an `inline_passages` xor `semantic_retriever_config`, but both are `None`" + raise TypeError( + f"Invalid configuration: Either 'inline_passages' or 'semantic_retriever_config' must be provided, but currently both are 'None'. " + f"Received for inline_passages: {inline_passages}, and for semantic_retriever: {semantic_retriever}." ) if answer_style: answer_style = to_answer_style(answer_style) - return glm.GenerateAnswerRequest( + return protos.GenerateAnswerRequest( model=model, contents=contents, inline_passages=inline_passages, @@ -247,10 +248,9 @@ def generate_answer( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): - """ - Calls the GenerateAnswer API and returns a `types.Answer` containing the response. + """Calls the GenerateAnswer API and returns a `types.Answer` containing the response. You can pass a literal list of text chunks: @@ -276,9 +276,9 @@ def generate_answer( contents: The question to be answered by the model, grounded in the provided source. inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. @@ -320,7 +320,7 @@ async def generate_answer_async( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Calls the API and returns a `types.Answer` containing the answer. @@ -330,9 +330,9 @@ async def generate_answer_async( contents: The question to be answered by the model, grounded in the provided source. inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. diff --git a/google/generativeai/client.py b/google/generativeai/client.py index e8e91ae7e..40c2bdcaf 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -1,17 +1,20 @@ from __future__ import annotations import os +import contextlib import dataclasses import pathlib -import re import types from typing import Any, cast from collections.abc import Sequence import httplib2 import google.ai.generativelanguage as glm +import google.generativeai.protos as protos from google.auth import credentials as ga_credentials +from google.auth import exceptions as ga_exceptions +from google import auth from google.api_core import client_options as client_options_lib from google.api_core import gapic_v1 from google.api_core import operations_v1 @@ -30,6 +33,18 @@ GENAI_API_DISCOVERY_URL = "https://generativelanguage.googleapis.com/$discovery/rest" +@contextlib.contextmanager +def patch_colab_gce_credentials(): + get_gce = auth._default._get_gce_credentials + if "COLAB_RELEASE_TAG" in os.environ: + auth._default._get_gce_credentials = lambda *args, **kwargs: (None, None) + + try: + yield + finally: + auth._default._get_gce_credentials = get_gce + + class FileServiceClient(glm.FileServiceClient): def __init__(self, *args, **kwargs): self._discovery_api = None @@ -38,7 +53,9 @@ def __init__(self, *args, **kwargs): def _setup_discovery_api(self): api_key = self._client_options.api_key if api_key is None: - raise ValueError("Uploading to the File API requires an API key.") + raise ValueError( + "Invalid operation: Uploading to the File API requires an API key. Please provide a valid API key." + ) request = googleapiclient.http.HttpRequest( http=httplib2.Http(), @@ -60,7 +77,7 @@ def create_file( name: str | None = None, display_name: str | None = None, resumable: bool = True, - ) -> glm.File: + ) -> protos.File: if self._discovery_api is None: self._setup_discovery_api() @@ -81,7 +98,9 @@ def create_file( class FileServiceAsyncClient(glm.FileServiceAsyncClient): async def create_file(self, *args, **kwargs): - raise NotImplementedError("`create_file` is not yet implemented for the async client.") + raise NotImplementedError( + "The `create_file` method is currently not supported for the asynchronous client." + ) @dataclasses.dataclass @@ -109,7 +128,7 @@ def configure( client_info: gapic_v1.client_info.ClientInfo | None = None, default_metadata: Sequence[tuple[str, str]] = (), ) -> None: - """Captures default client configuration. + """Initializes default client configurations using specified parameters or environment variables. If no API key has been provided (either directly, or on `client_options`) and the `GOOGLE_API_KEY` environment variable is set, it will be used as the API key. @@ -135,7 +154,9 @@ def configure( if had_api_key_value: if api_key is not None: - raise ValueError("You can't set both `api_key` and `client_options['api_key']`.") + raise ValueError( + "Invalid configuration: Please set either `api_key` or `client_options['api_key']`, but not both." + ) else: if api_key is None: # If no key is provided explicitly, attempt to load one from the @@ -183,7 +204,17 @@ def make_client(self, name): if not self.client_config: configure() - client = cls(**self.client_config) + try: + with patch_colab_gce_credentials(): + client = cls(**self.client_config) + except ga_exceptions.DefaultCredentialsError as e: + e.args = ( + "\n No API_KEY or ADC found. Please either:\n" + " - Set the `GOOGLE_API_KEY` environment variable.\n" + " - Manually pass the key with `genai.configure(api_key=my_api_key)`.\n" + " - Or set up Application Default Credentials, see https://ai.google.dev/gemini-api/docs/oauth for more information.", + ) + raise e if not self.default_metadata: return client @@ -328,9 +359,9 @@ def get_default_retriever_async_client() -> glm.RetrieverAsyncClient: return _client_manager.get_default_client("retriever_async") -def get_dafault_permission_client() -> glm.PermissionServiceClient: +def get_default_permission_client() -> glm.PermissionServiceClient: return _client_manager.get_default_client("permission") -def get_dafault_permission_async_client() -> glm.PermissionServiceAsyncClient: +def get_default_permission_async_client() -> glm.PermissionServiceAsyncClient: return _client_manager.get_default_client("permission_async") diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 0cc342096..448347b41 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -18,36 +18,38 @@ import sys import textwrap -from typing import Any, Iterable, List, Optional, Union +from typing import Iterable, List import google.ai.generativelanguage as glm from google.generativeai.client import get_default_discuss_client from google.generativeai.client import get_default_discuss_async_client from google.generativeai import string_utils +from google.generativeai import protos from google.generativeai.types import discuss_types +from google.generativeai.types import helper_types from google.generativeai.types import model_types -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types -def _make_message(content: discuss_types.MessageOptions) -> glm.Message: - """Creates a `glm.Message` object from the provided content.""" - if isinstance(content, glm.Message): +def _make_message(content: discuss_types.MessageOptions) -> protos.Message: + """Creates a `protos.Message` object from the provided content.""" + if isinstance(content, protos.Message): return content if isinstance(content, str): - return glm.Message(content=content) + return protos.Message(content=content) else: - return glm.Message(content) + return protos.Message(content) def _make_messages( messages: discuss_types.MessagesOptions, -) -> List[glm.Message]: +) -> List[protos.Message]: """ - Creates a list of `glm.Message` objects from the provided messages. + Creates a list of `protos.Message` objects from the provided messages. This function takes a variety of message content inputs, such as strings, dictionaries, - or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that + or `protos.Message` objects, and creates a list of `protos.Message` objects. It ensures that the authors of the messages alternate appropriately. If authors are not provided, default authors are assigned based on their position in the list. @@ -55,9 +57,9 @@ def _make_messages( messages: The messages to convert. Returns: - A list of `glm.Message` objects with alternating authors. + A list of `protos.Message` objects with alternating authors. """ - if isinstance(messages, (str, dict, glm.Message)): + if isinstance(messages, (str, dict, protos.Message)): messages = [_make_message(messages)] else: messages = [_make_message(message) for message in messages] @@ -68,7 +70,9 @@ def _make_messages( elif len(even_authors) == 1: even_author = even_authors.pop() else: - raise discuss_types.AuthorError("Authors are not strictly alternating") + raise discuss_types.AuthorError( + "Invalid sequence: Authors in the discussion must alternate strictly." + ) odd_authors = set(msg.author for msg in messages[1::2] if msg.author) if not odd_authors: @@ -76,7 +80,9 @@ def _make_messages( elif len(odd_authors) == 1: odd_author = odd_authors.pop() else: - raise discuss_types.AuthorError("Authors are not strictly alternating") + raise discuss_types.AuthorError( + "Invalid sequence: Authors in the discussion must alternate strictly." + ) if all(msg.author for msg in messages): return messages @@ -88,39 +94,39 @@ def _make_messages( return messages -def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: - """Creates a `glm.Example` object from the provided item.""" - if isinstance(item, glm.Example): +def _make_example(item: discuss_types.ExampleOptions) -> protos.Example: + """Creates a `protos.Example` object from the provided item.""" + if isinstance(item, protos.Example): return item if isinstance(item, dict): item = item.copy() item["input"] = _make_message(item["input"]) item["output"] = _make_message(item["output"]) - return glm.Example(item) + return protos.Example(item) if isinstance(item, Iterable): input, output = list(item) - return glm.Example(input=_make_message(input), output=_make_message(output)) + return protos.Example(input=_make_message(input), output=_make_message(output)) # try anyway - return glm.Example(item) + return protos.Example(item) def _make_examples_from_flat( examples: List[discuss_types.MessageOptions], -) -> List[glm.Example]: +) -> List[protos.Example]: """ - Creates a list of `glm.Example` objects from a list of message options. + Creates a list of `protos.Example` objects from a list of message options. This function takes a list of `discuss_types.MessageOptions` and pairs them into - `glm.Example` objects. The input examples must be in pairs to create valid examples. + `protos.Example` objects. The input examples must be in pairs to create valid examples. Args: examples: The list of `discuss_types.MessageOptions`. Returns: - A list of `glm.Example objects` created by pairing up the provided messages. + A list of `protos.Example objects` created by pairing up the provided messages. Raises: ValueError: If the provided list of examples is not of even length. @@ -129,8 +135,8 @@ def _make_examples_from_flat( raise ValueError( textwrap.dedent( f"""\ - You must pass `Primer` objects, pairs of messages, or an *even* number of messages, got: - {len(examples)} messages""" + Invalid input: You must pass either `Primer` objects, pairs of messages, or an even number of messages. + Currently, {len(examples)} messages were provided, which is an odd number.""" ) ) result = [] @@ -140,7 +146,7 @@ def _make_examples_from_flat( pair.append(msg) if n % 2 == 0: continue - primer = glm.Example( + primer = protos.Example( input=pair[0], output=pair[1], ) @@ -151,21 +157,21 @@ def _make_examples_from_flat( def _make_examples( examples: discuss_types.ExamplesOptions, -) -> List[glm.Example]: +) -> List[protos.Example]: """ - Creates a list of `glm.Example` objects from the provided examples. + Creates a list of `protos.Example` objects from the provided examples. This function takes various types of example content inputs and creates a list - of `glm.Example` objects. It handles the conversion of different input types and ensures + of `protos.Example` objects. It handles the conversion of different input types and ensures the appropriate structure for creating valid examples. Args: examples: The examples to convert. Returns: - A list of `glm.Example` objects created from the provided examples. + A list of `protos.Example` objects created from the provided examples. """ - if isinstance(examples, glm.Example): + if isinstance(examples, protos.Example): return [examples] if isinstance(examples, dict): @@ -185,7 +191,7 @@ def _make_examples( else: if not ("input" in first and "output" in first): raise TypeError( - "To create an `Example` from a dict you must supply both `input` and an `output` keys" + "Invalid dictionary format: To create an `Example` instance, the dictionary must contain both `input` and `output` keys." ) else: if isinstance(first, discuss_types.MESSAGE_OPTIONS): @@ -203,11 +209,11 @@ def _make_message_prompt_dict( context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, -) -> glm.MessagePrompt: +) -> protos.MessagePrompt: """ - Creates a `glm.MessagePrompt` object from the provided prompt components. + Creates a `protos.MessagePrompt` object from the provided prompt components. - This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`, + This function constructs a `protos.MessagePrompt` object using the provided `context`, `examples`, or `messages`. It ensures the proper structure and handling of the input components. Either pass a `prompt` or it's component `context`, `examples`, `messages`. @@ -219,7 +225,7 @@ def _make_message_prompt_dict( messages: The messages for the prompt. Returns: - A `glm.MessagePrompt` object created from the provided prompt components. + A `protos.MessagePrompt` object created from the provided prompt components. """ if prompt is None: prompt = dict( @@ -231,10 +237,9 @@ def _make_message_prompt_dict( flat_prompt = (context is not None) or (examples is not None) or (messages is not None) if flat_prompt: raise ValueError( - "You can't set `prompt`, and its fields `(context, examples, messages)`" - " at the same time" + "Invalid configuration: Either `prompt` or its fields `(context, examples, messages)` should be set, but not both simultaneously." ) - if isinstance(prompt, glm.MessagePrompt): + if isinstance(prompt, protos.MessagePrompt): return prompt elif isinstance(prompt, dict): # Always check dict before Iterable. pass @@ -244,7 +249,7 @@ def _make_message_prompt_dict( keys = set(prompt.keys()) if not keys.issubset(discuss_types.MESSAGE_PROMPT_KEYS): raise KeyError( - f"Found extra entries in the prompt dictionary: {keys - discuss_types.MESSAGE_PROMPT_KEYS}" + f"Invalid prompt dictionary: Extra entries found that are not recognized: {keys - discuss_types.MESSAGE_PROMPT_KEYS}. Please check the keys." ) examples = prompt.get("examples", None) @@ -264,12 +269,12 @@ def _make_message_prompt( context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, -) -> glm.MessagePrompt: - """Creates a `glm.MessagePrompt` object from the provided prompt components.""" +) -> protos.MessagePrompt: + """Creates a `protos.MessagePrompt` object from the provided prompt components.""" prompt = _make_message_prompt_dict( prompt=prompt, context=context, examples=examples, messages=messages ) - return glm.MessagePrompt(prompt) + return protos.MessagePrompt(prompt) def _make_generate_message_request( @@ -283,15 +288,15 @@ def _make_generate_message_request( top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, -) -> glm.GenerateMessageRequest: - """Creates a `glm.GenerateMessageRequest` object for generating messages.""" +) -> protos.GenerateMessageRequest: + """Creates a `protos.GenerateMessageRequest` object for generating messages.""" model = model_types.make_model_name(model) prompt = _make_message_prompt( prompt=prompt, context=context, examples=examples, messages=messages ) - return glm.GenerateMessageRequest( + return protos.GenerateMessageRequest( model=model, prompt=prompt, temperature=temperature, @@ -316,9 +321,9 @@ def chat( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: - """Calls the API and returns a `types.ChatResponse` containing the response. + """Calls the API to initiate a chat with a model using provided parameters Args: model: Which model to call, as a string or a `types.Model`. @@ -416,8 +421,9 @@ async def chat_async( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: + """Calls the API asynchronously to initiate a chat with a model using provided parameters""" request = _make_generate_message_request( model=model, context=context, @@ -469,15 +475,16 @@ def last(self, message: discuss_types.MessageOptions): def reply( self, message: discuss_types.MessageOptions, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceAsyncClient): - raise TypeError(f"reply can't be called on an async client, use reply_async instead.") + raise TypeError( + "Invalid operation: The 'reply' method cannot be called on an asynchronous client. Please use the 'reply_async' method instead." + ) if self.last is None: raise ValueError( - "The last response from the model did not return any candidates.\n" - "Check the `.filters` attribute to see why the responses were filtered:\n" - f"{self.filters}" + f"Invalid operation: No candidates returned from the model's last response. " + f"Please inspect the '.filters' attribute to understand why responses were filtered out. Current filters: {self.filters}" ) request = self.to_dict() @@ -496,7 +503,7 @@ async def reply_async( ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceClient): raise TypeError( - f"reply_async can't be called on a non-async client, use reply instead." + "Invalid method call: `reply_async` is not supported on a non-async client. Please use the `reply` method instead." ) request = self.to_dict() request.pop("candidates") @@ -508,9 +515,9 @@ async def reply_async( def _build_chat_response( - request: glm.GenerateMessageRequest, - response: glm.GenerateMessageResponse, - client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient, + request: protos.GenerateMessageRequest, + response: protos.GenerateMessageResponse, + client: glm.DiscussServiceClient | protos.DiscussServiceAsyncClient, ) -> ChatResponse: request = type(request).to_dict(request) prompt = request.pop("prompt") @@ -521,7 +528,7 @@ def _build_chat_response( response = type(response).to_dict(response) response.pop("messages") - response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) + response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"]) if response["candidates"]: last = response["candidates"][0] @@ -535,9 +542,9 @@ def _build_chat_response( def _generate_response( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, client: glm.DiscussServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: request_options = {} @@ -551,9 +558,9 @@ def _generate_response( async def _generate_response_async( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: request_options = {} @@ -574,8 +581,10 @@ def count_message_tokens( messages: discuss_types.MessagesOptions | None = None, model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.TokenCount: + """Calls the API to calculate the number of tokens used in the prompt.""" + model = model_types.make_model_name(model) prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages) diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 375d5dcb4..616fa07bf 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -14,17 +14,16 @@ # limitations under the License. from __future__ import annotations -import dataclasses -from collections.abc import Iterable, Sequence, Mapping import itertools from typing import Any, Iterable, overload, TypeVar, Union, Mapping import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_generative_client from google.generativeai.client import get_default_generative_async_client -from google.generativeai import string_utils +from google.generativeai.types import helper_types from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai.types import content_types @@ -32,7 +31,7 @@ DEFAULT_EMB_MODEL = "models/embedding-001" EMBEDDING_MAX_BATCH_SIZE = 100 -EmbeddingTaskType = glm.TaskType +EmbeddingTaskType = protos.TaskType EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType] @@ -84,7 +83,9 @@ def to_task_type(x: EmbeddingTaskTypeOptions) -> EmbeddingTaskType: def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: if n < 1: - raise ValueError(f"Batch size `n` must be >0, got: {n}") + raise ValueError( + f"Invalid input: The batch size 'n' must be a positive integer. You entered: {n}. Please enter a number greater than 0." + ) batch = [] for item in iterable: batch.append(item) @@ -104,7 +105,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -116,7 +117,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -127,7 +128,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create embeddings for content passed in. @@ -169,11 +170,13 @@ def embed_content( if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT: raise ValueError( - "If a title is specified, the task must be a retrieval document type task." + f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}." ) if output_dimensionality and output_dimensionality < 0: - raise ValueError("`output_dimensionality` must be a non-negative integer.") + raise ValueError( + f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}." + ) if task_type: task_type = to_task_type(task_type) @@ -181,7 +184,7 @@ def embed_content( if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): result = {"embedding": []} requests = ( - glm.EmbedContentRequest( + protos.EmbedContentRequest( model=model, content=content_types.to_content(c), task_type=task_type, @@ -191,7 +194,7 @@ def embed_content( for c in content ) for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): - embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) embedding_response = client.batch_embed_contents( embedding_request, **request_options, @@ -200,7 +203,7 @@ def embed_content( result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) return result else: - embedding_request = glm.EmbedContentRequest( + embedding_request = protos.EmbedContentRequest( model=model, content=content_types.to_content(content), task_type=task_type, @@ -224,7 +227,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -236,7 +239,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -247,9 +250,10 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: - """The async version of `genai.embed_content`.""" + """Calls the API to create async embeddings for content passed in.""" + model = model_types.make_model_name(model) if request_options is None: @@ -260,11 +264,12 @@ async def embed_content_async( if title and to_task_type(task_type) is not EmbeddingTaskType.RETRIEVAL_DOCUMENT: raise ValueError( - "If a title is specified, the task must be a retrieval document type task." + f"Invalid task type: When a title is specified, the task must be of a 'retrieval document' type. Received task type: {task_type} and title: {title}." ) - if output_dimensionality and output_dimensionality < 0: - raise ValueError("`output_dimensionality` must be a non-negative integer.") + raise ValueError( + f"Invalid value: `output_dimensionality` must be a non-negative integer. Received: {output_dimensionality}." + ) if task_type: task_type = to_task_type(task_type) @@ -272,7 +277,7 @@ async def embed_content_async( if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): result = {"embedding": []} requests = ( - glm.EmbedContentRequest( + protos.EmbedContentRequest( model=model, content=content_types.to_content(c), task_type=task_type, @@ -282,7 +287,7 @@ async def embed_content_async( for c in content ) for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): - embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) embedding_response = await client.batch_embed_contents( embedding_request, **request_options, @@ -291,7 +296,7 @@ async def embed_content_async( result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) return result else: - embedding_request = glm.EmbedContentRequest( + embedding_request = protos.EmbedContentRequest( model=model, content=content_types.to_content(content), task_type=task_type, diff --git a/google/generativeai/files.py b/google/generativeai/files.py index 13535c47f..4028d37f7 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -19,7 +19,7 @@ import mimetypes from typing import Iterable import logging -import google.ai.generativelanguage as glm +from google.generativeai import protos from itertools import islice from google.generativeai.types import file_types @@ -37,7 +37,7 @@ def upload_file( display_name: str | None = None, resumable: bool = True, ) -> file_types.File: - """Uploads a file using a supported file service. + """Calls the API to upload a file using a supported file service. Args: path: The path to the file to be uploaded. @@ -73,21 +73,24 @@ def upload_file( def list_files(page_size=100) -> Iterable[file_types.File]: + """Calls the API to list files using a supported file service.""" client = get_default_file_client() - response = client.list_files(glm.ListFilesRequest(page_size=page_size)) + response = client.list_files(protos.ListFilesRequest(page_size=page_size)) for proto in response: yield file_types.File(proto) def get_file(name) -> file_types.File: + """Calls the API to retrieve a specified file using a supported file service.""" client = get_default_file_client() return file_types.File(client.get_file(name=name)) def delete_file(name): - if isinstance(name, (file_types.File, glm.File)): + """Calls the API to permanently delete a specified file using a supported file service.""" + if isinstance(name, (file_types.File, protos.File)): name = name.name - request = glm.DeleteFileRequest(name=name) + request = protos.DeleteFileRequest(name=name) client = get_default_file_client() client.delete_file(request=request) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index a0e7df1e2..7d69ae8f9 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -3,21 +3,19 @@ from __future__ import annotations from collections.abc import Iterable -import dataclasses import textwrap from typing import Any -from typing import Union import reprlib # pylint: disable=bad-continuation, line-too-long import google.api_core.exceptions -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import client -from google.generativeai import string_utils from google.generativeai.types import content_types from google.generativeai.types import generation_types +from google.generativeai.types import helper_types from google.generativeai.types import safety_types @@ -79,9 +77,7 @@ def __init__( if "/" not in model_name: model_name = "models/" + model_name self._model_name = model_name - self._safety_settings = safety_types.to_easy_safety_dict( - safety_settings, harm_category_set="new" - ) + self._safety_settings = safety_types.to_easy_safety_dict(safety_settings) self._generation_config = generation_types.to_generation_config_dict(generation_config) self._tools = content_types.to_function_library(tools) @@ -129,11 +125,8 @@ def _prepare_request( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None, tool_config: content_types.ToolConfigType | None, - ) -> glm.GenerateContentRequest: - """Creates a `glm.GenerateContentRequest` from raw inputs.""" - if not contents: - raise TypeError("contents must not be empty") - + ) -> protos.GenerateContentRequest: + """Creates a `protos.GenerateContentRequest` from raw inputs.""" tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -149,12 +142,12 @@ def _prepare_request( merged_gc = self._generation_config.copy() merged_gc.update(generation_config) - safety_settings = safety_types.to_easy_safety_dict(safety_settings, harm_category_set="new") + safety_settings = safety_types.to_easy_safety_dict(safety_settings) merged_ss = self._safety_settings.copy() merged_ss.update(safety_settings) - merged_ss = safety_types.normalize_safety_settings(merged_ss, harm_category_set="new") + merged_ss = safety_types.normalize_safety_settings(merged_ss) - return glm.GenerateContentRequest( + return protos.GenerateContentRequest( model=self._model_name, contents=contents, generation_config=merged_gc, @@ -181,7 +174,7 @@ def generate_content( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.GenerateContentResponse: """A multipurpose function to generate responses from the model. @@ -216,27 +209,30 @@ def generate_content( ### Input type flexibility - While the underlying API strictly expects a `list[glm.Content]` objects, this method + While the underlying API strictly expects a `list[protos.Content]` objects, this method will convert the user input into the correct type. The hierarchy of types that can be converted is below. Any of these objects can be passed as an equivalent `dict`. - * `Iterable[glm.Content]` - * `glm.Content` - * `Iterable[glm.Part]` - * `glm.Part` - * `str`, `Image`, or `glm.Blob` + * `Iterable[protos.Content]` + * `protos.Content` + * `Iterable[protos.Part]` + * `protos.Part` + * `str`, `Image`, or `protos.Blob` - In an `Iterable[glm.Content]` each `content` is a separate message. - But note that an `Iterable[glm.Part]` is taken as the parts of a single message. + In an `Iterable[protos.Content]` each `content` is a separate message. + But note that an `Iterable[protos.Part]` is taken as the parts of a single message. Arguments: contents: The contents serving as the model's prompt. generation_config: Overrides for the model's generation config. safety_settings: Overrides for the model's safety settings. stream: If True, yield response chunks as they are generated. - tools: `glm.Tools` more info coming soon. + tools: `protos.Tools` more info coming soon. request_options: Options for the request. """ + if not contents: + raise TypeError("contents must not be empty") + request = self._prepare_request( contents=contents, generation_config=generation_config, @@ -267,8 +263,8 @@ def generate_content( except google.api_core.exceptions.InvalidArgument as e: if e.message.startswith("Request payload size exceeds the limit:"): e.message += ( - " Please upload your files with the File API instead." - "`f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" + " The file size is too large. Please use the File API to upload your files instead. " + "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" ) raise @@ -281,9 +277,12 @@ async def generate_content_async( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `GenerativeModel.generate_content`.""" + if not contents: + raise TypeError("contents must not be empty") + request = self._prepare_request( contents=contents, generation_config=generation_config, @@ -314,8 +313,8 @@ async def generate_content_async( except google.api_core.exceptions.InvalidArgument as e: if e.message.startswith("Request payload size exceeds the limit:"): e.message += ( - " Please upload your files with the File API instead." - "`f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" + " The file size is too large. Please use the File API to upload your files instead. " + "Example: `f = genai.upload_file(path); m.generate_content(['tell me about this file:', f])`" ) raise @@ -328,15 +327,15 @@ def count_tokens( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, - ) -> glm.CountTokensResponse: + request_options: helper_types.RequestOptionsType | None = None, + ) -> protos.CountTokensResponse: if request_options is None: request_options = {} if self._client is None: self._client = client.get_default_generative_client() - request = glm.CountTokensRequest( + request = protos.CountTokensRequest( model=self.model_name, generate_content_request=self._prepare_request( contents=contents, @@ -355,15 +354,15 @@ async def count_tokens_async( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, - ) -> glm.CountTokensResponse: + request_options: helper_types.RequestOptionsType | None = None, + ) -> protos.CountTokensResponse: if request_options is None: request_options = {} if self._async_client is None: self._async_client = client.get_default_generative_async_client() - request = glm.CountTokensRequest( + request = protos.CountTokensRequest( model=self.model_name, generate_content_request=self._prepare_request( contents=contents, @@ -389,10 +388,12 @@ def start_chat( >>> response = chat.send_message("Hello?") Arguments: - history: An iterable of `glm.Content` objects, or equvalents to initialize the session. + history: An iterable of `protos.Content` objects, or equivalents to initialize the session. """ if self._generation_config.get("candidate_count", 1) > 1: - raise ValueError("Can't chat with `candidate_count > 1`") + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) return ChatSession( model=self, history=history, @@ -403,11 +404,13 @@ def start_chat( class ChatSession: """Contains an ongoing conversation with the model. - >>> model = genai.GenerativeModel(model="gemini-pro") + >>> model = genai.GenerativeModel('models/gemini-pro') >>> chat = model.start_chat() >>> response = chat.send_message("Hello") >>> print(response.text) - >>> response = chat.send_message(...) + >>> response = chat.send_message("Hello again") + >>> print(response.text) + >>> response = chat.send_message(... This `ChatSession` object collects the messages sent and received, in its `ChatSession.history` attribute. @@ -427,8 +430,8 @@ def __init__( enable_automatic_function_calling: bool = False, ): self.model: GenerativeModel = model - self._history: list[glm.Content] = content_types.to_contents(history) - self._last_sent: glm.Content | None = None + self._history: list[protos.Content] = content_types.to_contents(history) + self._last_sent: protos.Content | None = None self._last_received: generation_types.BaseGenerateContentResponse | None = None self.enable_automatic_function_calling = enable_automatic_function_calling @@ -441,12 +444,13 @@ def send_message( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.GenerateContentResponse: """Sends the conversation history with the added message and returns the model's response. Appends the request and response to the conversation history. - >>> model = genai.GenerativeModel(model="gemini-pro") + >>> model = genai.GenerativeModel('models/gemini-pro') >>> chat = model.start_chat() >>> response = chat.send_message("Hello") >>> print(response.text) @@ -473,10 +477,12 @@ def send_message( safety_settings: Overrides for the model's safety settings. stream: If True, yield response chunks as they are generated. """ + if request_options is None: + request_options = {} + if self.enable_automatic_function_calling and stream: raise NotImplementedError( - "The `google.generativeai` SDK does not yet support `stream=True` with " - "`enable_automatic_function_calling=True`" + "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." ) tools_lib = self.model._get_tools_lib(tools) @@ -491,7 +497,9 @@ def send_message( generation_config = generation_types.to_generation_config_dict(generation_config) if generation_config.get("candidate_count", 1) > 1: - raise ValueError("Can't chat with `candidate_count > 1`") + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) response = self.model.generate_content( contents=history, @@ -500,6 +508,7 @@ def send_message( stream=stream, tools=tools_lib, tool_config=tool_config, + request_options=request_options, ) self._check_response(response=response, stream=stream) @@ -512,6 +521,7 @@ def send_message( safety_settings=safety_settings, stream=stream, tools_lib=tools_lib, + request_options=request_options, ) self._last_sent = content @@ -525,41 +535,49 @@ def _check_response(self, *, response, stream): if not stream: if response.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, ): raise generation_types.StopCandidateException(response.candidates[0]) - def _get_function_calls(self, response) -> list[glm.FunctionCall]: + def _get_function_calls(self, response) -> list[protos.FunctionCall]: candidates = response.candidates if len(candidates) != 1: raise ValueError( - f"Automatic function calling only works with 1 candidate, got: {len(candidates)}" + f"Invalid number of candidates: Automatic function calling only works with 1 candidate, but {len(candidates)} were provided." ) parts = candidates[0].content.parts function_calls = [part.function_call for part in parts if part and "function_call" in part] return function_calls def _handle_afc( - self, *, response, history, generation_config, safety_settings, stream, tools_lib - ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + self, + *, + response, + history, + generation_config, + safety_settings, + stream, + tools_lib, + request_options, + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): if not all(callable(tools_lib[fc]) for fc in function_calls): break history.append(response.candidates[0].content) - function_response_parts: list[glm.Part] = [] + function_response_parts: list[protos.Part] = [] for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( - "This should never happen, it should only return None if the declaration" - "is not callable, and that's guarded against above." + "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration " + "is not callable, which is checked earlier in the code." ) function_response_parts.append(fr) - send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) history.append(send) response = self.model.generate_content( @@ -568,6 +586,7 @@ def _handle_afc( safety_settings=safety_settings, stream=stream, tools=tools_lib, + request_options=request_options, ) self._check_response(response=response, stream=stream) @@ -584,12 +603,15 @@ async def send_message_async( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `ChatSession.send_message`.""" + if request_options is None: + request_options = {} + if self.enable_automatic_function_calling and stream: raise NotImplementedError( - "The `google.generativeai` SDK does not yet support `stream=True` with " - "`enable_automatic_function_calling=True`" + "Unsupported configuration: The `google.generativeai` SDK currently does not support the combination of `stream=True` and `enable_automatic_function_calling=True`." ) tools_lib = self.model._get_tools_lib(tools) @@ -604,7 +626,9 @@ async def send_message_async( generation_config = generation_types.to_generation_config_dict(generation_config) if generation_config.get("candidate_count", 1) > 1: - raise ValueError("Can't chat with `candidate_count > 1`") + raise ValueError( + "Invalid configuration: The chat functionality does not support `candidate_count` greater than 1." + ) response = await self.model.generate_content_async( contents=history, @@ -613,6 +637,7 @@ async def send_message_async( stream=stream, tools=tools_lib, tool_config=tool_config, + request_options=request_options, ) self._check_response(response=response, stream=stream) @@ -625,6 +650,7 @@ async def send_message_async( safety_settings=safety_settings, stream=stream, tools_lib=tools_lib, + request_options=request_options, ) self._last_sent = content @@ -633,24 +659,32 @@ async def send_message_async( return response async def _handle_afc_async( - self, *, response, history, generation_config, safety_settings, stream, tools_lib - ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + self, + *, + response, + history, + generation_config, + safety_settings, + stream, + tools_lib, + request_options, + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): if not all(callable(tools_lib[fc]) for fc in function_calls): break history.append(response.candidates[0].content) - function_response_parts: list[glm.Part] = [] + function_response_parts: list[protos.Part] = [] for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( - "This should never happen, it should only return None if the declaration" - "is not callable, and that's guarded against above." + "Unexpected state: The function reference (fr) should never be None. It should only return None if the declaration " + "is not callable, which is checked earlier in the code." ) function_response_parts.append(fr) - send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) history.append(send) response = await self.model.generate_content_async( @@ -659,6 +693,7 @@ async def _handle_afc_async( safety_settings=safety_settings, stream=stream, tools=tools_lib, + request_options=request_options, ) self._check_response(response=response, stream=stream) @@ -673,7 +708,7 @@ def __copy__(self): history=list(self.history), ) - def rewind(self) -> tuple[glm.Content, glm.Content]: + def rewind(self) -> tuple[protos.Content, protos.Content]: """Removes the last request/response pair from the chat history.""" if self._last_received is None: result = self._history.pop(-2), self._history.pop() @@ -690,29 +725,27 @@ def last(self) -> generation_types.BaseGenerateContentResponse | None: return self._last_received @property - def history(self) -> list[glm.Content]: + def history(self) -> list[protos.Content]: """The chat history.""" last = self._last_received if last is None: return self._history if last.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, ): error = generation_types.StopCandidateException(last.candidates[0]) last._error = error if last._error is not None: raise generation_types.BrokenResponseError( - "Can not build a coherent chat history after a broken " - "streaming response " - "(See the previous Exception fro details). " - "To inspect the last response object, use `chat.last`." - "To remove the last request/response `Content` objects from the chat " - "call `last_send, last_received = chat.rewind()` and continue " - "without it." + "Unable to build a coherent chat history due to a broken streaming response. " + "Refer to the previous exception for details. " + "To inspect the last response object, use `chat.last`. " + "To remove the last request/response `Content` objects from the chat, " + "call `last_send, last_received = chat.rewind()` and continue without it." ) from last._error sent = self._last_sent @@ -737,7 +770,7 @@ def __repr__(self) -> str: _model = str(self.model).replace("\n", "\n" + " " * 4) def content_repr(x): - return f"glm.Content({_dict_repr.repr(type(x).to_dict(x))})" + return f"protos.Content({_dict_repr.repr(type(x).to_dict(x))})" try: history = list(self.history) diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 7c7b8a5cf..9ba0745c1 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -18,9 +18,12 @@ from typing import Any, Literal import google.ai.generativelanguage as glm + +from google.generativeai import protos from google.generativeai import operations from google.generativeai.client import get_default_model_client from google.generativeai.types import model_types +from google.generativeai.types import helper_types from google.api_core import operation from google.api_core import protobuf_helpers from google.protobuf import field_mask_pb2 @@ -31,23 +34,23 @@ def get_model( name: model_types.AnyModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model | model_types.TunedModel: - """Given a model name, fetch the `types.Model` or `types.TunedModel` object. + """Calls the API to fetch a model by name. ``` import pprint - model = genai.get_tuned_model(model_name): + model = genai.get_model('models/gemini-pro') pprint.pprint(model) ``` Args: - name: The name of the model to fetch. + name: The name of the model to fetch. Should start with `models/` client: The client to use. request_options: Options for the request. Returns: - A `types.Model` or `types.TunedModel` object. + A `types.Model` """ name = model_types.make_model_name(name) if name.startswith("models/"): @@ -55,25 +58,27 @@ def get_model( elif name.startswith("tunedModels/"): return get_tuned_model(name, client=client, request_options=request_options) else: - raise ValueError("Model names must start with `models/` or `tunedModels/`") + raise ValueError( + f"Invalid model name: Model names must start with `models/` or `tunedModels/`. Received: {name}" + ) def get_base_model( name: model_types.BaseModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model: - """Get the `types.Model` for the given base model name. + """Calls the API to fetch a base model by name. ``` import pprint - model = genai.get_model('models/chat-bison-001'): + model = genai.get_base_model('models/chat-bison-001') pprint.pprint(model) ``` Args: - name: The name of the model to fetch. + name: The name of the model to fetch. Should start with `models/` client: The client to use. request_options: Options for the request. @@ -88,7 +93,9 @@ def get_base_model( name = model_types.make_model_name(name) if not name.startswith("models/"): - raise ValueError(f"Base model names must start with `models/`, got: {name}") + raise ValueError( + f"Invalid model name: Base model names must start with `models/`. Received: {name}" + ) result = client.get_model(name=name, **request_options) result = type(result).to_dict(result) @@ -99,18 +106,18 @@ def get_tuned_model( name: model_types.TunedModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: - """Get the `types.TunedModel` for the given tuned model name. + """Calls the API to fetch a tuned model by name. ``` import pprint - model = genai.get_tuned_model('tunedModels/my-model-1234'): + model = genai.get_tuned_model('tunedModels/gemini-1.0-pro-001') pprint.pprint(model) ``` Args: - name: The name of the model to fetch. + name: The name of the model to fetch. Should start with `tunedModels/` client: The client to use. request_options: Options for the request. @@ -126,7 +133,9 @@ def get_tuned_model( name = model_types.make_model_name(name) if not name.startswith("tunedModels/"): - raise ValueError("Tuned model names must start with `tunedModels/`") + raise ValueError( + f"Invalid model name: Tuned model names must start with `tunedModels/`. Received: {name}" + ) result = client.get_tuned_model(name=name, **request_options) @@ -136,6 +145,8 @@ def get_tuned_model( def get_base_model_name( model: model_types.AnyModelNameOptions, client: glm.ModelServiceClient | None = None ): + """Calls the API to fetch the base model name of a model.""" + if isinstance(model, str): if model.startswith("tunedModels/"): model = get_model(model, client=client) @@ -146,14 +157,17 @@ def get_base_model_name( base_model = model.base_model elif isinstance(model, model_types.Model): base_model = model.name - elif isinstance(model, glm.Model): + elif isinstance(model, protos.Model): base_model = model.name - elif isinstance(model, glm.TunedModel): + elif isinstance(model, protos.TunedModel): base_model = getattr(model, "base_model", None) if not base_model: base_model = model.tuned_model_source.base_model else: - raise TypeError(f"Cannot understand model: {model}") + raise TypeError( + f"Invalid model: The provided model '{model}' is not recognized or supported. " + "Supported types are: str, model_types.TunedModel, model_types.Model, protos.Model, and protos.TunedModel." + ) return base_model @@ -162,9 +176,9 @@ def list_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.ModelsIterable: - """Lists available models. + """Calls the API to list all available models. ``` import pprint @@ -196,9 +210,9 @@ def list_tuned_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModelsIterable: - """Lists available models. + """Calls the API to list all tuned models. ``` import pprint @@ -244,9 +258,9 @@ def create_tuned_model( input_key: str = "text_input", output_key: str = "output", client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> operations.CreateTunedModelOperation: - """Launches a tuning job to create a TunedModel. + """Calls the API to initiate a tuning process that optimizes a model for specific data, returning an operation object to track and manage the tuning progress. Since tuning a model can take significant time, this API doesn't wait for the tuning to complete. Instead, it returns a `google.api_core.operation.Operation` object that lets you check on the @@ -270,10 +284,10 @@ def create_tuned_model( Args: source_model: The name of the model to tune. training_data: The dataset to tune the model on. This must be either: - * A `glm.Dataset`, or + * A `protos.Dataset`, or * An `Iterable` of: - *`glm.TuningExample`, - * {'text_input': text_input, 'output': output} dicts, or + *`protos.TuningExample`, + * `{'text_input': text_input, 'output': output}` dicts * `(text_input, output)` tuples. * A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which columns to use as the input/output @@ -319,23 +333,25 @@ def create_tuned_model( } } else: - ValueError(f"Not understood: `{source_model=}`") + raise ValueError( + f"Invalid model name: The provided model '{source_model}' does not match any known model patterns such as 'models/' or 'tunedModels/'" + ) training_data = model_types.encode_tuning_data( training_data, input_key=input_key, output_key=output_key ) - hyperparameters = glm.Hyperparameters( + hyperparameters = protos.Hyperparameters( epoch_count=epoch_count, batch_size=batch_size, learning_rate=learning_rate, ) - tuning_task = glm.TuningTask( + tuning_task = protos.TuningTask( training_data=training_data, hyperparameters=hyperparameters, ) - tuned_model = glm.TunedModel( + tuned_model = protos.TunedModel( **source_model, display_name=display_name, description=description, @@ -344,6 +360,7 @@ def create_tuned_model( top_k=top_k, tuning_task=tuning_task, ) + operation = client.create_tuned_model( dict(tuned_model_id=id, tuned_model=tuned_model), **request_options ) @@ -353,11 +370,11 @@ def create_tuned_model( @typing.overload def update_tuned_model( - tuned_model: glm.TunedModel, + tuned_model: protos.TunedModel, updates: None = None, *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -368,19 +385,20 @@ def update_tuned_model( updates: dict[str, Any], *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass def update_tuned_model( - tuned_model: str | glm.TunedModel, + tuned_model: str | protos.TunedModel, updates: dict[str, Any] | None = None, *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: - """Push updates to the tuned model. Only certain attributes are updatable.""" + """Calls the API to puch updates to a specified tuned model where only certain attributes are updatable.""" + if request_options is None: request_options = {} @@ -391,10 +409,9 @@ def update_tuned_model( name = tuned_model if not isinstance(updates, dict): raise TypeError( - "When calling `update_tuned_model(name:str, updates: dict)`,\n" - "`updates` must be a `dict`.\n" - f"got: {type(updates)}" + f"Invalid argument type: In the function `update_tuned_model(name:str, updates: dict)`, the `updates` argument must be of type `dict`. Received type: {type(updates).__name__}." ) + tuned_model = client.get_tuned_model(name=name, **request_options) updates = flatten_update_paths(updates) @@ -403,11 +420,11 @@ def update_tuned_model( field_mask.paths.append(path) for path, value in updates.items(): _apply_update(tuned_model, path, value) - elif isinstance(tuned_model, glm.TunedModel): + elif isinstance(tuned_model, protos.TunedModel): if updates is not None: raise ValueError( - "When calling `update_tuned_model(tuned_model:glm.TunedModel, updates=None)`," - "`updates` must not be set." + "Invalid argument: When calling `update_tuned_model(tuned_model:protos.TunedModel, updates=None)`, " + "the `updates` argument must not be set." ) name = tuned_model.name @@ -415,12 +432,12 @@ def update_tuned_model( field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb) else: raise TypeError( - "For `update_tuned_model(tuned_model:dict|glm.TunedModel)`," - f"`tuned_model` must be a `dict` or a `glm.TunedModel`. Got a: `{type(tuned_model)}`" + "Invalid argument type: In the function `update_tuned_model(tuned_model:dict|protos.TunedModel)`, the " + f"`tuned_model` argument must be of type `dict` or `protos.TunedModel`. Received type: {type(tuned_model).__name__}." ) result = client.update_tuned_model( - glm.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), + protos.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), **request_options, ) return model_types.decode_tuned_model(result) @@ -436,8 +453,10 @@ def _apply_update(thing, path, value): def delete_tuned_model( tuned_model: model_types.TunedModelNameOptions, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> None: + """Calls the API to delete a specified tuned model""" + if request_options is None: request_options = {} diff --git a/google/generativeai/operations.py b/google/generativeai/operations.py index d492a9dee..52fd8a1b8 100644 --- a/google/generativeai/operations.py +++ b/google/generativeai/operations.py @@ -17,7 +17,7 @@ import functools from typing import Iterator -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai.types import model_types @@ -27,6 +27,8 @@ def list_operations(*, client=None) -> Iterator[CreateTunedModelOperation]: + """Calls the API to list all operations""" + if client is None: client = client_lib.get_default_operations_client() @@ -41,6 +43,7 @@ def list_operations(*, client=None) -> Iterator[CreateTunedModelOperation]: def get_operation(name: str, *, client=None) -> CreateTunedModelOperation: + """Calls the API to get a specific operation""" if client is None: client = client_lib.get_default_operations_client() @@ -49,8 +52,9 @@ def get_operation(name: str, *, client=None) -> CreateTunedModelOperation: def delete_operation(name: str, *, client=None): - """Raises: - google.api_core.exceptions.MethodNotImplemented: Not implemented.""" + """Calls the API to delete a specific operation""" + + # Raises:google.api_core.exceptions.MethodNotImplemented: Not implemented. if client is None: client = client_lib.get_default_operations_client() @@ -71,8 +75,8 @@ def from_proto(cls, proto, client): cls=CreateTunedModelOperation, operation=proto, operations_client=client, - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) @classmethod @@ -107,14 +111,14 @@ def update(self): """Refresh the current statuses in metadata/result/error""" self._refresh_and_update() - def wait_bar(self, **kwargs) -> Iterator[glm.CreateTunedModelMetadata]: + def wait_bar(self, **kwargs) -> Iterator[protos.CreateTunedModelMetadata]: """A tqdm wait bar, yields `Operation` statuses until complete. Args: **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)` Yields: - Operation statuses as `glm.CreateTunedModelMetadata` objects. + Operation statuses as `protos.CreateTunedModelMetadata` objects. """ bar = tqdm.tqdm(total=self.metadata.total_steps, initial=0, **kwargs) @@ -127,7 +131,7 @@ def wait_bar(self, **kwargs) -> Iterator[glm.CreateTunedModelMetadata]: bar.update(self.metadata.completed_steps - bar.n) return self.result() - def set_result(self, result: glm.TunedModel): + def set_result(self, result: protos.TunedModel): result = model_types.decode_tuned_model(result) super().set_result(result) diff --git a/google/generativeai/permission.py b/google/generativeai/permission.py index b502f9a60..b2b7c15e1 100644 --- a/google/generativeai/permission.py +++ b/google/generativeai/permission.py @@ -90,9 +90,8 @@ def _construct_name( # if name is not provided, then try to construct name via provided resource_name and permission_id. if not (resource_name and permission_id): raise ValueError( - "Either `name` or (`resource_name` and `permission_id`) must be provided." + f"Invalid arguments: Either `name` or both `resource_name` and `permission_id` must be provided. Received name: {name}, resource_name: {resource_name}, permission_id: {permission_id}." ) - if resource_type: resource_type = _to_resource_type(resource_type) else: @@ -100,8 +99,7 @@ def _construct_name( resource_path_components = resource_name.split("/") if len(resource_path_components) != 2: raise ValueError( - f"Invalid `resource_name` format. Expected format: \ - `resource_type/resource_name`. Got: `{resource_name}` instead." + f"Invalid `resource_name` format: Expected format is `resource_type/resource_name` (2 components). Received: `{resource_name}` with {len(resource_path_components)} components." ) resource_type = _to_resource_type(resource_path_components[0]) @@ -128,7 +126,7 @@ def get_permission( permission_id: str | int | None = None, resource_type: str | None = None, ) -> permission_types.Permission: - """Get information about a permission by name. + """Calls the API to retrieve detailed information about a specific permission based on resource type and permission identifiers Args: name: The name of the permission. diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py new file mode 100644 index 000000000..010396c75 --- /dev/null +++ b/google/generativeai/protos.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module provides low level access to the ProtoBuffer "Message" classes used by the API. + +**For typical usage of this SDK you do not need to use any of these classes.** + +ProtoBufers are Google API's serilization format. They are strongly typed and efficient. + +The `genai` SDK tries to be permissive about what objects it will accept from a user, but in the end +the SDK always converts input to an appropriate Proto Message object to send as the request. Each API request +has a `*Request` and `*Response` Message defined here. + +If you have any uncertainty about what the API may accept or return, these classes provide the +complete/unambiguous answer. They come from the `google-ai-generativelanguage` package which is +generated from a snapshot of the API definition. + +>>> from google.generativeai import protos +>>> import inspect +>>> print(inspect.getsource(protos.Part)) + +Proto classes can have "oneof" fields. Use `in` to check which `oneof` field is set. + +>>> p = protos.Part(text='hello') +>>> 'text' in p +True +>>> p.inline_data = {'mime_type':'image/png', 'data': b'PNG'} +>>> type(p.inline_data) is protos.Blob +True +>>> 'inline_data' in p +True +>>> 'text' in p +False + +Instances of all Message classes can be converted into JSON compatible dictionaries with the following construct +(Bytes are base64 encoded): + +>>> p_dict = type(p).to_dict(p) +>>> p_dict +{'inline_data': {'mime_type': 'image/png', 'data': 'UE5H'}} + +A compatible dict can be converted to an instance of a Message class by passing it as the first argument to the +constructor: + +>>> p = protos.Part(p_dict) +inline_data { + mime_type: "image/png" + data: "PNG" +} + +Note when converting that `to_dict` accepts additional arguments: + +- `use_integers_for_enums:bool = True`, Set it to `False` to replace enum int values with their string + names in the output +- ` including_default_value_fields:bool = True`, Set it to `False` to reduce the verbosity of the output. + +Additional arguments are described in the docstring: + +>>> help(proto.Part.to_dict) +""" + +from google.ai.generativelanguage_v1beta.types import * +from google.ai.generativelanguage_v1beta.types import __all__ diff --git a/google/generativeai/responder.py b/google/generativeai/responder.py index 238e7e13a..bb85167ad 100644 --- a/google/generativeai/responder.py +++ b/google/generativeai/responder.py @@ -22,9 +22,9 @@ import pydantic -from google.ai import generativelanguage as glm +from google.generativeai import protos -Type = glm.Type +Type = protos.Type TypeOptions = Union[int, str, Type] @@ -186,8 +186,8 @@ def _rename_schema_fields(schema: dict[str, Any]): class FunctionDeclaration: def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): - """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" - self._proto = glm.FunctionDeclaration( + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( name=name, description=description, parameters=_rename_schema_fields(parameters) ) @@ -200,7 +200,7 @@ def description(self) -> str: return self._proto.description @property - def parameters(self) -> glm.Schema: + def parameters(self) -> protos.Schema: return self._proto.parameters @classmethod @@ -209,7 +209,7 @@ def from_proto(cls, proto) -> FunctionDeclaration: self._proto = proto return self - def to_proto(self) -> glm.FunctionDeclaration: + def to_proto(self) -> protos.FunctionDeclaration: return self._proto @staticmethod @@ -255,16 +255,16 @@ def __init__( super().__init__(name=name, description=description, parameters=parameters) self.function = function - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: result = self.function(**fc.args) if not isinstance(result, dict): result = {"result": result} - return glm.FunctionResponse(name=fc.name, response=result) + return protos.FunctionResponse(name=fc.name, response=result) FunctionDeclarationType = Union[ FunctionDeclaration, - glm.FunctionDeclaration, + protos.FunctionDeclaration, dict[str, Any], Callable[..., Any], ] @@ -272,8 +272,8 @@ def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: def _make_function_declaration( fun: FunctionDeclarationType, -) -> FunctionDeclaration | glm.FunctionDeclaration: - if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): return fun elif isinstance(fun, dict): if "function" in fun: @@ -284,20 +284,20 @@ def _make_function_declaration( return CallableFunctionDeclaration.from_function(fun) else: raise TypeError( - "Expected an instance of `genai.FunctionDeclaraionType`. Got a:\n" f" {type(fun)=}\n", + f"Invalid argument type: Expected an instance of `genai.FunctionDeclarationType`. Received type: {type(fun).__name__}.", fun, ) -def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: - if isinstance(fd, glm.FunctionDeclaration): +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): return fd return fd.to_proto() class Tool: - """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): # The main path doesn't use this but is seems useful. @@ -309,23 +309,23 @@ def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): raise ValueError("") self._index[fd.name] = fd - self._proto = glm.Tool( + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations] ) @property - def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: declaration = self[fc] if not callable(declaration): return None @@ -341,21 +341,21 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] def _make_tool(tool: ToolType) -> Tool: if isinstance(tool, Tool): return tool - elif isinstance(tool, glm.Tool): + elif isinstance(tool, protos.Tool): return Tool(function_declarations=tool.function_declarations) elif isinstance(tool, dict): if "function_declarations" in tool: return Tool(**tool) else: fd = tool - return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -363,7 +363,7 @@ def _make_tool(tool: ToolType) -> Tool: return Tool(function_declarations=[tool]) except Exception as e: raise TypeError( - "Expected an instance of `genai.ToolType`. Got a:\n" f" {type(tool)=}", + f"Invalid argument type: Expected an instance of `genai.ToolType`. Received type: {type(tool).__name__}.", tool, ) from e @@ -380,26 +380,25 @@ def __init__(self, tools: Iterable[ToolType]): name = declaration.name if name in self._index: raise ValueError( - f"A `FunctionDeclaration` named {name} is already defined. " - "Each `FunctionDeclaration` must be uniquely named." + f"Invalid operation: A `FunctionDeclaration` named '{name}' is already defined. Each `FunctionDeclaration` must have a unique name." ) self._index[declaration.name] = declaration def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: declaration = self[fc] if not callable(declaration): return None response = declaration(fc) - return glm.Part(function_response=response) + return protos.Part(function_response=response) def to_proto(self): return [tool.to_proto() for tool in self._tools] @@ -432,7 +431,7 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No return FunctionLibrary(tools=lib) -FunctionCallingMode = glm.FunctionCallingConfig.Mode +FunctionCallingMode = protos.FunctionCallingConfig.Mode # fmt: off _FUNCTION_CALLING_MODE = { @@ -468,12 +467,12 @@ class FunctionCallingConfigDict(TypedDict): FunctionCallingConfigType = Union[ - FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig ] -def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, glm.FunctionCallingConfig): +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): return obj elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} @@ -483,29 +482,31 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", + "Invalid argument type: Could not convert input to `protos.FunctionCallingConfig`." + f" Received type: {type(obj).__name__}.", obj, ) - return glm.FunctionCallingConfig(obj) + return protos.FunctionCallingConfig(obj) class ToolConfigDict: function_calling_config: FunctionCallingConfigType -ToolConfigType = Union[ToolConfigDict, glm.ToolConfig] +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] -def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: - if isinstance(obj, glm.ToolConfig): +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): return obj elif isinstance(obj, dict): fcc = obj.pop("function_calling_config") fcc = to_function_calling_config(fcc) obj["function_calling_config"] = fcc - return glm.ToolConfig(**obj) + return protos.ToolConfig(**obj) else: raise TypeError( - f"Could not convert input to `glm.ToolConfig`: \n'" f" type: {type(obj)}\n", obj + "Invalid argument type: Could not convert input to `protos.ToolConfig`. " + f"Received type: {type(obj).__name__}.", ) diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index dfd5e9026..53c90140a 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -14,15 +14,15 @@ # limitations under the License. from __future__ import annotations -import re -import string -import dataclasses -from typing import Any, AsyncIterable, Iterable, Optional + +from typing import AsyncIterable, Iterable, Optional import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client +from google.generativeai.types import helper_types from google.generativeai.types.model_types import idecode_time from google.generativeai.types import retriever_types @@ -31,12 +31,9 @@ def create_corpus( name: str | None = None, display_name: str | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: - """ - Create a new `Corpus` in the retriever service, and return it as a `retriever_types.Corpus` instance. - - Users can specify either a name or display_name. + """Calls the API to create a new `Corpus` by specifying either a corpus resource name as an ID or a display name, and returns the created `Corpus`. Args: name: The corpus resource name (ID). The name must be alphanumeric and fewer @@ -59,13 +56,13 @@ def create_corpus( client = get_default_retriever_client() if name is None: - corpus = glm.Corpus(display_name=display_name) + corpus = protos.Corpus(display_name=display_name) elif retriever_types.valid_name(name): - corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name) + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) else: raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateCorpusRequest(corpus=corpus) + request = protos.CreateCorpusRequest(corpus=corpus) response = client.create_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -78,7 +75,7 @@ async def create_corpus_async( name: str | None = None, display_name: str | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """This is the async version of `retriever.create_corpus`.""" if request_options is None: @@ -88,13 +85,13 @@ async def create_corpus_async( client = get_default_retriever_async_client() if name is None: - corpus = glm.Corpus(display_name=display_name) + corpus = protos.Corpus(display_name=display_name) elif retriever_types.valid_name(name): - corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name) + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) else: raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateCorpusRequest(corpus=corpus) + request = protos.CreateCorpusRequest(corpus=corpus) response = await client.create_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -106,10 +103,9 @@ async def create_corpus_async( def get_corpus( name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip - """ - Fetch a specific `Corpus` from the retriever service. + """Calls the API to fetch a `Corpus` by name and returns the `Corpus`. Args: name: The `Corpus` name. @@ -127,7 +123,7 @@ def get_corpus( if "/" not in name: name = "corpora/" + name - request = glm.GetCorpusRequest(name=name) + request = protos.GetCorpusRequest(name=name) response = client.get_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -139,9 +135,10 @@ def get_corpus( async def get_corpus_async( name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """This is the async version of `retriever.get_corpus`.""" + if request_options is None: request_options = {} @@ -151,7 +148,7 @@ async def get_corpus_async( if "/" not in name: name = "corpora/" + name - request = glm.GetCorpusRequest(name=name) + request = protos.GetCorpusRequest(name=name) response = await client.get_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -164,15 +161,15 @@ def delete_corpus( name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip - """ - Delete a `Corpus` from the service. + """Calls the API to remove a `Corpus` from the service, optionally deleting associated `Document`s and objects if the `force` parameter is set to true. Args: name: The `Corpus` name. force: If set to true, any `Document`s and objects related to this `Corpus` will also be deleted. request_options: Options for the request. + """ if request_options is None: request_options = {} @@ -183,7 +180,7 @@ def delete_corpus( if "/" not in name: name = "corpora/" + name - request = glm.DeleteCorpusRequest(name=name, force=force) + request = protos.DeleteCorpusRequest(name=name, force=force) client.delete_corpus(request, **request_options) @@ -191,7 +188,7 @@ async def delete_corpus_async( name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """This is the async version of `retriever.delete_corpus`.""" if request_options is None: @@ -203,7 +200,7 @@ async def delete_corpus_async( if "/" not in name: name = "corpora/" + name - request = glm.DeleteCorpusRequest(name=name, force=force) + request = protos.DeleteCorpusRequest(name=name, force=force) await client.delete_corpus(request, **request_options) @@ -211,10 +208,9 @@ def list_corpora( *, page_size: Optional[int] = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[retriever_types.Corpus]: - """ - List the Corpuses you own in the service. + """Calls the API to list all `Corpora` in the service and returns a list of paginated `Corpora`. Args: page_size: Maximum number of `Corpora` to request. @@ -230,7 +226,7 @@ def list_corpora( if client is None: client = get_default_retriever_client() - request = glm.ListCorporaRequest(page_size=page_size) + request = protos.ListCorporaRequest(page_size=page_size) for corpus in client.list_corpora(request, **request_options): corpus = type(corpus).to_dict(corpus) idecode_time(corpus, "create_time") @@ -242,7 +238,7 @@ async def list_corpora_async( *, page_size: Optional[int] = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[retriever_types.Corpus]: """This is the async version of `retriever.list_corpora`.""" if request_options is None: @@ -251,7 +247,7 @@ async def list_corpora_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListCorporaRequest(page_size=page_size) + request = protos.ListCorporaRequest(page_size=page_size) async for corpus in await client.list_corpora(request, **request_options): corpus = type(corpus).to_dict(corpus) idecode_time(corpus, "create_time") diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 3a147f945..2a6267661 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -21,12 +21,15 @@ import google.ai.generativelanguage as glm +from google.generativeai import protos + from google.generativeai.client import get_default_text_client from google.generativeai import string_utils +from google.generativeai.types import helper_types from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai import models -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types DEFAULT_TEXT_MODEL = "models/text-bison-001" EMBEDDING_MAX_BATCH_SIZE = 100 @@ -51,25 +54,27 @@ def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: yield batch -def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: +def _make_text_prompt(prompt: str | dict[str, str]) -> protos.TextPrompt: """ - Creates a `glm.TextPrompt` object based on the provided prompt input. + Creates a `protos.TextPrompt` object based on the provided prompt input. Args: prompt: The prompt input, either a string or a dictionary. Returns: - glm.TextPrompt: A TextPrompt object containing the prompt text. + protos.TextPrompt: A TextPrompt object containing the prompt text. Raises: TypeError: If the provided prompt is neither a string nor a dictionary. """ if isinstance(prompt, str): - return glm.TextPrompt(text=prompt) + return protos.TextPrompt(text=prompt) elif isinstance(prompt, dict): - return glm.TextPrompt(prompt) + return protos.TextPrompt(prompt) else: - TypeError("Expected string or dictionary for text prompt.") + raise TypeError( + "Invalid argument type: Expected a string or dictionary for the text prompt." + ) def _make_generate_text_request( @@ -81,13 +86,13 @@ def _make_generate_text_request( max_output_tokens: int | None = None, top_p: int | None = None, top_k: int | None = None, - safety_settings: safety_types.SafetySettingOptions | None = None, + safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, -) -> glm.GenerateTextRequest: +) -> protos.GenerateTextRequest: """ - Creates a `glm.GenerateTextRequest` object based on the provided parameters. + Creates a `protos.GenerateTextRequest` object based on the provided parameters. - This function generates a `glm.GenerateTextRequest` object with the specified + This function generates a `protos.GenerateTextRequest` object with the specified parameters. It prepares the input parameters and creates a request that can be used for generating text using the chosen model. @@ -104,19 +109,17 @@ def _make_generate_text_request( or iterable of strings. Defaults to None. Returns: - `glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. + `protos.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) - safety_settings = safety_types.normalize_safety_settings( - safety_settings, harm_category_set="old" - ) + safety_settings = palm_safety_types.normalize_safety_settings(safety_settings) if isinstance(stop_sequences, str): stop_sequences = [stop_sequences] if stop_sequences: stop_sequences = list(stop_sequences) - return glm.GenerateTextRequest( + return protos.GenerateTextRequest( model=model, prompt=prompt, temperature=temperature, @@ -138,12 +141,12 @@ def generate_text( max_output_tokens: int | None = None, top_p: float | None = None, top_k: float | None = None, - safety_settings: safety_types.SafetySettingOptions | None = None, + safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, client: glm.TextServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.Completion: - """Calls the API and returns a `types.Completion` containing the response. + """Calls the API to generate text based on the provided prompt. Args: model: Which model to call, as a string or a `types.Model`. @@ -215,12 +218,12 @@ def __init__(self, **kwargs): def _generate_response( - request: glm.GenerateTextRequest, + request: protos.GenerateTextRequest, client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ - Generates a response using the provided `glm.GenerateTextRequest` and client. + Generates a response using the provided `protos.GenerateTextRequest` and client. Args: request: The text generation request. @@ -240,11 +243,11 @@ def _generate_response( response = client.generate_text(request, **request_options) response = type(response).to_dict(response) - response["filters"] = safety_types.convert_filters_to_enums(response["filters"]) - response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums( + response["filters"] = palm_safety_types.convert_filters_to_enums(response["filters"]) + response["safety_feedback"] = palm_safety_types.convert_safety_feedback_to_enums( response["safety_feedback"] ) - response["candidates"] = safety_types.convert_candidate_enums(response["candidates"]) + response["candidates"] = palm_safety_types.convert_candidate_enums(response["candidates"]) return Completion(_client=client, **response) @@ -253,8 +256,10 @@ def count_text_tokens( model: model_types.AnyModelNameOptions, prompt: str, client: glm.TextServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.TokenCount: + """Calls the API to count the number of tokens in the text prompt.""" + base_model = models.get_base_model_name(model) if request_options is None: @@ -264,7 +269,7 @@ def count_text_tokens( client = get_default_text_client() result = client.count_text_tokens( - glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), + protos.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), **request_options, ) @@ -276,7 +281,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: str, client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -285,7 +290,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: Sequence[str], client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -293,7 +298,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: str | Sequence[str], client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create an embedding for the text passed in. @@ -319,7 +324,7 @@ def generate_embeddings( client = get_default_text_client() if isinstance(text, str): - embedding_request = glm.EmbedTextRequest(model=model, text=text) + embedding_request = protos.EmbedTextRequest(model=model, text=text) embedding_response = client.embed_text( embedding_request, **request_options, @@ -330,7 +335,7 @@ def generate_embeddings( result = {"embedding": []} for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): # TODO(markdaoust): This could use an option for returning an iterator or wait-bar. - embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch) + embedding_request = protos.BatchEmbedTextRequest(model=model, texts=batch) embedding_response = client.batch_embed_text( embedding_request, **request_options, diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index dc0a76761..21768bbe6 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -19,6 +19,7 @@ from google.generativeai.types.discuss_types import * from google.generativeai.types.file_types import * from google.generativeai.types.generation_types import * +from google.generativeai.types.helper_types import * from google.generativeai.types.model_types import * from google.generativeai.types.safety_types import * from google.generativeai.types.text_types import * diff --git a/google/generativeai/types/answer_types.py b/google/generativeai/types/answer_types.py index 18bd11d62..143a578a4 100644 --- a/google/generativeai/types/answer_types.py +++ b/google/generativeai/types/answer_types.py @@ -16,11 +16,11 @@ from typing import Union -import google.ai.generativelanguage as glm +from google.generativeai import protos __all__ = ["Answer"] -FinishReason = glm.Candidate.FinishReason +FinishReason = protos.Candidate.FinishReason FinishReasonOptions = Union[int, str, FinishReason] diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py index ae857c35b..9f169703f 100644 --- a/google/generativeai/types/citation_types.py +++ b/google/generativeai/types/citation_types.py @@ -17,7 +17,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -33,10 +33,10 @@ class CitationSourceDict(TypedDict): uri: str | None license: str | None - __doc__ = string_utils.strip_oneof(glm.CitationSource.__doc__) + __doc__ = string_utils.strip_oneof(protos.CitationSource.__doc__) class CitationMetadataDict(TypedDict): citation_sources: List[CitationSourceDict | None] - __doc__ = string_utils.strip_oneof(glm.CitationMetadata.__doc__) + __doc__ = string_utils.strip_oneof(protos.CitationMetadata.__doc__) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 67c1338bf..7e343a5c0 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -26,7 +26,7 @@ import pydantic from google.generativeai.types import file_types -from google.ai import generativelanguage as glm +from google.generativeai import protos if typing.TYPE_CHECKING: import PIL.Image @@ -72,7 +72,7 @@ def pil_to_blob(img): bytesio = io.BytesIO() - if isinstance(img, PIL.PngImagePlugin.PngImageFile): + if isinstance(img, PIL.PngImagePlugin.PngImageFile) or img.mode == "RGBA": img.save(bytesio, format="PNG") mime_type = "image/png" else: @@ -80,10 +80,10 @@ def pil_to_blob(img): mime_type = "image/jpeg" bytesio.seek(0) data = bytesio.read() - return glm.Blob(mime_type=mime_type, data=data) + return protos.Blob(mime_type=mime_type, data=data) -def image_to_blob(image) -> glm.Blob: +def image_to_blob(image) -> protos.Blob: if PIL is not None: if isinstance(image, PIL.Image.Image): return pil_to_blob(image) @@ -93,21 +93,20 @@ def image_to_blob(image) -> glm.Blob: name = image.filename if name is None: raise ValueError( - "Can only convert `IPython.display.Image` if " - "it is constructed from a local file (Image(filename=...))." + "Conversion failed. The `IPython.display.Image` can only be converted if " + "it is constructed from a local file. Please ensure you are using the format: Image(filename='...')." ) - mime_type, _ = mimetypes.guess_type(name) if mime_type is None: mime_type = "image/unknown" - return glm.Blob(mime_type=mime_type, data=image.data) + return protos.Blob(mime_type=mime_type, data=image.data) raise TypeError( - "Could not convert image. expected an `Image` type" - "(`PIL.Image.Image` or `IPython.display.Image`).\n" - f"Got a: {type(image)}\n" - f"Value: {image}" + "Image conversion failed. The input was expected to be of type `Image` " + "(either `PIL.Image.Image` or `IPython.display.Image`).\n" + f"However, received an object of type: {type(image)}.\n" + f"Object Value: {image}" ) @@ -116,30 +115,30 @@ class BlobDict(TypedDict): data: bytes -def _convert_dict(d: Mapping) -> glm.Content | glm.Part | glm.Blob: +def _convert_dict(d: Mapping) -> protos.Content | protos.Part | protos.Blob: if is_content_dict(d): content = dict(d) if isinstance(parts := content["parts"], str): content["parts"] = [parts] content["parts"] = [to_part(part) for part in content["parts"]] - return glm.Content(content) + return protos.Content(content) elif is_part_dict(d): part = dict(d) if "inline_data" in part: part["inline_data"] = to_blob(part["inline_data"]) if "file_data" in part: - part["file_data"] = to_file_data(part["file_data"]) - return glm.Part(part) + part["file_data"] = file_types.to_file_data(part["file_data"]) + return protos.Part(part) elif is_blob_dict(d): blob = d - return glm.Blob(blob) + return protos.Blob(blob) else: raise KeyError( - "Could not recognize the intended type of the `dict`. " - "A `Content` should have a 'parts' key. " - "A `Part` should have a 'inline_data' or a 'text' key. " - "A `Blob` should have 'mime_type' and 'data' keys. " - f"Got keys: {list(d.keys())}" + "Unable to determine the intended type of the `dict`. " + "For `Content`, a 'parts' key is expected. " + "For `Part`, either an 'inline_data' or a 'text' key is expected. " + "For `Blob`, both 'mime_type' and 'data' keys are expected. " + f"However, the provided dictionary has the following keys: {list(d.keys())}" ) @@ -149,17 +148,17 @@ def is_blob_dict(d): if typing.TYPE_CHECKING: BlobType = Union[ - glm.Blob, BlobDict, PIL.Image.Image, IPython.display.Image + protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image ] # Any for the images else: - BlobType = Union[glm.Blob, BlobDict, Any] + BlobType = Union[protos.Blob, BlobDict, Any] -def to_blob(blob: BlobType) -> glm.Blob: +def to_blob(blob: BlobType) -> protos.Blob: if isinstance(blob, Mapping): blob = _convert_dict(blob) - if isinstance(blob, glm.Blob): + if isinstance(blob, protos.Blob): return blob elif isinstance(blob, IMAGE_TYPES): return image_to_blob(blob) @@ -176,43 +175,21 @@ def to_blob(blob: BlobType) -> glm.Blob: ) -class FileDataDict(TypedDict): - mime_type: str - file_uri: str - - -FileDataType = Union[FileDataDict, glm.FileData, file_types.File] - - -def to_file_data(file_data: FileDataType): - if isinstance(file_data, dict): - if "file_uri" in file_data: - file_data = glm.FileData(file_data) - else: - file_data = glm.File(file_data) - - if isinstance(file_data, file_types.File): - file_data = file_data.to_proto() - - if isinstance(file_data, (glm.File, file_types.File)): - file_data = glm.FileData( - mime_type=file_data.mime_type, - file_uri=file_data.uri, - ) - - if isinstance(file_data, glm.FileData): - return file_data - else: - raise TypeError(f"Could not convert a {type(file_data)} to `FileData`") - - class PartDict(TypedDict): text: str inline_data: BlobType # When you need a `Part` accept a part object, part-dict, blob or string -PartType = Union[glm.Part, PartDict, BlobType, str, glm.FunctionCall, glm.FunctionResponse] +PartType = Union[ + protos.Part, + PartDict, + BlobType, + str, + protos.FunctionCall, + protos.FunctionResponse, + file_types.FileDataType, +] def is_part_dict(d): @@ -229,22 +206,22 @@ def to_part(part: PartType): if isinstance(part, Mapping): part = _convert_dict(part) - if isinstance(part, glm.Part): + if isinstance(part, protos.Part): return part elif isinstance(part, str): - return glm.Part(text=part) - elif isinstance(part, glm.FileData): - return glm.Part(file_data=part) - elif isinstance(part, (glm.File, file_types.File)): - return glm.Part(file_data=to_file_data(part)) - elif isinstance(part, glm.FunctionCall): - return glm.Part(function_call=part) - elif isinstance(part, glm.FunctionResponse): - return glm.Part(function_response=part) + return protos.Part(text=part) + elif isinstance(part, protos.FileData): + return protos.Part(file_data=part) + elif isinstance(part, (protos.File, file_types.File)): + return protos.Part(file_data=file_types.to_file_data(part)) + elif isinstance(part, protos.FunctionCall): + return protos.Part(function_call=part) + elif isinstance(part, protos.FunctionResponse): + return protos.Part(function_response=part) else: # Maybe it can be turned into a blob? - return glm.Part(inline_data=to_blob(part)) + return protos.Part(inline_data=to_blob(part)) class ContentDict(TypedDict): @@ -258,46 +235,48 @@ def is_content_dict(d): # When you need a message accept a `Content` object or dict, a list of parts, # or a single part -ContentType = Union[glm.Content, ContentDict, Iterable[PartType], PartType] +ContentType = Union[protos.Content, ContentDict, Iterable[PartType], PartType] # For generate_content, we're not guessing roles for [[parts],[parts],[parts]] yet. -StrictContentType = Union[glm.Content, ContentDict] +StrictContentType = Union[protos.Content, ContentDict] def to_content(content: ContentType): if not content: - raise ValueError("content must not be empty") + raise ValueError( + "Invalid input: 'content' argument must not be empty. Please provide a non-empty value." + ) if isinstance(content, Mapping): content = _convert_dict(content) - if isinstance(content, glm.Content): + if isinstance(content, protos.Content): return content elif isinstance(content, Iterable) and not isinstance(content, str): - return glm.Content(parts=[to_part(part) for part in content]) + return protos.Content(parts=[to_part(part) for part in content]) else: # Maybe this is a Part? - return glm.Content(parts=[to_part(content)]) + return protos.Content(parts=[to_part(content)]) def strict_to_content(content: StrictContentType): if isinstance(content, Mapping): content = _convert_dict(content) - if isinstance(content, glm.Content): + if isinstance(content, protos.Content): return content else: raise TypeError( - "Expected a `glm.Content` or a `dict(parts=...)`.\n" - f"Got type: {type(content)}\n" - f"Value: {content}\n" + "Invalid input type. Expected a `protos.Content` or a `dict` with a 'parts' key.\n" + f"However, received an object of type: {type(content)}.\n" + f"Object Value: {content}" ) ContentsType = Union[ContentType, Iterable[StrictContentType], None] -def to_contents(contents: ContentsType) -> list[glm.Content]: +def to_contents(contents: ContentsType) -> list[protos.Content]: if contents is None: return [] @@ -477,14 +456,20 @@ def convert_to_nullable(schema): anyof = schema.pop("anyOf", None) if anyof is not None: if len(anyof) != 2: - raise ValueError("Type Unions are not supported (except for Optional)") + raise ValueError( + "Invalid input: Type Unions are not supported, except for `Optional` types. " + "Please provide an `Optional` type or a non-Union type." + ) a, b = anyof if a == {"type": "null"}: schema.update(b) elif b == {"type": "null"}: schema.update(a) else: - raise ValueError("Type Unions are not supported (except for Optional)") + raise ValueError( + "Invalid input: Type Unions are not supported, except for `Optional` types. " + "Please provide an `Optional` type or a non-Union type." + ) schema["nullable"] = True properties = schema.get("properties", None) @@ -524,8 +509,8 @@ def _rename_schema_fields(schema): class FunctionDeclaration: def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): - """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" - self._proto = glm.FunctionDeclaration( + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( name=name, description=description, parameters=_rename_schema_fields(parameters) ) @@ -538,7 +523,7 @@ def description(self) -> str: return self._proto.description @property - def parameters(self) -> glm.Schema: + def parameters(self) -> protos.Schema: return self._proto.parameters @classmethod @@ -547,7 +532,7 @@ def from_proto(cls, proto) -> FunctionDeclaration: self._proto = proto return self - def to_proto(self) -> glm.FunctionDeclaration: + def to_proto(self) -> protos.FunctionDeclaration: return self._proto @staticmethod @@ -593,16 +578,16 @@ def __init__( super().__init__(name=name, description=description, parameters=parameters) self.function = function - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: result = self.function(**fc.args) if not isinstance(result, dict): result = {"result": result} - return glm.FunctionResponse(name=fc.name, response=result) + return protos.FunctionResponse(name=fc.name, response=result) FunctionDeclarationType = Union[ FunctionDeclaration, - glm.FunctionDeclaration, + protos.FunctionDeclaration, dict[str, Any], Callable[..., Any], ] @@ -610,8 +595,8 @@ def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: def _make_function_declaration( fun: FunctionDeclarationType, -) -> FunctionDeclaration | glm.FunctionDeclaration: - if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): return fun elif isinstance(fun, dict): if "function" in fun: @@ -622,20 +607,21 @@ def _make_function_declaration( return CallableFunctionDeclaration.from_function(fun) else: raise TypeError( - "Expected an instance of `genai.FunctionDeclaraionType`. Got a:\n" f" {type(fun)=}\n", - fun, + "Invalid input type. Expected an instance of `genai.FunctionDeclarationType`.\n" + f"However, received an object of type: {type(fun)}.\n" + f"Object Value: {fun}" ) -def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: - if isinstance(fd, glm.FunctionDeclaration): +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): return fd return fd.to_proto() class Tool: - """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): # The main path doesn't use this but is seems useful. @@ -647,23 +633,23 @@ def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): raise ValueError("") self._index[fd.name] = fd - self._proto = glm.Tool( + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations] ) @property - def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: declaration = self[fc] if not callable(declaration): return None @@ -679,21 +665,21 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] def _make_tool(tool: ToolType) -> Tool: if isinstance(tool, Tool): return tool - elif isinstance(tool, glm.Tool): + elif isinstance(tool, protos.Tool): return Tool(function_declarations=tool.function_declarations) elif isinstance(tool, dict): if "function_declarations" in tool: return Tool(**tool) else: fd = tool - return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -701,8 +687,9 @@ def _make_tool(tool: ToolType) -> Tool: return Tool(function_declarations=[tool]) except Exception as e: raise TypeError( - "Expected an instance of `genai.ToolType`. Got a:\n" f" {type(tool)=}", - tool, + "Invalid input type. Expected an instance of `genai.ToolType`.\n" + f"However, received an object of type: {type(tool)}.\n" + f"Object Value: {tool}" ) from e @@ -718,26 +705,26 @@ def __init__(self, tools: Iterable[ToolType]): name = declaration.name if name in self._index: raise ValueError( - f"A `FunctionDeclaration` named {name} is already defined. " - "Each `FunctionDeclaration` must be uniquely named." + f"Invalid operation: A `FunctionDeclaration` named '{name}' is already defined. " + "Each `FunctionDeclaration` must have a unique name. Please use a different name." ) self._index[declaration.name] = declaration def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: declaration = self[fc] if not callable(declaration): return None response = declaration(fc) - return glm.Part(function_response=response) + return protos.Part(function_response=response) def to_proto(self): return [tool.to_proto() for tool in self._tools] @@ -770,7 +757,7 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No return FunctionLibrary(tools=lib) -FunctionCallingMode = glm.FunctionCallingConfig.Mode +FunctionCallingMode = protos.FunctionCallingConfig.Mode # fmt: off _FUNCTION_CALLING_MODE = { @@ -806,12 +793,12 @@ class FunctionCallingConfigDict(TypedDict): FunctionCallingConfigType = Union[ - FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig ] -def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, glm.FunctionCallingConfig): +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): return obj elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} @@ -821,29 +808,32 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Could not convert input to `glm.FunctionCallingConfig`: \n'" f" type: {type(obj)}\n", - obj, + "Invalid input type. Failed to convert input to `protos.FunctionCallingConfig`.\n" + f"Received an object of type: {type(obj)}.\n" + f"Object Value: {obj}" ) - return glm.FunctionCallingConfig(obj) + return protos.FunctionCallingConfig(obj) class ToolConfigDict: function_calling_config: FunctionCallingConfigType -ToolConfigType = Union[ToolConfigDict, glm.ToolConfig] +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] -def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: - if isinstance(obj, glm.ToolConfig): +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): return obj elif isinstance(obj, dict): fcc = obj.pop("function_calling_config") fcc = to_function_calling_config(fcc) obj["function_calling_config"] = fcc - return glm.ToolConfig(**obj) + return protos.ToolConfig(**obj) else: raise TypeError( - f"Could not convert input to `glm.ToolConfig`: \n'" f" type: {type(obj)}\n", obj + "Invalid input type. Failed to convert input to `protos.ToolConfig`.\n" + f"Received an object of type: {type(obj)}.\n" + f"Object Value: {obj}" ) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index 0cb393e5c..a538da65c 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -19,10 +19,10 @@ from typing import Any, Dict, Union, Iterable, Optional, Tuple, List from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import citation_types @@ -46,15 +46,15 @@ class TokenCount(TypedDict): class MessageDict(TypedDict): - """A dict representation of a `glm.Message`.""" + """A dict representation of a `protos.Message`.""" author: str content: str citation_metadata: Optional[citation_types.CitationMetadataDict] -MessageOptions = Union[str, MessageDict, glm.Message] -MESSAGE_OPTIONS = (str, dict, glm.Message) +MessageOptions = Union[str, MessageDict, protos.Message] +MESSAGE_OPTIONS = (str, dict, protos.Message) MessagesOptions = Union[ MessageOptions, @@ -64,7 +64,7 @@ class MessageDict(TypedDict): class ExampleDict(TypedDict): - """A dict representation of a `glm.Example`.""" + """A dict representation of a `protos.Example`.""" input: MessageOptions output: MessageOptions @@ -74,14 +74,14 @@ class ExampleDict(TypedDict): Tuple[MessageOptions, MessageOptions], Iterable[MessageOptions], ExampleDict, - glm.Example, + protos.Example, ] -EXAMPLE_OPTIONS = (glm.Example, dict, Iterable) +EXAMPLE_OPTIONS = (protos.Example, dict, Iterable) ExamplesOptions = Union[ExampleOptions, Iterable[ExampleOptions]] class MessagePromptDict(TypedDict, total=False): - """A dict representation of a `glm.MessagePrompt`.""" + """A dict representation of a `protos.MessagePrompt`.""" context: str examples: ExamplesOptions @@ -90,16 +90,16 @@ class MessagePromptDict(TypedDict, total=False): MessagePromptOptions = Union[ str, - glm.Message, - Iterable[Union[str, glm.Message]], + protos.Message, + Iterable[Union[str, protos.Message]], MessagePromptDict, - glm.MessagePrompt, + protos.MessagePrompt, ] MESSAGE_PROMPT_KEYS = {"context", "examples", "messages"} class ResponseDict(TypedDict): - """A dict representation of a `glm.GenerateMessageResponse`.""" + """A dict representation of a `protos.GenerateMessageResponse`.""" messages: List[MessageDict] candidates: List[MessageDict] @@ -169,7 +169,7 @@ class ChatResponse(abc.ABC): temperature: Optional[float] candidate_count: Optional[int] candidates: List[MessageDict] - filters: List[safety_types.ContentFilterDict] + filters: List[palm_safety_types.ContentFilterDict] top_p: Optional[float] = None top_k: Optional[float] = None diff --git a/google/generativeai/types/file_types.py b/google/generativeai/types/file_types.py index d18404871..ef251e296 100644 --- a/google/generativeai/types/file_types.py +++ b/google/generativeai/types/file_types.py @@ -15,19 +15,22 @@ from __future__ import annotations import datetime +from typing import Union +from typing_extensions import TypedDict +from google.rpc.status_pb2 import Status from google.generativeai.client import get_default_file_client -import google.ai.generativelanguage as glm +from google.generativeai import protos class File: - def __init__(self, proto: glm.File | File | dict): + def __init__(self, proto: protos.File | File | dict): if isinstance(proto, File): proto = proto.to_proto() - self._proto = glm.File(proto) + self._proto = protos.File(proto) - def to_proto(self): + def to_proto(self) -> protos.File: return self._proto @property @@ -67,9 +70,51 @@ def uri(self) -> str: return self._proto.uri @property - def state(self) -> glm.File.State: + def state(self) -> protos.File.State: return self._proto.state + @property + def video_metadata(self) -> protos.VideoMetadata: + return self._proto.video_metadata + + @property + def error(self) -> Status: + return self._proto.error + def delete(self): client = get_default_file_client() client.delete_file(name=self.name) + + +class FileDataDict(TypedDict): + mime_type: str + file_uri: str + + +FileDataType = Union[FileDataDict, protos.FileData, protos.File, File] + + +def to_file_data(file_data: FileDataType): + if isinstance(file_data, dict): + if "file_uri" in file_data: + file_data = protos.FileData(file_data) + else: + file_data = protos.File(file_data) + + if isinstance(file_data, File): + file_data = file_data.to_proto() + + if isinstance(file_data, protos.File): + file_data = protos.FileData( + mime_type=file_data.mime_type, + file_uri=file_data.uri, + ) + + if isinstance(file_data, protos.FileData): + return file_data + else: + raise TypeError( + f"Invalid input type. Failed to convert input to `FileData`.\n" + f"Received an object of type: {type(file_data)}.\n" + f"Object Value: {file_data}" + ) diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index b7a342b37..20686a156 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -30,7 +30,7 @@ import google.protobuf.json_format import google.api_core.exceptions -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils from google.generativeai.types import content_types from google.generativeai.responder import _rename_schema_fields @@ -85,7 +85,7 @@ class GenerationConfigDict(TypedDict, total=False): max_output_tokens: int temperature: float response_mime_type: str - response_schema: glm.Schema | Mapping[str, Any] # fmt: off + response_schema: protos.Schema | Mapping[str, Any] # fmt: off @dataclasses.dataclass @@ -165,19 +165,19 @@ class GenerationConfig: top_p: float | None = None top_k: int | None = None response_mime_type: str | None = None - response_schema: glm.Schema | Mapping[str, Any] | None = None + response_schema: protos.Schema | Mapping[str, Any] | None = None -GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig] +GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig] def _normalize_schema(generation_config): - # Convert response_schema to glm.Schema for request + # Convert response_schema to protos.Schema for request response_schema = generation_config.get("response_schema", None) if response_schema is None: return - if isinstance(response_schema, glm.Schema): + if isinstance(response_schema, protos.Schema): return if isinstance(response_schema, type): @@ -185,19 +185,19 @@ def _normalize_schema(generation_config): elif isinstance(response_schema, types.GenericAlias): if not str(response_schema).startswith("list["): raise ValueError( - f"Could not understand {response_schema}, expected: `int`, `float`, `str`, `bool`, " - "`typing_extensions.TypedDict`, `dataclass`, or `list[...]`" + f"Invalid input: Could not understand the type of '{response_schema}'. " + "Expected one of the following types: `int`, `float`, `str`, `bool`, `typing_extensions.TypedDict`, `dataclass`, or `list[...]`." ) response_schema = content_types._schema_for_class(response_schema) response_schema = _rename_schema_fields(response_schema) - generation_config["response_schema"] = glm.Schema(response_schema) + generation_config["response_schema"] = protos.Schema(response_schema) def to_generation_config_dict(generation_config: GenerationConfigType): if generation_config is None: return {} - elif isinstance(generation_config, glm.GenerationConfig): + elif isinstance(generation_config, protos.GenerationConfig): schema = generation_config.response_schema generation_config = type(generation_config).to_dict( generation_config @@ -214,21 +214,21 @@ def to_generation_config_dict(generation_config: GenerationConfigType): return generation_config else: raise TypeError( - "Did not understand `generation_config`, expected a `dict` or" - f" `GenerationConfig`\nGot type: {type(generation_config)}\nValue:" - f" {generation_config}" + "Invalid input type. Expected a `dict` or `GenerationConfig` for `generation_config`.\n" + f"However, received an object of type: {type(generation_config)}.\n" + f"Object Value: {generation_config}" ) def _join_citation_metadatas( - citation_metadatas: Iterable[glm.CitationMetadata], + citation_metadatas: Iterable[protos.CitationMetadata], ): citation_metadatas = list(citation_metadatas) return citation_metadatas[-1] def _join_safety_ratings_lists( - safety_ratings_lists: Iterable[list[glm.SafetyRating]], + safety_ratings_lists: Iterable[list[protos.SafetyRating]], ): ratings = {} blocked = collections.defaultdict(list) @@ -243,13 +243,13 @@ def _join_safety_ratings_lists( safety_list = [] for (category, probability), blocked in zip(ratings.items(), blocked.values()): safety_list.append( - glm.SafetyRating(category=category, probability=probability, blocked=blocked) + protos.SafetyRating(category=category, probability=probability, blocked=blocked) ) return safety_list -def _join_contents(contents: Iterable[glm.Content]): +def _join_contents(contents: Iterable[protos.Content]): contents = tuple(contents) roles = [c.role for c in contents if c.role] if roles: @@ -271,22 +271,22 @@ def _join_contents(contents: Iterable[glm.Content]): merged_parts.append(part) continue - merged_part = glm.Part(merged_parts[-1]) + merged_part = protos.Part(merged_parts[-1]) merged_part.text += part.text merged_parts[-1] = merged_part - return glm.Content( + return protos.Content( role=role, parts=merged_parts, ) -def _join_candidates(candidates: Iterable[glm.Candidate]): +def _join_candidates(candidates: Iterable[protos.Candidate]): candidates = tuple(candidates) index = candidates[0].index # These should all be the same. - return glm.Candidate( + return protos.Candidate( index=index, content=_join_contents([c.content for c in candidates]), finish_reason=candidates[-1].finish_reason, @@ -296,7 +296,7 @@ def _join_candidates(candidates: Iterable[glm.Candidate]): ) -def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): +def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]): # Assuming that is a candidate ends, it is no longer returned in the list of # candidates and that's why candidates have an index candidates = collections.defaultdict(list) @@ -312,15 +312,15 @@ def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): def _join_prompt_feedbacks( - prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback], + prompt_feedbacks: Iterable[protos.GenerateContentResponse.PromptFeedback], ): # Always return the first prompt feedback. return next(iter(prompt_feedbacks)) -def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]): +def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]): chunks = tuple(chunks) - return glm.GenerateContentResponse( + return protos.GenerateContentResponse( candidates=_join_candidate_lists(c.candidates for c in chunks), prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), usage_metadata=chunks[-1].usage_metadata, @@ -338,11 +338,11 @@ def __init__( done: bool, iterator: ( None - | Iterable[glm.GenerateContentResponse] - | AsyncIterable[glm.GenerateContentResponse] + | Iterable[protos.GenerateContentResponse] + | AsyncIterable[protos.GenerateContentResponse] ), - result: glm.GenerateContentResponse, - chunks: Iterable[glm.GenerateContentResponse] | None = None, + result: protos.GenerateContentResponse, + chunks: Iterable[protos.GenerateContentResponse] | None = None, ): self._done = done self._iterator = iterator @@ -356,6 +356,18 @@ def __init__( else: self._error = None + def to_dict(self): + """Returns the result as a JSON-compatible dict. + + Note: This doesn't capture the iterator state when streaming, it only captures the accumulated + `GenerateContentResponse` fields. + + >>> import json + >>> response = model.generate_content('Hello?') + >>> json.dumps(response.to_dict()) + """ + return type(self._result).to_dict(self._result) + @property def candidates(self): """The list of candidate responses. @@ -377,14 +389,13 @@ def parts(self): candidates = self.candidates if not candidates: raise ValueError( - "The `response.parts` quick accessor only works for a single candidate, " - "but none were returned. Check the `response.prompt_feedback` to see if the prompt was blocked." + "Invalid operation: The `response.parts` quick accessor requires a single candidate, " + "but none were returned. Please check the `response.prompt_feedback` to determine if the prompt was blocked." ) if len(candidates) > 1: raise ValueError( - "The `response.parts` quick accessor only works with a " - "single candidate. With multiple candidates use " - "result.candidates[index].text" + "Invalid operation: The `response.parts` quick accessor requires a single candidate. " + "For multiple candidates, please use `result.candidates[index].text`." ) parts = candidates[0].content.parts return parts @@ -399,18 +410,14 @@ def text(self): parts = self.parts if not parts: raise ValueError( - "The `response.text` quick accessor only works when the response contains a valid " - "`Part`, but none was returned. Check the `candidate.safety_ratings` to see if the " - "response was blocked." + "Invalid operation: The `response.text` quick accessor requires the response to contain a valid `Part`, " + "but none were returned. Please check the `candidate.safety_ratings` to determine if the response was blocked." ) - if len(parts) != 1 or "text" not in parts[0]: raise ValueError( - "The `response.text` quick accessor only works for " - "simple (single-`Part`) text responses. This response is not simple text. " - "Use the `result.parts` accessor or the full " - "`result.candidates[index].content.parts` lookup " - "instead." + "Invalid operation: The `response.text` quick accessor requires a simple (single-`Part`) text response. " + "This response is not simple text. Please use the `result.parts` accessor or the full " + "`result.candidates[index].content.parts` lookup instead." ) return parts[0].text @@ -428,10 +435,12 @@ def __str__(self) -> str: else: _iterator = f"<{self._iterator.__class__.__name__}>" - as_dict = type(self._result).to_dict(self._result) + as_dict = type(self._result).to_dict( + self._result, use_integers_for_enums=False, including_default_value_fields=False + ) json_str = json.dumps(as_dict, indent=2) - _result = f"glm.GenerateContentResponse({json_str})" + _result = f"protos.GenerateContentResponse({json_str})" _result = _result.replace("\n", "\n ") if self._error: @@ -469,7 +478,7 @@ def rewrite_stream_error(): GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method. These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`. - This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback` + This object is based on the low level `protos.GenerateContentResponse` class which just has `prompt_feedback` and `candidates` attributes. This class adds several quick accessors for common use cases. The same object type is returned for both `stream=True/False`. @@ -498,7 +507,7 @@ def rewrite_stream_error(): @string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC) class GenerateContentResponse(BaseGenerateContentResponse): @classmethod - def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): + def from_iterator(cls, iterator: Iterable[protos.GenerateContentResponse]): iterator = iter(iterator) with rewrite_stream_error(): response = next(iterator) @@ -510,7 +519,7 @@ def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): ) @classmethod - def from_response(cls, response: glm.GenerateContentResponse): + def from_response(cls, response: protos.GenerateContentResponse): return cls( done=True, iterator=None, @@ -565,7 +574,7 @@ def resolve(self): @string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC) class AsyncGenerateContentResponse(BaseGenerateContentResponse): @classmethod - async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]): + async def from_aiterator(cls, iterator: AsyncIterable[protos.GenerateContentResponse]): iterator = aiter(iterator) # type: ignore with rewrite_stream_error(): response = await anext(iterator) # type: ignore @@ -577,7 +586,7 @@ async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentRespons ) @classmethod - def from_response(cls, response: glm.GenerateContentResponse): + def from_response(cls, response: protos.GenerateContentResponse): return cls( done=True, iterator=None, diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py new file mode 100644 index 000000000..fd8c1882b --- /dev/null +++ b/google/generativeai/types/helper_types.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import google.api_core.timeout +import google.api_core.retry + +import collections +import dataclasses + +from typing import Union +from typing_extensions import TypedDict + +__all__ = ["RequestOptions", "RequestOptionsType"] + + +class RequestOptionsDict(TypedDict, total=False): + retry: google.api_core.retry.Retry + timeout: Union[int, float, google.api_core.timeout.TimeToDeadlineTimeout] + + +@dataclasses.dataclass(init=False) +class RequestOptions(collections.abc.Mapping): + """Request options + + >>> import google.generativeai as genai + >>> from google.generativeai.types import RequestOptions + >>> from google.api_core import retry + >>> + >>> model = genai.GenerativeModel() + >>> response = model.generate_content('Hello', + ... request_options=RequestOptions( + ... retry=retry.Retry(initial=10, multiplier=2, maximum=60, timeout=300))) + >>> response = model.generate_content('Hello', + ... request_options=RequestOptions(timeout=600))) + + Args: + retry: Refer to [retry docs](https://googleapis.dev/python/google-api-core/latest/retry.html) for details. + timeout: In seconds (or provide a [TimeToDeadlineTimeout](https://googleapis.dev/python/google-api-core/latest/timeout.html) object). + """ + + retry: google.api_core.retry.Retry | None + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None + + def __init__( + self, + *, + retry: google.api_core.retry.Retry | None = None, + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None = None, + ): + self.retry = retry + self.timeout = timeout + + # Inherit from Mapping for **unpacking + def __getitem__(self, item): + if item == "retry": + return self.retry + elif item == "timeout": + return self.timeout + else: + raise KeyError( + f"Invalid key: 'RequestOptions' does not contain a key named '{item}'. " + "Please use a valid key." + ) + + def __iter__(self): + yield "retry" + yield "timeout" + + def __len__(self): + return 2 + + +RequestOptionsType = Union[RequestOptions, RequestOptionsDict] diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 0f85acfe8..81a545b30 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -28,7 +28,7 @@ import urllib.request from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import permission_types from google.generativeai import string_utils @@ -44,7 +44,7 @@ "TunedModelState", ] -TunedModelState = glm.TunedModel.State +TunedModelState = protos.TunedModel.State TunedModelStateOptions = Union[None, str, int, TunedModelState] @@ -91,7 +91,7 @@ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState: @string_utils.prettyprint @dataclasses.dataclass class Model: - """A dataclass representation of a `glm.Model`. + """A dataclass representation of a `protos.Model`. Attributes: name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming @@ -140,8 +140,8 @@ def idecode_time(parent: dict["str", Any], name: str): parent[name] = dt -def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedModel: - if isinstance(tuned_model, glm.TunedModel): +def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel: + if isinstance(tuned_model, protos.TunedModel): tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None)) @@ -180,7 +180,7 @@ def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedM @string_utils.prettyprint @dataclasses.dataclass class TunedModel: - """A dataclass representation of a `glm.TunedModel`.""" + """A dataclass representation of a `protos.TunedModel`.""" name: str | None = None source_model: str | None = None @@ -214,13 +214,13 @@ class TuningExampleDict(TypedDict): output: str -TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str], list[str]] +TuningExampleOptions = Union[TuningExampleDict, protos.TuningExample, tuple[str, str], list[str]] # TODO(markdaoust): gs:// URLS? File-type argument for files without extension? TuningDataOptions = Union[ pathlib.Path, str, - glm.Dataset, + protos.Dataset, Mapping[str, Iterable[str]], Iterable[TuningExampleOptions], ] @@ -228,8 +228,8 @@ class TuningExampleDict(TypedDict): def encode_tuning_data( data: TuningDataOptions, input_key="text_input", output_key="output" -) -> glm.Dataset: - if isinstance(data, glm.Dataset): +) -> protos.Dataset: + if isinstance(data, protos.Dataset): return data if isinstance(data, str): @@ -287,16 +287,22 @@ def _convert_dict(data, input_key, output_key): try: inputs = data[input_key] except KeyError: - raise KeyError(f'input_key is "{input_key}", but data has keys: {sorted(data.keys())}') + raise KeyError( + f"Invalid key: The input key '{input_key}' does not exist in the data. " + f"Available keys are: {sorted(data.keys())}." + ) try: outputs = data[output_key] except KeyError: - raise KeyError(f'output_key is "{output_key}", but data has keys: {sorted(data.keys())}') + raise KeyError( + f"Invalid key: The output key '{output_key}' does not exist in the data. " + f"Available keys are: {sorted(data.keys())}." + ) for i, o in zip(inputs, outputs): - new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)})) - return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) + new_data.append(protos.TuningExample({"text_input": str(i), "output": str(o)})) + return protos.Dataset(examples=protos.TuningExamples(examples=new_data)) def _convert_iterable(data, input_key, output_key): @@ -304,17 +310,17 @@ def _convert_iterable(data, input_key, output_key): for example in data: example = encode_tuning_example(example, input_key, output_key) new_data.append(example) - return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) + return protos.Dataset(examples=protos.TuningExamples(examples=new_data)) def encode_tuning_example(example: TuningExampleOptions, input_key, output_key): - if isinstance(example, glm.TuningExample): + if isinstance(example, protos.TuningExample): return example elif isinstance(example, (tuple, list)): a, b = example - example = glm.TuningExample(text_input=a, output=b) + example = protos.TuningExample(text_input=a, output=b) else: # dict - example = glm.TuningExample(text_input=example[input_key], output=example[output_key]) + example = protos.TuningExample(text_input=example[input_key], output=example[output_key]) return example @@ -335,22 +341,26 @@ class Hyperparameters: learning_rate: float = 0.0 -BaseModelNameOptions = Union[str, Model, glm.Model] -TunedModelNameOptions = Union[str, TunedModel, glm.TunedModel] -AnyModelNameOptions = Union[str, Model, glm.Model, TunedModel, glm.TunedModel] +BaseModelNameOptions = Union[str, Model, protos.Model] +TunedModelNameOptions = Union[str, TunedModel, protos.TunedModel] +AnyModelNameOptions = Union[str, Model, protos.Model, TunedModel, protos.TunedModel] ModelNameOptions = AnyModelNameOptions def make_model_name(name: AnyModelNameOptions): - if isinstance(name, (Model, glm.Model, TunedModel, glm.TunedModel)): + if isinstance(name, (Model, protos.Model, TunedModel, protos.TunedModel)): name = name.name # pytype: disable=attribute-error elif isinstance(name, str): name = name else: - raise TypeError("Expected: str, Model, or TunedModel") + raise TypeError( + "Invalid input type. Expected one of the following types: `str`, `Model`, or `TunedModel`." + ) if not (name.startswith("models/") or name.startswith("tunedModels/")): - raise ValueError(f"Model names should start with `models/` or `tunedModels/`, got: {name}") + raise ValueError( + f"Invalid model name: '{name}'. Model names should start with 'models/' or 'tunedModels/'." + ) return name @@ -362,7 +372,7 @@ def make_model_name(name: AnyModelNameOptions): @string_utils.prettyprint @dataclasses.dataclass class TokenCount: - """A dataclass representation of a `glm.TokenCountResponse`. + """A dataclass representation of a `protos.TokenCountResponse`. Attributes: token_count: The number of tokens returned by the model's tokenizer for the `input_text`. diff --git a/google/generativeai/types/palm_safety_types.py b/google/generativeai/types/palm_safety_types.py new file mode 100644 index 000000000..0ab85e1b2 --- /dev/null +++ b/google/generativeai/types/palm_safety_types.py @@ -0,0 +1,286 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Mapping + +import enum +import typing +from typing import Dict, Iterable, List, Union + +from typing_extensions import TypedDict + + +from google.generativeai import protos +from google.generativeai import string_utils + + +__all__ = [ + "HarmCategory", + "HarmProbability", + "HarmBlockThreshold", + "BlockedReason", + "ContentFilterDict", + "SafetyRatingDict", + "SafetySettingDict", + "SafetyFeedbackDict", +] + +# These are basic python enums, it's okay to expose them +HarmProbability = protos.SafetyRating.HarmProbability +HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold +BlockedReason = protos.ContentFilter.BlockedReason + + +class HarmCategory: + """ + Harm Categories supported by the palm-family models + """ + + HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_DEROGATORY = protos.HarmCategory.HARM_CATEGORY_DEROGATORY.value + HARM_CATEGORY_TOXICITY = protos.HarmCategory.HARM_CATEGORY_TOXICITY.value + HARM_CATEGORY_VIOLENCE = protos.HarmCategory.HARM_CATEGORY_VIOLENCE.value + HARM_CATEGORY_SEXUAL = protos.HarmCategory.HARM_CATEGORY_SEXUAL.value + HARM_CATEGORY_MEDICAL = protos.HarmCategory.HARM_CATEGORY_MEDICAL.value + HARM_CATEGORY_DANGEROUS = protos.HarmCategory.HARM_CATEGORY_DANGEROUS.value + + +HarmCategoryOptions = Union[str, int, HarmCategory] + +# fmt: off +_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = { + protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + protos.HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + 1: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + "harm_category_derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + "derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + + protos.HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + 2: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "harm_category_toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxic": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + + protos.HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + 3: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "harm_category_violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violent": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + + protos.HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + 4: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "sex": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + + protos.HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + 5: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "harm_category_medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "med": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + + protos.HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + 6: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, +} +# fmt: on + + +def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory: + if isinstance(x, str): + x = x.lower() + return _HARM_CATEGORIES[x] + + +HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold] + +# fmt: off +_BLOCK_THRESHOLDS: Dict[HarmBlockThresholdOptions, HarmBlockThreshold] = { + HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + 0: HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "harm_block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "block_threshold_unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + "unspecified": HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED, + + HarmBlockThreshold.BLOCK_LOW_AND_ABOVE: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + 1: HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "block_low_and_above": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "low": HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + + HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + 2: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "block_medium_and_above": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "medium": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + "med": HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + + HarmBlockThreshold.BLOCK_ONLY_HIGH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + 3: HarmBlockThreshold.BLOCK_ONLY_HIGH, + "block_only_high": HarmBlockThreshold.BLOCK_ONLY_HIGH, + "high": HarmBlockThreshold.BLOCK_ONLY_HIGH, + + HarmBlockThreshold.BLOCK_NONE: HarmBlockThreshold.BLOCK_NONE, + 4: HarmBlockThreshold.BLOCK_NONE, + "block_none": HarmBlockThreshold.BLOCK_NONE, +} +# fmt: on + + +def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold: + if isinstance(x, str): + x = x.lower() + return _BLOCK_THRESHOLDS[x] + + +class ContentFilterDict(TypedDict): + reason: BlockedReason + message: str + + __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__) + + +def convert_filters_to_enums( + filters: Iterable[dict], +) -> List[ContentFilterDict]: + result = [] + for f in filters: + f = f.copy() + f["reason"] = BlockedReason(f["reason"]) + f = typing.cast(ContentFilterDict, f) + result.append(f) + return result + + +class SafetyRatingDict(TypedDict): + category: protos.HarmCategory + probability: HarmProbability + + __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__) + + +def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: + return { + "category": protos.HarmCategory(rating["category"]), + "probability": HarmProbability(rating["probability"]), + } + + +def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: + result = [] + for r in ratings: + result.append(convert_rating_to_enum(r)) + return result + + +class SafetySettingDict(TypedDict): + category: protos.HarmCategory + threshold: HarmBlockThreshold + + __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__) + + +class LooseSafetySettingDict(TypedDict): + category: HarmCategoryOptions + threshold: HarmBlockThresholdOptions + + +EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] +EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] + +SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] + + +def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict: + if settings is None: + return {} + elif isinstance(settings, Mapping): + return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()} + else: # Iterable + return { + to_harm_category(d["category"]): to_block_threshold(d["threshold"]) for d in settings + } + + +def normalize_safety_settings( + settings: SafetySettingOptions, +) -> list[SafetySettingDict] | None: + if settings is None: + return None + if isinstance(settings, Mapping): + return [ + { + "category": to_harm_category(key), + "threshold": to_block_threshold(value), + } + for key, value in settings.items() + ] + else: + return [ + { + "category": to_harm_category(d["category"]), + "threshold": to_block_threshold(d["threshold"]), + } + for d in settings + ] + + +def convert_setting_to_enum(setting: dict) -> SafetySettingDict: + return { + "category": protos.HarmCategory(setting["category"]), + "threshold": HarmBlockThreshold(setting["threshold"]), + } + + +class SafetyFeedbackDict(TypedDict): + rating: SafetyRatingDict + setting: SafetySettingDict + + __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__) + + +def convert_safety_feedback_to_enums( + safety_feedback: Iterable[dict], +) -> List[SafetyFeedbackDict]: + result = [] + for sf in safety_feedback: + result.append( + { + "rating": convert_rating_to_enum(sf["rating"]), + "setting": convert_setting_to_enum(sf["setting"]), + } + ) + return result + + +def convert_candidate_enums(candidates): + result = [] + for candidate in candidates: + candidate = candidate.copy() + candidate["safety_ratings"] = convert_ratings_to_enum(candidate["safety_ratings"]) + result.append(candidate) + return result diff --git a/google/generativeai/types/permission_types.py b/google/generativeai/types/permission_types.py index ef9242999..1df831db0 100644 --- a/google/generativeai/types/permission_types.py +++ b/google/generativeai/types/permission_types.py @@ -19,17 +19,18 @@ import re import google.ai.generativelanguage as glm +from google.generativeai import protos from google.protobuf import field_mask_pb2 -from google.generativeai.client import get_dafault_permission_client -from google.generativeai.client import get_dafault_permission_async_client +from google.generativeai.client import get_default_permission_client +from google.generativeai.client import get_default_permission_async_client from google.generativeai.utils import flatten_update_paths from google.generativeai import string_utils -GranteeType = glm.Permission.GranteeType -Role = glm.Permission.Role +GranteeType = protos.Permission.GranteeType +Role = protos.Permission.Role GranteeTypeOptions = Union[str, int, GranteeType] RoleOptions = Union[str, int, Role] @@ -107,8 +108,8 @@ def delete( Delete permission (self). """ if client is None: - client = get_dafault_permission_client() - delete_request = glm.DeletePermissionRequest(name=self.name) + client = get_default_permission_client() + delete_request = protos.DeletePermissionRequest(name=self.name) client.delete_permission(request=delete_request) async def delete_async( @@ -119,8 +120,8 @@ async def delete_async( This is the async version of `Permission.delete`. """ if client is None: - client = get_dafault_permission_async_client() - delete_request = glm.DeletePermissionRequest(name=self.name) + client = get_default_permission_async_client() + delete_request = protos.DeletePermissionRequest(name=self.name) await client.delete_permission(request=delete_request) # TODO (magashe): Add a method to validate update value. As of now only `role` is supported as a mask path @@ -146,13 +147,13 @@ def update( `Permission` object with specified updates. """ if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() updates = flatten_update_paths(updates) for update_path in updates: if update_path != "role": raise ValueError( - f"As of now, only `role` can be updated for `Permission`. Got: `{update_path}` instead." + f"Invalid update path: '{update_path}'. Currently, only the 'role' attribute can be updated for 'Permission'." ) field_mask = field_mask_pb2.FieldMask() @@ -161,7 +162,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - update_request = glm.UpdatePermissionRequest( + update_request = protos.UpdatePermissionRequest( permission=self._to_proto(), update_mask=field_mask ) client.update_permission(request=update_request) @@ -176,13 +177,13 @@ async def update_async( This is the async version of `Permission.update`. """ if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() updates = flatten_update_paths(updates) for update_path in updates: if update_path != "role": raise ValueError( - f"As of now, only `role` can be updated for `Permission`. Got: `{update_path}` instead." + f"Invalid update path: '{update_path}'. Currently, only the 'role' attribute can be updated for 'Permission'." ) field_mask = field_mask_pb2.FieldMask() @@ -191,14 +192,14 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - update_request = glm.UpdatePermissionRequest( + update_request = protos.UpdatePermissionRequest( permission=self._to_proto(), update_mask=field_mask ) await client.update_permission(request=update_request) return self - def _to_proto(self) -> glm.Permission: - return glm.Permission( + def _to_proto(self) -> protos.Permission: + return protos.Permission( name=self.name, role=self.role, grantee_type=self.grantee_type, @@ -224,8 +225,8 @@ def get( Requested permission as an instance of `Permission`. """ if client is None: - client = get_dafault_permission_client() - get_perm_request = glm.GetPermissionRequest(name=name) + client = get_default_permission_client() + get_perm_request = protos.GetPermissionRequest(name=name) get_perm_response = client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) return cls(**get_perm_response) @@ -240,8 +241,8 @@ async def get_async( This is the async version of `Permission.get`. """ if client is None: - client = get_dafault_permission_async_client() - get_perm_request = glm.GetPermissionRequest(name=name) + client = get_default_permission_async_client() + get_perm_request = protos.GetPermissionRequest(name=name) get_perm_response = await client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) return cls(**get_perm_response) @@ -263,7 +264,7 @@ def _make_create_permission_request( role: RoleOptions, grantee_type: Optional[GranteeTypeOptions] = None, email_address: Optional[str] = None, - ) -> glm.CreatePermissionRequest: + ) -> protos.CreatePermissionRequest: role = to_role(role) if grantee_type: @@ -271,20 +272,19 @@ def _make_create_permission_request( if email_address and grantee_type == GranteeType.EVERYONE: raise ValueError( - f"Cannot limit access for: `{email_address}` when `grantee_type` is set to `EVERYONE`." + f"Invalid operation: Access cannot be limited for a specific email address ('{email_address}') when 'grantee_type' is set to 'EVERYONE'." ) - if not email_address and grantee_type != GranteeType.EVERYONE: raise ValueError( - f"`email_address` must be specified unless `grantee_type` is set to `EVERYONE`." + f"Invalid operation: An 'email_address' must be provided when 'grantee_type' is not set to 'EVERYONE'. Currently, 'grantee_type' is set to '{grantee_type}' and 'email_address' is '{email_address if email_address else 'not provided'}'." ) - permission = glm.Permission( + permission = protos.Permission( role=role, grantee_type=grantee_type, email_address=email_address, ) - return glm.CreatePermissionRequest( + return protos.CreatePermissionRequest( parent=self.parent, permission=permission, ) @@ -313,7 +313,7 @@ def create( ValueError: When email_address is not specified and grantee_type is not set to EVERYONE. """ if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() request = self._make_create_permission_request( role=role, grantee_type=grantee_type, email_address=email_address @@ -333,7 +333,7 @@ async def create_async( This is the async version of `PermissionAdapter.create_permission`. """ if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() request = self._make_create_permission_request( role=role, grantee_type=grantee_type, email_address=email_address @@ -358,9 +358,9 @@ def list( Paginated list of `Permission` objects. """ if client is None: - client = get_dafault_permission_client() + client = get_default_permission_client() - request = glm.ListPermissionsRequest( + request = protos.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error ) for permission in client.list_permissions(request): @@ -376,9 +376,9 @@ async def list_async( This is the async version of `PermissionAdapter.list_permissions`. """ if client is None: - client = get_dafault_permission_async_client() + client = get_default_permission_async_client() - request = glm.ListPermissionsRequest( + request = protos.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error ) async for permission in await client.list_permissions(request): @@ -400,8 +400,8 @@ def transfer_ownership( if self.parent.startswith("corpora"): raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: - client = get_dafault_permission_client() - transfer_request = glm.TransferOwnershipRequest( + client = get_default_permission_client() + transfer_request = protos.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) return client.transfer_ownership(request=transfer_request) @@ -415,8 +415,8 @@ async def transfer_ownership_async( if self.parent.startswith("corpora"): raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: - client = get_dafault_permission_async_client() - transfer_request = glm.TransferOwnershipRequest( + client = get_default_permission_async_client() + transfer_request = protos.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) return await client.transfer_ownership(request=transfer_request) diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index 72859f207..9931ee58d 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -22,11 +22,14 @@ from typing_extensions import deprecated # type: ignore import google.ai.generativelanguage as glm +from google.generativeai import protos from google.protobuf import field_mask_pb2 from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client from google.generativeai import string_utils +from google.generativeai.types import helper_types + from google.generativeai.types import permission_types from google.generativeai.types.model_types import idecode_time from google.generativeai.utils import flatten_update_paths @@ -42,14 +45,14 @@ def valid_name(name): return re.match(_VALID_NAME, name) and len(name) < 40 -Operator = glm.Condition.Operator -State = glm.Chunk.State +Operator = protos.Condition.Operator +State = protos.Chunk.State OperatorOptions = Union[str, int, Operator] StateOptions = Union[str, int, State] ChunkOptions = Union[ - glm.Chunk, + protos.Chunk, str, tuple[str, str], tuple[str, str, Any], @@ -57,17 +60,17 @@ def valid_name(name): ] # fmt: no BatchCreateChunkOptions = Union[ - glm.BatchCreateChunksRequest, + protos.BatchCreateChunksRequest, Mapping[str, str], Mapping[str, tuple[str, str]], Iterable[ChunkOptions], ] # fmt: no -UpdateChunkOptions = Union[glm.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] +UpdateChunkOptions = Union[protos.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] -BatchUpdateChunksOptions = Union[glm.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] +BatchUpdateChunksOptions = Union[protos.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] -BatchDeleteChunkOptions = Union[list[glm.DeleteChunkRequest], Iterable[str]] +BatchDeleteChunkOptions = Union[list[protos.DeleteChunkRequest], Iterable[str]] _OPERATOR: dict[OperatorOptions, Operator] = { Operator.OPERATOR_UNSPECIFIED: Operator.OPERATOR_UNSPECIFIED, @@ -156,15 +159,15 @@ def _to_proto(self): elif isinstance(c.value, (int, float)): kwargs["numeric_value"] = float(c.value) else: - ValueError( - f"The value for the condition must be either a string or an integer/float, but got {c.value}." + raise ValueError( + f"Invalid value type: The value for the condition must be either a string or an integer/float. Received: '{c.value}' of type {type(c.value).__name__}." ) kwargs["operation"] = c.operation - condition = glm.Condition(**kwargs) + condition = protos.Condition(**kwargs) conditions.append(condition) - return glm.MetadataFilter(key=self.key, conditions=conditions) + return protos.MetadataFilter(key=self.key, conditions=conditions) @string_utils.prettyprint @@ -186,18 +189,17 @@ def _to_proto(self): kwargs["string_value"] = self.value elif isinstance(self.value, Iterable): if isinstance(self.value, Mapping): - # If already converted to a glm.StringList, get the values + # If already converted to a protos.StringList, get the values kwargs["string_list_value"] = self.value else: - kwargs["string_list_value"] = glm.StringList(values=self.value) + kwargs["string_list_value"] = protos.StringList(values=self.value) elif isinstance(self.value, (int, float)): kwargs["numeric_value"] = float(self.value) else: - ValueError( - f"The value for a custom_metadata specification must be either a list of string values, a string, or an integer/float, but got {self.value}." + raise ValueError( + f"Invalid value type: The value for a custom_metadata specification must be either a list of string values, a string, or an integer/float. Received: '{self.value}' of type {type(self.value).__name__}." ) - - return glm.CustomMetadata(key=self.key, **kwargs) + return protos.CustomMetadata(key=self.key, **kwargs) @classmethod def _from_dict(cls, cm): @@ -215,21 +217,21 @@ def _to_dict(self): return type(proto).to_dict(proto) -CustomMetadataOptions = Union[CustomMetadata, glm.CustomMetadata, dict] +CustomMetadataOptions = Union[CustomMetadata, protos.CustomMetadata, dict] def make_custom_metadata(cm: CustomMetadataOptions) -> CustomMetadata: if isinstance(cm, CustomMetadata): return cm - if isinstance(cm, glm.CustomMetadata): + if isinstance(cm, protos.CustomMetadata): cm = type(cm).to_dict(cm) if isinstance(cm, dict): return CustomMetadata._from_dict(cm) else: raise ValueError( # nofmt - "Could not create a `CustomMetadata` from:\n" f" type: {type(cm)}\n" f" value: {cm}" + f"Invalid input: Could not create a 'CustomMetadata' from the provided input. Received type: '{type(cm).__name__}', value: '{cm}'." ) @@ -261,7 +263,7 @@ def create_document( display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ Request to create a `Document`. @@ -292,9 +294,9 @@ def create_document( c_data.append(cm._to_proto()) if name is None: - document = glm.Document(display_name=display_name, custom_metadata=c_data) + document = protos.Document(display_name=display_name, custom_metadata=c_data) elif valid_name(name): - document = glm.Document( + document = protos.Document( name=f"{self.name}/documents/{name}", display_name=display_name, custom_metadata=c_data, @@ -302,7 +304,7 @@ def create_document( else: raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateDocumentRequest(parent=self.name, document=document) + request = protos.CreateDocumentRequest(parent=self.name, document=document) response = client.create_document(request, **request_options) return decode_document(response) @@ -312,7 +314,7 @@ async def create_document_async( display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.create_document`.""" if request_options is None: @@ -328,9 +330,9 @@ async def create_document_async( c_data.append(cm._to_proto()) if name is None: - document = glm.Document(display_name=display_name, custom_metadata=c_data) + document = protos.Document(display_name=display_name, custom_metadata=c_data) elif valid_name(name): - document = glm.Document( + document = protos.Document( name=f"{self.name}/documents/{name}", display_name=display_name, custom_metadata=c_data, @@ -338,7 +340,7 @@ async def create_document_async( else: raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateDocumentRequest(parent=self.name, document=document) + request = protos.CreateDocumentRequest(parent=self.name, document=document) response = await client.create_document(request, **request_options) return decode_document(response) @@ -346,7 +348,7 @@ def get_document( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ Get information about a specific `Document`. @@ -367,7 +369,7 @@ def get_document( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.GetDocumentRequest(name=name) + request = protos.GetDocumentRequest(name=name) response = client.get_document(request, **request_options) return decode_document(response) @@ -375,7 +377,7 @@ async def get_document_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.get_document`.""" if request_options is None: @@ -387,7 +389,7 @@ async def get_document_async( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.GetDocumentRequest(name=name) + request = protos.GetDocumentRequest(name=name) response = await client.get_document(request, **request_options) return decode_document(response) @@ -401,7 +403,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified `Corpus`. @@ -423,7 +425,9 @@ def update( # At this time, only `display_name` can be updated for item in updates: if item != "display_name": - raise ValueError("At this time, only `display_name` can be updated for `Corpus`.") + raise ValueError( + "Invalid operation: Currently, only the 'display_name' attribute can be updated for a 'Corpus'." + ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): @@ -431,7 +435,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + request = protos.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) client.update_corpus(request, **request_options) return self @@ -439,7 +443,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.update`.""" if request_options is None: @@ -452,7 +456,9 @@ async def update_async( # At this time, only `display_name` can be updated for item in updates: if item != "display_name": - raise ValueError("At this time, only `display_name` can be updated for `Corpus`.") + raise ValueError( + "Invalid operation: Currently, only the 'display_name' attribute can be updated for a 'Corpus'." + ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): @@ -460,7 +466,7 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + request = protos.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) await client.update_corpus(request, **request_options) return self @@ -470,7 +476,7 @@ def query( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """ Query a corpus for information. @@ -492,14 +498,16 @@ def query( if results_count: if results_count > 100: - raise ValueError("Number of results returned must be between 1 and 100.") + raise ValueError( + "Invalid operation: The number of results returned must be between 1 and 100." + ) m_f_ = [] if metadata_filters: for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryCorpusRequest( + request = protos.QueryCorpusRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -524,7 +532,7 @@ async def query_async( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """This is the async version of `Corpus.query`.""" if request_options is None: @@ -535,14 +543,16 @@ async def query_async( if results_count: if results_count > 100: - raise ValueError("Number of results returned must be between 1 and 100.") + raise ValueError( + "Invalid operation: The number of results returned must be between 1 and 100." + ) m_f_ = [] if metadata_filters: for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryCorpusRequest( + request = protos.QueryCorpusRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -566,7 +576,7 @@ def delete_document( name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Delete a document in the corpus. @@ -585,7 +595,7 @@ def delete_document( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.DeleteDocumentRequest(name=name, force=bool(force)) + request = protos.DeleteDocumentRequest(name=name, force=bool(force)) client.delete_document(request, **request_options) async def delete_document_async( @@ -593,7 +603,7 @@ async def delete_document_async( name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.delete_document`.""" if request_options is None: @@ -605,14 +615,14 @@ async def delete_document_async( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.DeleteDocumentRequest(name=name, force=bool(force)) + request = protos.DeleteDocumentRequest(name=name, force=bool(force)) await client.delete_document(request, **request_options) def list_documents( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Document]: """ List documents in corpus. @@ -631,7 +641,7 @@ def list_documents( if client is None: client = get_default_retriever_client() - request = glm.ListDocumentsRequest( + request = protos.ListDocumentsRequest( parent=self.name, page_size=page_size, ) @@ -642,7 +652,7 @@ async def list_documents_async( self, page_size: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Document]: """This is the async version of `Corpus.list_documents`.""" if request_options is None: @@ -651,7 +661,7 @@ async def list_documents_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListDocumentsRequest( + request = protos.ListDocumentsRequest( parent=self.name, page_size=page_size, ) @@ -744,7 +754,7 @@ def create_chunk( name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """ Create a `Chunk` object which has textual data. @@ -783,15 +793,17 @@ def create_chunk( chunk_name = name if isinstance(data, str): - chunk = glm.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=c_data + ) else: - chunk = glm.Chunk( + chunk = protos.Chunk( name=chunk_name, data={"string_value": data.string_value}, custom_metadata=c_data, ) - request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + request = protos.CreateChunkRequest(parent=self.name, chunk=chunk) response = client.create_chunk(request, **request_options) return decode_chunk(response) @@ -801,7 +813,7 @@ async def create_chunk_async( name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """This is the async version of `Document.create_chunk`.""" if request_options is None: @@ -825,24 +837,26 @@ async def create_chunk_async( chunk_name = name if isinstance(data, str): - chunk = glm.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=c_data + ) else: - chunk = glm.Chunk( + chunk = protos.Chunk( name=chunk_name, data={"string_value": data.string_value}, custom_metadata=c_data, ) - request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + request = protos.CreateChunkRequest(parent=self.name, chunk=chunk) response = await client.create_chunk(request, **request_options) return decode_chunk(response) - def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: + def _make_chunk(self, chunk: ChunkOptions) -> protos.Chunk: # del self - if isinstance(chunk, glm.Chunk): - return glm.Chunk(chunk) + if isinstance(chunk, protos.Chunk): + return protos.Chunk(chunk) elif isinstance(chunk, str): - return glm.Chunk(data={"string_value": chunk}) + return protos.Chunk(data={"string_value": chunk}) elif isinstance(chunk, tuple): if len(chunk) == 2: name, data = chunk # pytype: disable=bad-unpacking @@ -855,7 +869,7 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: f"value: {chunk}" ) - return glm.Chunk( + return protos.Chunk( name=name, data={"string_value": data}, custom_metadata=custom_metadata, @@ -864,16 +878,16 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: if isinstance(chunk["data"], str): chunk = dict(chunk) chunk["data"] = {"string_value": chunk["data"]} - return glm.Chunk(chunk) + return protos.Chunk(chunk) else: raise TypeError( - f"Could not convert instance of `{type(chunk)}` chunk:" f"value: {chunk}" + f"Invalid input: Could not convert instance of type '{type(chunk).__name__}' to a chunk. Received value: '{chunk}'." ) def _make_batch_create_chunk_request( self, chunks: BatchCreateChunkOptions - ) -> glm.BatchCreateChunksRequest: - if isinstance(chunks, glm.BatchCreateChunksRequest): + ) -> protos.BatchCreateChunksRequest: + if isinstance(chunks, protos.BatchCreateChunksRequest): return chunks if isinstance(chunks, Mapping): @@ -892,15 +906,15 @@ def _make_batch_create_chunk_request( chunk.name = f"{self.name}/chunks/{chunk.name}" - requests.append(glm.CreateChunkRequest(parent=self.name, chunk=chunk)) + requests.append(protos.CreateChunkRequest(parent=self.name, chunk=chunk)) - return glm.BatchCreateChunksRequest(parent=self.name, requests=requests) + return protos.BatchCreateChunksRequest(parent=self.name, requests=requests) def batch_create_chunks( self, chunks: BatchCreateChunkOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Create chunks within the given document. @@ -926,7 +940,7 @@ async def batch_create_chunks_async( self, chunks: BatchCreateChunkOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_create_chunk`.""" if request_options is None: @@ -943,7 +957,7 @@ def get_chunk( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Get information about a specific chunk. @@ -964,7 +978,7 @@ def get_chunk( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.GetChunkRequest(name=name) + request = protos.GetChunkRequest(name=name) response = client.get_chunk(request, **request_options) return decode_chunk(response) @@ -972,7 +986,7 @@ async def get_chunk_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.get_chunk`.""" if request_options is None: @@ -984,7 +998,7 @@ async def get_chunk_async( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.GetChunkRequest(name=name) + request = protos.GetChunkRequest(name=name) response = await client.get_chunk(request, **request_options) return decode_chunk(response) @@ -992,7 +1006,7 @@ def list_chunks( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Chunk]: """ List chunks of a document. @@ -1010,7 +1024,7 @@ def list_chunks( if client is None: client = get_default_retriever_client() - request = glm.ListChunksRequest(parent=self.name, page_size=page_size) + request = protos.ListChunksRequest(parent=self.name, page_size=page_size) for chunk in client.list_chunks(request, **request_options): yield decode_chunk(chunk) @@ -1018,7 +1032,7 @@ async def list_chunks_async( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Chunk]: """This is the async version of `Document.list_chunks`.""" if request_options is None: @@ -1027,7 +1041,7 @@ async def list_chunks_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListChunksRequest(parent=self.name, page_size=page_size) + request = protos.ListChunksRequest(parent=self.name, page_size=page_size) async for chunk in await client.list_chunks(request, **request_options): yield decode_chunk(chunk) @@ -1037,7 +1051,7 @@ def query( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """ Query a `Document` in the `Corpus` for information. @@ -1058,14 +1072,16 @@ def query( if results_count: if results_count < 0 or results_count >= 100: - raise ValueError("Number of results returned must be between 1 and 100.") + raise ValueError( + "Invalid operation: The number of results returned must be between 1 and 100." + ) m_f_ = [] if metadata_filters: for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryDocumentRequest( + request = protos.QueryDocumentRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -1090,7 +1106,7 @@ async def query_async( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """This is the async version of `Document.query`.""" if request_options is None: @@ -1101,14 +1117,16 @@ async def query_async( if results_count: if results_count < 0 or results_count >= 100: - raise ValueError("Number of results returned must be between 1 and 100.") + raise ValueError( + "Invalid operation: The number of results returned must be between 1 and 100." + ) m_f_ = [] if metadata_filters: for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryDocumentRequest( + request = protos.QueryDocumentRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -1137,7 +1155,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified document. @@ -1159,14 +1177,16 @@ def update( # At this time, only `display_name` can be updated for item in updates: if item != "display_name": - raise ValueError("At this time, only `display_name` can be updated for `Document`.") + raise ValueError( + "Invalid operation: Currently, only the 'display_name' attribute can be updated for a 'Document'." + ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + request = protos.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) client.update_document(request, **request_options) return self @@ -1174,7 +1194,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.update`.""" if request_options is None: @@ -1187,14 +1207,16 @@ async def update_async( # At this time, only `display_name` can be updated for item in updates: if item != "display_name": - raise ValueError("At this time, only `display_name` can be updated for `Document`.") + raise ValueError( + "Invalid operation: Currently, only the 'display_name' attribute can be updated for a 'Document'." + ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + request = protos.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) await client.update_document(request, **request_options) return self @@ -1202,7 +1224,7 @@ def batch_update_chunks( self, chunks: BatchUpdateChunksOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update multiple chunks within the same document. @@ -1220,7 +1242,7 @@ def batch_update_chunks( if client is None: client = get_default_retriever_client() - if isinstance(chunks, glm.BatchUpdateChunksRequest): + if isinstance(chunks, protos.BatchUpdateChunksRequest): response = client.batch_update_chunks(chunks) response = type(response).to_dict(response) return response @@ -1245,7 +1267,7 @@ def batch_update_chunks( for item in updates: if item != "data.string_value": raise ValueError( - f"At this time, only `data` can be updated for `Chunk`. Got {item}." + f"Invalid operation: Currently, only the 'data' attribute can be updated for a 'Chunk'. Attempted to update '{item}'." ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): @@ -1253,15 +1275,17 @@ def batch_update_chunks( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - glm.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest( + chunk=chunk_to_update.to_dict(), update_mask=field_mask + ) ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): for chunk in chunks: - if isinstance(chunk, glm.UpdateChunkRequest): + if isinstance(chunk, protos.UpdateChunkRequest): _requests.append(chunk) elif isinstance(chunk, tuple): # First element is name of chunk, second element contains updates @@ -1287,10 +1311,10 @@ def batch_update_chunks( ) else: raise TypeError( - "The `chunks` parameter must be a list of glm.UpdateChunkRequests," - "dictionaries, or tuples of dictionaries." + "Invalid input: The 'chunks' parameter must be a list of 'protos.UpdateChunkRequests'," + " dictionaries, or tuples of dictionaries." ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response @@ -1299,7 +1323,7 @@ async def batch_update_chunks_async( self, chunks: BatchUpdateChunksOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_update_chunks`.""" if request_options is None: @@ -1308,7 +1332,7 @@ async def batch_update_chunks_async( if client is None: client = get_default_retriever_async_client() - if isinstance(chunks, glm.BatchUpdateChunksRequest): + if isinstance(chunks, protos.BatchUpdateChunksRequest): response = client.batch_update_chunks(chunks) response = type(response).to_dict(response) return response @@ -1333,7 +1357,7 @@ async def batch_update_chunks_async( for item in updates: if item != "data.string_value": raise ValueError( - f"At this time, only `data` can be updated for `Chunk`. Got {item}." + f"Invalid operation: Currently, only the 'data' attribute can be updated for a 'Chunk'. Attempted to update '{item}'." ) field_mask = field_mask_pb2.FieldMask() for path in updates.keys(): @@ -1341,15 +1365,17 @@ async def batch_update_chunks_async( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - glm.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest( + chunk=chunk_to_update.to_dict(), update_mask=field_mask + ) ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): for chunk in chunks: - if isinstance(chunk, glm.UpdateChunkRequest): + if isinstance(chunk, protos.UpdateChunkRequest): _requests.append(chunk) elif isinstance(chunk, tuple): # First element is name of chunk, second element contains updates @@ -1375,10 +1401,10 @@ async def batch_update_chunks_async( ) else: raise TypeError( - "The `chunks` parameter must be a list of glm.UpdateChunkRequests," + "Invalid input: The 'chunks' parameter must be a list of 'protos.UpdateChunkRequests', " "dictionaries, or tuples of dictionaries." ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response @@ -1387,7 +1413,7 @@ def delete_chunk( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, # fmt: {} + request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """ Delete a `Chunk`. @@ -1405,14 +1431,14 @@ def delete_chunk( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.DeleteChunkRequest(name=name) + request = protos.DeleteChunkRequest(name=name) client.delete_chunk(request, **request_options) async def delete_chunk_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, # fmt: {} + request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """This is the async version of `Document.delete_chunk`.""" if request_options is None: @@ -1424,14 +1450,14 @@ async def delete_chunk_async( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.DeleteChunkRequest(name=name) + request = protos.DeleteChunkRequest(name=name) await client.delete_chunk(request, **request_options) def batch_delete_chunks( self, chunks: BatchDeleteChunkOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Delete multiple `Chunk`s from a document. @@ -1446,25 +1472,26 @@ def batch_delete_chunks( if client is None: client = get_default_retriever_client() - if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + if all(isinstance(x, protos.DeleteChunkRequest) for x in chunks): + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=chunks) client.batch_delete_chunks(request, **request_options) elif isinstance(chunks, Iterable): _request_list = [] for chunk_name in chunks: - _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + _request_list.append(protos.DeleteChunkRequest(name=chunk_name)) + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `glm.DeleteChunkRequest`s." + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, " + "or multiple 'protos.DeleteChunkRequest's." ) async def batch_delete_chunks_async( self, chunks: BatchDeleteChunkOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_delete_chunks`.""" if request_options is None: @@ -1473,18 +1500,19 @@ async def batch_delete_chunks_async( if client is None: client = get_default_retriever_async_client() - if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + if all(isinstance(x, protos.DeleteChunkRequest) for x in chunks): + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=chunks) await client.batch_delete_chunks(request, **request_options) elif isinstance(chunks, Iterable): _request_list = [] for chunk_name in chunks: - _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + _request_list.append(protos.DeleteChunkRequest(name=chunk_name)) + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) await client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple `glm.DeleteChunkRequest`s." + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, " + "or multiple 'protos.DeleteChunkRequest's." ) def to_dict(self) -> dict[str, Any]: @@ -1496,7 +1524,7 @@ def to_dict(self) -> dict[str, Any]: return result -def decode_chunk(chunk: glm.Chunk) -> Chunk: +def decode_chunk(chunk: protos.Chunk) -> Chunk: chunk = type(chunk).to_dict(chunk) idecode_time(chunk, "create_time") idecode_time(chunk, "update_time") @@ -1570,7 +1598,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified `Chunk`. @@ -1602,7 +1630,7 @@ def update( for item in updates: if item != "data.string_value": raise ValueError( - f"At this time, only `data` can be updated for `Chunk`. Got {item}." + f"Invalid operation: Currently, only the 'data' attribute can be updated for a 'Chunk'. Attempted to update '{item}'." ) field_mask = field_mask_pb2.FieldMask() @@ -1610,7 +1638,7 @@ def update( field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + request = protos.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) client.update_chunk(request, **request_options) return self @@ -1619,7 +1647,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Chunk.update`.""" if request_options is None: @@ -1642,7 +1670,7 @@ async def update_async( for item in updates: if item != "data.string_value": raise ValueError( - f"At this time, only `data` can be updated for `Chunk`. Got {item}." + f"Invalid operation: Currently, only the 'data' attribute can be updated for a 'Chunk'. Attempted to update '{item}'." ) field_mask = field_mask_pb2.FieldMask() @@ -1650,7 +1678,7 @@ async def update_async( field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + request = protos.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) await client.update_chunk(request, **request_options) return self diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index 7d94a5bb0..74da06e45 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -1,14 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations from collections.abc import Mapping +import enum import typing from typing import Dict, Iterable, List, Union from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -24,105 +39,72 @@ ] # These are basic python enums, it's okay to expose them -HarmCategory = glm.HarmCategory -HarmProbability = glm.SafetyRating.HarmProbability -HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold -BlockedReason = glm.ContentFilter.BlockedReason +HarmProbability = protos.SafetyRating.HarmProbability +HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold +BlockedReason = protos.ContentFilter.BlockedReason + +import proto + + +class HarmCategory(proto.Enum): + """ + Harm Categories supported by the gemini-family model + """ + + HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_HARASSMENT = protos.HarmCategory.HARM_CATEGORY_HARASSMENT.value + HARM_CATEGORY_HATE_SPEECH = protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value + HARM_CATEGORY_SEXUALLY_EXPLICIT = protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value + HARM_CATEGORY_DANGEROUS_CONTENT = protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value + HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_OLD_HARM_CATEGORIES: Dict[HarmCategoryOptions, HarmCategory] = { - HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": HarmCategory.HARM_CATEGORY_UNSPECIFIED, - - HarmCategory.HARM_CATEGORY_DEROGATORY: HarmCategory.HARM_CATEGORY_DEROGATORY, - 1: HarmCategory.HARM_CATEGORY_DEROGATORY, - "harm_category_derogatory": HarmCategory.HARM_CATEGORY_DEROGATORY, - "derogatory": HarmCategory.HARM_CATEGORY_DEROGATORY, - - HarmCategory.HARM_CATEGORY_TOXICITY: HarmCategory.HARM_CATEGORY_TOXICITY, - 2: HarmCategory.HARM_CATEGORY_TOXICITY, - "harm_category_toxicity": HarmCategory.HARM_CATEGORY_TOXICITY, - "toxicity": HarmCategory.HARM_CATEGORY_TOXICITY, - "toxic": HarmCategory.HARM_CATEGORY_TOXICITY, - - HarmCategory.HARM_CATEGORY_VIOLENCE: HarmCategory.HARM_CATEGORY_VIOLENCE, - 3: HarmCategory.HARM_CATEGORY_VIOLENCE, - "harm_category_violence": HarmCategory.HARM_CATEGORY_VIOLENCE, - "violence": HarmCategory.HARM_CATEGORY_VIOLENCE, - "violent": HarmCategory.HARM_CATEGORY_VIOLENCE, - - HarmCategory.HARM_CATEGORY_SEXUAL: HarmCategory.HARM_CATEGORY_SEXUAL, - 4: HarmCategory.HARM_CATEGORY_SEXUAL, - "harm_category_sexual": HarmCategory.HARM_CATEGORY_SEXUAL, - "sexual": HarmCategory.HARM_CATEGORY_SEXUAL, - "sex": HarmCategory.HARM_CATEGORY_SEXUAL, - - HarmCategory.HARM_CATEGORY_MEDICAL: HarmCategory.HARM_CATEGORY_MEDICAL, - 5: HarmCategory.HARM_CATEGORY_MEDICAL, - "harm_category_medical": HarmCategory.HARM_CATEGORY_MEDICAL, - "medical": HarmCategory.HARM_CATEGORY_MEDICAL, - "med": HarmCategory.HARM_CATEGORY_MEDICAL, - - HarmCategory.HARM_CATEGORY_DANGEROUS: HarmCategory.HARM_CATEGORY_DANGEROUS, - 6: HarmCategory.HARM_CATEGORY_DANGEROUS, - "harm_category_dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS, - "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS, - "danger": HarmCategory.HARM_CATEGORY_DANGEROUS, -} - -_NEW_HARM_CATEGORIES = { - 7: HarmCategory.HARM_CATEGORY_HARASSMENT, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmCategory.HARM_CATEGORY_HARASSMENT, - "harm_category_harassment": HarmCategory.HARM_CATEGORY_HARASSMENT, - "harassment": HarmCategory.HARM_CATEGORY_HARASSMENT, - - 8: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'harm_category_hate_speech': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate_speech': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate': HarmCategory.HARM_CATEGORY_HATE_SPEECH, - - 9: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexually_explicit": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexual": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexually_explicit": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexual": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sex": HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - - 10: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous_content": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "dangerous": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "danger": HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = { + protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + 7: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + protos.HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harm_category_harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + + 8: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'harm_category_hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + + 9: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sex": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + + 10: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous_content": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, } # fmt: on -def to_old_harm_category(x: HarmCategoryOptions) -> HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory: if isinstance(x, str): x = x.lower() - return _OLD_HARM_CATEGORIES[x] - - -def to_new_harm_category(x: HarmCategoryOptions) -> HarmCategory: - if isinstance(x, str): - x = x.lower() - return _NEW_HARM_CATEGORIES[x] - - -def to_harm_category(x, harm_category_set): - if harm_category_set == "old": - return to_old_harm_category(x) - elif harm_category_set == "new": - return to_new_harm_category(x) - else: - raise ValueError("harm_category_set must be 'new' or 'old'") + return _HARM_CATEGORIES[x] HarmBlockThresholdOptions = Union[str, int, HarmBlockThreshold] @@ -158,7 +140,7 @@ def to_harm_category(x, harm_category_set): # fmt: on -def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmCategory: +def to_block_threshold(x: HarmBlockThresholdOptions) -> HarmBlockThreshold: if isinstance(x, str): x = x.lower() return _BLOCK_THRESHOLDS[x] @@ -168,7 +150,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__) def convert_filters_to_enums( @@ -184,15 +166,15 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: HarmCategory + category: protos.HarmCategory probability: HarmProbability - __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": HarmCategory(rating["category"]), + "category": protos.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -205,10 +187,10 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: HarmCategory + category: protos.HarmCategory threshold: HarmBlockThreshold - __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -219,34 +201,56 @@ class LooseSafetySettingDict(TypedDict): EasySafetySetting = Mapping[HarmCategoryOptions, HarmBlockThresholdOptions] EasySafetySettingDict = dict[HarmCategoryOptions, HarmBlockThresholdOptions] -SafetySettingOptions = Union[EasySafetySetting, Iterable[LooseSafetySettingDict], None] +SafetySettingOptions = Union[ + HarmBlockThresholdOptions, EasySafetySetting, Iterable[LooseSafetySettingDict], None +] + +def _expand_block_threshold(block_threshold: HarmBlockThresholdOptions): + block_threshold = to_block_threshold(block_threshold) + set(_HARM_CATEGORIES.values()) + return {category: block_threshold for category in set(_HARM_CATEGORIES.values())} -def to_easy_safety_dict(settings: SafetySettingOptions, harm_category_set) -> EasySafetySettingDict: + +def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict: if settings is None: return {} - elif isinstance(settings, Mapping): - return { - to_harm_category(key, harm_category_set): to_block_threshold(value) - for key, value in settings.items() - } + + if isinstance(settings, (int, str, HarmBlockThreshold)): + settings = _expand_block_threshold(settings) + + if isinstance(settings, Mapping): + return {to_harm_category(key): to_block_threshold(value) for key, value in settings.items()} + else: # Iterable - return { - to_harm_category(d["category"], harm_category_set): to_block_threshold(d["threshold"]) - for d in settings - } + result = {} + for setting in settings: + if isinstance(setting, protos.SafetySetting): + result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold) + elif isinstance(setting, dict): + result[to_harm_category(setting["category"])] = to_block_threshold( + setting["threshold"] + ) + else: + raise ValueError( + f"Could not understand safety setting:\n {type(setting)=}\n {setting=}" + ) + return result def normalize_safety_settings( settings: SafetySettingOptions, - harm_category_set, ) -> list[SafetySettingDict] | None: if settings is None: return None + + if isinstance(settings, (int, str, HarmBlockThreshold)): + settings = _expand_block_threshold(settings) + if isinstance(settings, Mapping): return [ { - "category": to_harm_category(key, harm_category_set), + "category": to_harm_category(key), "threshold": to_block_threshold(value), } for key, value in settings.items() @@ -254,7 +258,7 @@ def normalize_safety_settings( else: return [ { - "category": to_harm_category(d["category"], harm_category_set), + "category": to_harm_category(d["category"]), "threshold": to_block_threshold(d["threshold"]), } for d in settings @@ -263,7 +267,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": HarmCategory(setting["category"]), + "category": protos.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } @@ -272,7 +276,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/google/generativeai/types/text_types.py b/google/generativeai/types/text_types.py index f66c0fb32..61804fcaa 100644 --- a/google/generativeai/types/text_types.py +++ b/google/generativeai/types/text_types.py @@ -21,7 +21,7 @@ from typing_extensions import TypedDict from google.generativeai import string_utils -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import citation_types @@ -42,7 +42,7 @@ class BatchEmbeddingDict(TypedDict): class TextCompletion(TypedDict, total=False): output: str - safety_ratings: List[safety_types.SafetyRatingDict | None] + safety_ratings: List[palm_safety_types.SafetyRatingDict | None] citation_metadata: citation_types.CitationMetadataDict | None @@ -63,8 +63,8 @@ class Completion(abc.ABC): candidates: List[TextCompletion] result: str | None - filters: List[safety_types.ContentFilterDict | None] - safety_feedback: List[safety_types.SafetyFeedbackDict | None] + filters: List[palm_safety_types.ContentFilterDict | None] + safety_feedback: List[palm_safety_types.SafetyFeedbackDict | None] def to_dict(self) -> Dict[str, Any]: result = { diff --git a/google/generativeai/utils.py b/google/generativeai/utils.py index 6dc2b6a20..cd2c4cbf7 100644 --- a/google/generativeai/utils.py +++ b/google/generativeai/utils.py @@ -16,6 +16,8 @@ def flatten_update_paths(updates): + """Flattens a nested dictionary into a single level dictionary, with keys representing the original path.""" + new_updates = {} for key, value in updates.items(): if isinstance(value, dict): diff --git a/google/generativeai/version.py b/google/generativeai/version.py index e1ce17d66..8018b67ac 100644 --- a/google/generativeai/version.py +++ b/google/generativeai/version.py @@ -14,4 +14,4 @@ # limitations under the License. from __future__ import annotations -__version__ = "0.5.4" +__version__ = "0.6.0" diff --git a/tests/notebook/text_model_test.py b/tests/notebook/text_model_test.py index 9239ac9c3..428d44b26 100644 --- a/tests/notebook/text_model_test.py +++ b/tests/notebook/text_model_test.py @@ -68,21 +68,47 @@ def _generate_text( class TextModelTestCase(absltest.TestCase): - def test_generate_text(self): + def test_generate_text_without_args(self): model = TestModel() result = model.call_model("prompt goes in") self.assertEqual(result.text_results[0], "prompt goes in_1") - self.assertIsNone(result.text_results[1]) - self.assertIsNone(result.text_results[2]) - self.assertIsNone(result.text_results[3]) + def test_generate_text_without_args_none_results(self): + model = TestModel() + + result = model.call_model("prompt goes in") + self.assertEqual(result.text_results[1], "None") + self.assertEqual(result.text_results[2], "None") + self.assertEqual(result.text_results[3], "None") + + def test_generate_text_with_args_first_result(self): + model = TestModel() args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5) + result = model.call_model("prompt goes in", args) self.assertEqual(result.text_results[0], "prompt goes in_1") + + def test_generate_text_with_args_model_name(self): + model = TestModel() + args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5) + + result = model.call_model("prompt goes in", args) self.assertEqual(result.text_results[1], "model_name") - self.assertEqual(result.text_results[2], 0.42) - self.assertEqual(result.text_results[3], 5) + + def test_generate_text_with_args_temperature(self): + model = TestModel() + args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5) + result = model.call_model("prompt goes in", args) + + self.assertEqual(result.text_results[2], str(0.42)) + + def test_generate_text_with_args_candidate_count(self): + model = TestModel() + args = model_lib.ModelArguments(model="model_name", temperature=0.42, candidate_count=5) + + result = model.call_model("prompt goes in", args) + self.assertEqual(result.text_results[3], str(5)) def test_retry(self): model = TestModel() diff --git a/tests/test_answer.py b/tests/test_answer.py index 6fa12603c..2669b207c 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -18,9 +18,10 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import answer +from google.generativeai import types as genai_types from google.generativeai import client from absl.testing import absltest from absl.testing import parameterized @@ -46,14 +47,14 @@ def add_client_method(f): @add_client_method def generate_answer( - request: glm.GenerateAnswerRequest, + request: protos.GenerateAnswerRequest, **kwargs, - ) -> glm.GenerateAnswerResponse: + ) -> protos.GenerateAnswerResponse: self.observed_requests.append(request) - return glm.GenerateAnswerResponse( - answer=glm.Candidate( + return protos.GenerateAnswerResponse( + answer=protos.Candidate( index=1, - content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + content=(protos.Content(parts=[protos.Part(text="Demo answer.")])), ), answerable_probability=0.500, ) @@ -61,17 +62,23 @@ def generate_answer( def test_make_grounding_passages_mixed_types(self): inline_passages = [ "I am a chicken", - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), + protos.Content(parts=[protos.Part(text="I am a bird.")]), + protos.Content(parts=[protos.Part(text="I can fly!")]), ] x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "0", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -81,23 +88,29 @@ def test_make_grounding_passages_mixed_types(self): [ dict( testcase_name="grounding_passage", - inline_passages=glm.GroundingPassages( + inline_passages=protos.GroundingPassages( passages=[ { "id": "0", - "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + { + "id": "2", + "content": protos.Content(parts=[protos.Part(text="I can fly!")]), }, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, ] ), ), dict( testcase_name="content_object", inline_passages=[ - glm.Content(parts=[glm.Part(text="I am a chicken")]), - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), + protos.Content(parts=[protos.Part(text="I am a chicken")]), + protos.Content(parts=[protos.Part(text="I am a bird.")]), + protos.Content(parts=[protos.Part(text="I can fly!")]), ], ), dict( @@ -108,13 +121,19 @@ def test_make_grounding_passages_mixed_types(self): ) def test_make_grounding_passages(self, inline_passages): x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "0", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -132,27 +151,33 @@ def test_make_grounding_passages(self, inline_passages): dict( testcase_name="list_of_grounding_passages", inline_passages=[ - glm.GroundingPassage( - id="4", content=glm.Content(parts=[glm.Part(text="I am a chicken")]) + protos.GroundingPassage( + id="4", content=protos.Content(parts=[protos.Part(text="I am a chicken")]) ), - glm.GroundingPassage( - id="5", content=glm.Content(parts=[glm.Part(text="I am a bird.")]) + protos.GroundingPassage( + id="5", content=protos.Content(parts=[protos.Part(text="I am a bird.")]) ), - glm.GroundingPassage( - id="6", content=glm.Content(parts=[glm.Part(text="I can fly!")]) + protos.GroundingPassage( + id="6", content=protos.Content(parts=[protos.Part(text="I can fly!")]) ), ], ), ) def test_make_grounding_passages_different_id(self, inline_passages): x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "4", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "5", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "6", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "4", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "5", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "6", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -166,16 +191,22 @@ def test_make_grounding_passages_key_strings(self): } x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ { "id": "first", - "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "second", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + { + "id": "third", + "content": protos.Content(parts=[protos.Part(text="I can fly!")]), }, - {"id": "second", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "third", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, ] ), x, @@ -183,14 +214,14 @@ def test_make_grounding_passages_key_strings(self): def test_generate_answer_request(self): # Should be a list of contents to use to_contents() function. - contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + contents = [protos.Content(parts=[protos.Part(text="I have wings.")])] inline_passages = ["I am a chicken", "I am a bird.", "I can fly!"] - grounding_passages = glm.GroundingPassages( + grounding_passages = protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ) @@ -199,7 +230,7 @@ def test_generate_answer_request(self): ) self.assertEqual( - glm.GenerateAnswerRequest( + protos.GenerateAnswerRequest( model=DEFAULT_ANSWER_MODEL, contents=contents, inline_passages=grounding_passages ), x, @@ -207,13 +238,13 @@ def test_generate_answer_request(self): def test_generate_answer(self): # Test handling return value of generate_answer(). - contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + contents = [protos.Content(parts=[protos.Part(text="I have wings.")])] - grounding_passages = glm.GroundingPassages( + grounding_passages = protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ) @@ -224,13 +255,13 @@ def test_generate_answer(self): answer_style="ABSTRACTIVE", ) - self.assertIsInstance(a, glm.GenerateAnswerResponse) + self.assertIsInstance(a, protos.GenerateAnswerResponse) self.assertEqual( a, - glm.GenerateAnswerResponse( - answer=glm.Candidate( + protos.GenerateAnswerResponse( + answer=protos.Candidate( index=1, - content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + content=(protos.Content(parts=[protos.Part(text="Demo answer.")])), ), answerable_probability=0.500, ), @@ -239,7 +270,7 @@ def test_generate_answer(self): def test_generate_answer_called_with_request_options(self): self.client.generate_answer = mock.MagicMock() request = mock.ANY - request_options = {"timeout": 120} + request_options = genai_types.RequestOptions(timeout=120) answer.generate_answer(contents=[], inline_passages=[], request_options=request_options) diff --git a/tests/test_client.py b/tests/test_client.py index 34a0f9fc3..0cc3e05eb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,8 +4,10 @@ from absl.testing import absltest from absl.testing import parameterized -from google.api_core import client_options import google.ai.generativelanguage as glm + +from google.api_core import client_options +from google.generativeai import protos from google.generativeai import client @@ -42,7 +44,7 @@ def test_api_key_from_environment(self): def test_api_key_cannot_be_set_twice(self): client_opts = client_options.ClientOptions(api_key="AIzA_client_opts") - with self.assertRaisesRegex(ValueError, "You can't set both"): + with self.assertRaisesRegex(ValueError, "Invalid configuration: Please set either"): client.configure(api_key="AIzA_client", client_options=client_opts) def test_api_key_and_client_options(self): diff --git a/tests/test_content.py b/tests/test_content.py index 5f22b93a1..6df5faad4 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -19,11 +19,12 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import content_types import IPython.display import PIL.Image +import numpy as np HERE = pathlib.Path(__file__).parent TEST_PNG_PATH = HERE / "test_img.png" @@ -67,43 +68,45 @@ class ADataClassWithList: class UnitTests(parameterized.TestCase): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_PNG_PATH)], + ["RGBA", PIL.Image.fromarray(np.zeros([6, 6, 4], dtype=np.uint8))], ["IPython", IPython.display.Image(filename=TEST_PNG_PATH)], ) def test_png_to_blob(self, image): blob = content_types.image_to_blob(image) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_JPG_PATH)], + ["RGB", PIL.Image.fromarray(np.zeros([6, 6, 3], dtype=np.uint8))], ["IPython", IPython.display.Image(filename=TEST_JPG_PATH)], ) def test_jpg_to_blob(self, image): blob = content_types.image_to_blob(image) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/jpeg") self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF") @parameterized.named_parameters( ["BlobDict", {"mime_type": "image/png", "data": TEST_PNG_DATA}], - ["glm.Blob", glm.Blob(mime_type="image/png", data=TEST_PNG_DATA)], + ["protos.Blob", protos.Blob(mime_type="image/png", data=TEST_PNG_DATA)], ["Image", IPython.display.Image(filename=TEST_PNG_PATH)], ) def test_to_blob(self, example): blob = content_types.to_blob(example) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( ["dict", {"text": "Hello world!"}], - ["glm.Part", glm.Part(text="Hello world!")], + ["protos.Part", protos.Part(text="Hello world!")], ["str", "Hello world!"], ) def test_to_part(self, example): part = content_types.to_part(example) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -116,12 +119,12 @@ def test_to_part(self, example): ) def test_img_to_part(self, example): blob = content_types.to_part(example).inline_data - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( - ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["protos.Content", protos.Content(parts=[{"text": "Hello world!"}])], ["ContentDict", {"parts": [{"text": "Hello world!"}]}], ["ContentDict-str", {"parts": ["Hello world!"]}], ["list[parts]", [{"text": "Hello world!"}]], @@ -135,7 +138,7 @@ def test_to_content(self, example): part = content.parts[0] self.assertLen(content.parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -147,12 +150,12 @@ def test_img_to_content(self, example): content = content_types.to_content(example) blob = content.parts[0].inline_data self.assertLen(content.parts, 1) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( - ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["protos.Content", protos.Content(parts=[{"text": "Hello world!"}])], ["ContentDict", {"parts": [{"text": "Hello world!"}]}], ["ContentDict-str", {"parts": ["Hello world!"]}], ) @@ -161,7 +164,7 @@ def test_strict_to_content(self, example): part = content.parts[0] self.assertLen(content.parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -176,7 +179,7 @@ def test_strict_to_contents_fails(self, examples): content_types.strict_to_content(examples) @parameterized.named_parameters( - ["glm.Content", [glm.Content(parts=[{"text": "Hello world!"}])]], + ["protos.Content", [protos.Content(parts=[{"text": "Hello world!"}])]], ["ContentDict", [{"parts": [{"text": "Hello world!"}]}]], ["ContentDict-unwraped", [{"parts": ["Hello world!"]}]], ["ContentDict+str-part", [{"parts": "Hello world!"}]], @@ -188,7 +191,7 @@ def test_to_contents(self, example): self.assertLen(contents, 1) self.assertLen(contents[0].parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") def test_dict_to_content_fails(self): @@ -209,7 +212,7 @@ def test_img_to_contents(self, example): self.assertLen(contents, 1) self.assertLen(contents[0].parts, 1) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -217,9 +220,9 @@ def test_img_to_contents(self, example): [ "FunctionLibrary", content_types.FunctionLibrary( - tools=glm.Tool( + tools=protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -231,7 +234,7 @@ def test_img_to_contents(self, example): [ content_types.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -239,11 +242,11 @@ def test_img_to_contents(self, example): ], ], [ - "IterableTool-glm.Tool", + "IterableTool-protos.Tool", [ - glm.Tool( + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -268,7 +271,7 @@ def test_img_to_contents(self, example): "IterableTool-IterableFD", [ [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -278,7 +281,7 @@ def test_img_to_contents(self, example): [ "IterableTool-FD", [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -288,17 +291,17 @@ def test_img_to_contents(self, example): "Tool", content_types.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] ), ], [ - "glm.Tool", - glm.Tool( + "protos.Tool", + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -350,8 +353,8 @@ def test_img_to_contents(self, example): ), ], [ - "glm.FD", - glm.FunctionDeclaration( + "protos.FD", + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ), ], @@ -391,83 +394,83 @@ def b(): self.assertLen(tools[0].function_declarations, 2) @parameterized.named_parameters( - ["int", int, glm.Schema(type=glm.Type.INTEGER)], - ["float", float, glm.Schema(type=glm.Type.NUMBER)], - ["str", str, glm.Schema(type=glm.Type.STRING)], - ["nullable_str", Union[str, None], glm.Schema(type=glm.Type.STRING, nullable=True)], + ["int", int, protos.Schema(type=protos.Type.INTEGER)], + ["float", float, protos.Schema(type=protos.Type.NUMBER)], + ["str", str, protos.Schema(type=protos.Type.STRING)], + ["nullable_str", Union[str, None], protos.Schema(type=protos.Type.STRING, nullable=True)], [ "list", list[str], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.STRING), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.STRING), ), ], [ "list-list-int", list[list[int]], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema( - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.INTEGER), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema( + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.INTEGER), ), ), ), ], - ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], - ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ["dict", dict, protos.Schema(type=protos.Type.OBJECT)], + ["dict-str-any", dict[str, Any], protos.Schema(type=protos.Type.OBJECT)], [ "dataclass", ADataClass, - glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER}}, + protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ], [ "nullable_dataclass", Union[ADataClass, None], - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, nullable=True, - properties={"a": {"type_": glm.Type.INTEGER}}, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ], [ "list_of_dataclass", list[ADataClass], - glm.Schema( + protos.Schema( type="ARRAY", - items=glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER}}, + items=protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ), ], [ "dataclass_with_nullable", ADataClassWithNullable, - glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER, "nullable": True}}, + protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER, "nullable": True}}, ), ], [ "dataclass_with_list", ADataClassWithList, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, ), ], [ "list_of_dataclass_with_list", list[ADataClassWithList], - glm.Schema( - items=glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + items=protos.Schema( + type=protos.Type.OBJECT, properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, ), type="ARRAY", @@ -476,31 +479,31 @@ def b(): [ "list_of_nullable", list[Union[int, None]], - glm.Schema( + protos.Schema( type="ARRAY", - items={"type_": glm.Type.INTEGER, "nullable": True}, + items={"type_": protos.Type.INTEGER, "nullable": True}, ), ], [ "TypedDict", ATypedDict, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={ - "a": {"type_": glm.Type.INTEGER}, + "a": {"type_": protos.Type.INTEGER}, }, ), ], [ "nested", Nested, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={ - "x": glm.Schema( - type=glm.Type.OBJECT, + "x": protos.Schema( + type=protos.Type.OBJECT, properties={ - "a": {"type_": glm.Type.INTEGER}, + "a": {"type_": protos.Type.INTEGER}, }, ), }, diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 9d628a42c..4e54cf754 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -from typing import Any import unittest.mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import discuss from google.generativeai import client import google.generativeai as genai -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from absl.testing import absltest from absl.testing import parameterized @@ -38,18 +37,18 @@ def setUp(self): self.observed_request = None - self.mock_response = glm.GenerateMessageResponse( + self.mock_response = protos.GenerateMessageResponse( candidates=[ - glm.Message(content="a", author="1"), - glm.Message(content="b", author="1"), - glm.Message(content="c", author="1"), + protos.Message(content="a", author="1"), + protos.Message(content="b", author="1"), + protos.Message(content="c", author="1"), ], ) def fake_generate_message( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, **kwargs, - ) -> glm.GenerateMessageResponse: + ) -> protos.GenerateMessageResponse: self.observed_request = request response = copy.copy(self.mock_response) response.messages = request.prompt.messages @@ -61,22 +60,22 @@ def fake_generate_message( ["string", "Hello", ""], ["dict", {"content": "Hello"}, ""], ["dict_author", {"content": "Hello", "author": "me"}, "me"], - ["proto", glm.Message(content="Hello"), ""], - ["proto_author", glm.Message(content="Hello", author="me"), "me"], + ["proto", protos.Message(content="Hello"), ""], + ["proto_author", protos.Message(content="Hello", author="me"), "me"], ) def test_make_message(self, message, author): x = discuss._make_message(message) - self.assertIsInstance(x, glm.Message) + self.assertIsInstance(x, protos.Message) self.assertEqual("Hello", x.content) self.assertEqual(author, x.author) @parameterized.named_parameters( ["string", "Hello", ["Hello"]], ["dict", {"content": "Hello"}, ["Hello"]], - ["proto", glm.Message(content="Hello"), ["Hello"]], + ["proto", protos.Message(content="Hello"), ["Hello"]], [ "list", - ["hello0", {"content": "hello1"}, glm.Message(content="hello2")], + ["hello0", {"content": "hello1"}, protos.Message(content="hello2")], ["hello0", "hello1", "hello2"], ], ) @@ -91,15 +90,15 @@ def test_make_messages(self, messages, expected_contents): ["dict", {"input": "hello", "output": "goodbye"}], [ "proto", - glm.Example( - input=glm.Message(content="hello"), - output=glm.Message(content="goodbye"), + protos.Example( + input=protos.Message(content="hello"), + output=protos.Message(content="goodbye"), ), ], ) def test_make_example(self, example): x = discuss._make_example(example) - self.assertIsInstance(x, glm.Example) + self.assertIsInstance(x, protos.Example) self.assertEqual("hello", x.input.content) self.assertEqual("goodbye", x.output.content) return @@ -111,7 +110,7 @@ def test_make_example(self, example): "Hi", {"content": "Hello!"}, "what's your name?", - glm.Message(content="Dave, what's yours"), + protos.Message(content="Dave, what's yours"), ], ], [ @@ -146,15 +145,15 @@ def test_make_examples_from_example(self): @parameterized.named_parameters( ["str", "hello"], - ["message", glm.Message(content="hello")], + ["message", protos.Message(content="hello")], ["messages", ["hello"]], ["dict", {"messages": "hello"}], ["dict2", {"messages": ["hello"]}], - ["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])], + ["proto", protos.MessagePrompt(messages=[protos.Message(content="hello")])], ) def test_make_message_prompt_from_messages(self, prompt): x = discuss._make_message_prompt(prompt) - self.assertIsInstance(x, glm.MessagePrompt) + self.assertIsInstance(x, protos.MessagePrompt) self.assertEqual(x.messages[0].content, "hello") return @@ -182,15 +181,15 @@ def test_make_message_prompt_from_messages(self, prompt): [ "proto", [ - glm.MessagePrompt( + protos.MessagePrompt( context="you are a cat", examples=[ - glm.Example( - input=glm.Message(content="are you hungry?"), - output=glm.Message(content="meow!"), + protos.Example( + input=protos.Message(content="are you hungry?"), + output=protos.Message(content="meow!"), ) ], - messages=[glm.Message(content="hello")], + messages=[protos.Message(content="hello")], ) ], {}, @@ -198,7 +197,7 @@ def test_make_message_prompt_from_messages(self, prompt): ) def test_make_message_prompt_from_prompt(self, args, kwargs): x = discuss._make_message_prompt(*args, **kwargs) - self.assertIsInstance(x, glm.MessagePrompt) + self.assertIsInstance(x, protos.MessagePrompt) self.assertEqual(x.context, "you are a cat") self.assertEqual(x.examples[0].input.content, "are you hungry?") self.assertEqual(x.examples[0].output.content, "meow!") @@ -230,8 +229,8 @@ def test_make_generate_message_request_nested( } ) - self.assertIsInstance(request0, glm.GenerateMessageRequest) - self.assertIsInstance(request1, glm.GenerateMessageRequest) + self.assertIsInstance(request0, protos.GenerateMessageRequest) + self.assertIsInstance(request1, protos.GenerateMessageRequest) self.assertEqual(request0, request1) @parameterized.parameters( @@ -286,39 +285,43 @@ def test_reply(self, kwargs): response = response.reply("again") def test_receive_and_reply_with_filters(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( - candidates=[glm.Message(content="a", author="1")], + self.mock_response = mock_response = protos.GenerateMessageResponse( + candidates=[protos.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=safety_types.BlockedReason.SAFETY, message="unsafe"), - glm.ContentFilter(reason=safety_types.BlockedReason.OTHER), + protos.ContentFilter( + reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe" + ), + protos.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), ], ) response = discuss.chat(messages="do filters work?") filters = response.filters self.assertLen(filters, 2) - self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) - self.assertEqual(filters[0]["reason"], safety_types.BlockedReason.SAFETY) + self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason) + self.assertEqual(filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) self.assertEqual(filters[0]["message"], "unsafe") - self.mock_response = glm.GenerateMessageResponse( - candidates=[glm.Message(content="a", author="1")], + self.mock_response = protos.GenerateMessageResponse( + candidates=[protos.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + protos.ContentFilter( + reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED + ) ], ) response = response.reply("Does reply work?") filters = response.filters self.assertLen(filters, 1) - self.assertIsInstance(filters[0]["reason"], safety_types.BlockedReason) + self.assertIsInstance(filters[0]["reason"], palm_safety_types.BlockedReason) self.assertEqual( filters[0]["reason"], - safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED, + palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED, ) def test_chat_citations(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( + self.mock_response = mock_response = protos.GenerateMessageResponse( candidates=[ { "content": "Hello google!", diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py index 7e1f7947c..d35d03525 100644 --- a/tests/test_discuss_async.py +++ b/tests/test_discuss_async.py @@ -17,7 +17,7 @@ from typing import Any import unittest -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import discuss from absl.testing import absltest @@ -31,14 +31,14 @@ async def test_chat_async(self): observed_request = None async def fake_generate_message( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, **kwargs, - ) -> glm.GenerateMessageResponse: + ) -> protos.GenerateMessageResponse: nonlocal observed_request observed_request = request - return glm.GenerateMessageResponse( + return protos.GenerateMessageResponse( candidates=[ - glm.Message( + protos.Message( author="1", content="Why did the chicken cross the road?", ) @@ -59,17 +59,17 @@ async def fake_generate_message( self.assertEqual( observed_request, - glm.GenerateMessageRequest( + protos.GenerateMessageRequest( model="models/bard", - prompt=glm.MessagePrompt( + prompt=protos.MessagePrompt( context="Example Prompt", examples=[ - glm.Example( - input=glm.Message(content="Example from human"), - output=glm.Message(content="Example response from AI"), + protos.Example( + input=protos.Message(content="Example from human"), + output=protos.Message(content="Example response from AI"), ) ], - messages=[glm.Message(author="0", content="Tell me a joke")], + messages=[protos.Message(author="0", content="Tell me a joke")], ), temperature=0.75, candidate_count=1, diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 5f6aa8d89..a208a4743 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import embedding @@ -45,20 +45,20 @@ def add_client_method(f): @add_client_method def embed_content( - request: glm.EmbedContentRequest, + request: protos.EmbedContentRequest, **kwargs, - ) -> glm.EmbedContentResponse: + ) -> protos.EmbedContentResponse: self.observed_requests.append(request) - return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + return protos.EmbedContentResponse(embedding=protos.ContentEmbedding(values=[1, 2, 3])) @add_client_method def batch_embed_contents( - request: glm.BatchEmbedContentsRequest, + request: protos.BatchEmbedContentsRequest, **kwargs, - ) -> glm.BatchEmbedContentsResponse: + ) -> protos.BatchEmbedContentsResponse: self.observed_requests.append(request) - return glm.BatchEmbedContentsResponse( - embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + return protos.BatchEmbedContentsResponse( + embeddings=[protos.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) ) def test_embed_content(self): @@ -68,8 +68,9 @@ def test_embed_content(self): self.assertIsInstance(emb, dict) self.assertEqual( self.observed_requests[-1], - glm.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + protos.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, + content=protos.Content(parts=[protos.Part(text="What are you?")]), ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_embedding_async.py b/tests/test_embedding_async.py index d4ca16c08..367cf7ded 100644 --- a/tests/test_embedding_async.py +++ b/tests/test_embedding_async.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import embedding @@ -44,20 +44,20 @@ def add_client_method(f): @add_client_method async def embed_content( - request: glm.EmbedContentRequest, + request: protos.EmbedContentRequest, **kwargs, - ) -> glm.EmbedContentResponse: + ) -> protos.EmbedContentResponse: self.observed_requests.append(request) - return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + return protos.EmbedContentResponse(embedding=protos.ContentEmbedding(values=[1, 2, 3])) @add_client_method async def batch_embed_contents( - request: glm.BatchEmbedContentsRequest, + request: protos.BatchEmbedContentsRequest, **kwargs, - ) -> glm.BatchEmbedContentsResponse: + ) -> protos.BatchEmbedContentsResponse: self.observed_requests.append(request) - return glm.BatchEmbedContentsResponse( - embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + return protos.BatchEmbedContentsResponse( + embeddings=[protos.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) ) async def test_embed_content_async(self): @@ -67,8 +67,9 @@ async def test_embed_content_async(self): self.assertIsInstance(emb, dict) self.assertEqual( self.observed_requests[-1], - glm.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + protos.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, + content=protos.Content(parts=[protos.Part(text="What are you?")]), ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_files.py b/tests/test_files.py new file mode 100644 index 000000000..7d9139450 --- /dev/null +++ b/tests/test_files.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.generativeai.types import file_types + +import collections +import datetime +import os +from typing import Iterable, Union +import pathlib + +import google + +import google.generativeai as genai +from google.generativeai import client as client_lib +from google.generativeai import protos +from absl.testing import parameterized + + +class FileServiceClient(client_lib.FileServiceClient): + def __init__(self, test): + self.test = test + self.observed_requests = [] + self.responses = collections.defaultdict(list) + + def create_file( + self, + path: Union[str, pathlib.Path, os.PathLike], + *, + mime_type: Union[str, None] = None, + name: Union[str, None] = None, + display_name: Union[str, None] = None, + resumable: bool = True, + ) -> protos.File: + self.observed_requests.append( + dict( + path=path, + mime_type=mime_type, + name=name, + display_name=display_name, + resumable=resumable, + ) + ) + return self.responses["create_file"].pop(0) + + def get_file( + self, + request: protos.GetFileRequest, + **kwargs, + ) -> protos.File: + self.observed_requests.append(request) + return self.responses["get_file"].pop(0) + + def list_files( + self, + request: protos.ListFilesRequest, + **kwargs, + ) -> Iterable[protos.File]: + self.observed_requests.append(request) + for f in self.responses["list_files"].pop(0): + yield f + + def delete_file( + self, + request: protos.DeleteFileRequest, + **kwargs, + ): + self.observed_requests.append(request) + return + + +class UnitTests(parameterized.TestCase): + def setUp(self): + self.client = FileServiceClient(self) + + client_lib._client_manager.clients["file"] = self.client + + @property + def observed_requests(self): + return self.client.observed_requests + + @property + def responses(self): + return self.client.responses + + def test_video_metadata(self): + self.responses["create_file"].append( + protos.File( + uri="https://test", + state="ACTIVE", + video_metadata=dict(video_duration=datetime.timedelta(seconds=30)), + error=dict(code=7, message="ok?"), + ) + ) + + f = genai.upload_file(path="dummy") + self.assertEqual(google.rpc.status_pb2.Status(code=7, message="ok?"), f.error) + self.assertEqual( + protos.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))), + f.video_metadata, + ) + + @parameterized.named_parameters( + [ + dict( + testcase_name="FileDataDict", + file_data=dict(file_uri="https://test_uri"), + ), + dict( + testcase_name="FileDict", + file_data=dict(uri="https://test_uri"), + ), + dict( + testcase_name="FileData", + file_data=protos.FileData(file_uri="https://test_uri"), + ), + dict( + testcase_name="protos.File", + file_data=protos.File(uri="https://test_uri"), + ), + dict( + testcase_name="file_types.File", + file_data=file_types.File(dict(uri="https://test_uri")), + ), + ] + ) + def test_to_file_data(self, file_data): + file_data = file_types.to_file_data(file_data) + self.assertEqual(protos.FileData(file_uri="https://test_uri"), file_data) diff --git a/tests/test_generation.py b/tests/test_generation.py index 82beac16b..828577d21 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -5,7 +5,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import generation_types @@ -24,9 +24,11 @@ class Person(TypedDict): class UnitTests(parameterized.TestCase): @parameterized.named_parameters( [ - "glm.GenerationConfig", - glm.GenerationConfig( - temperature=0.1, stop_sequences=["end"], response_schema=glm.Schema(type="STRING") + "protos.GenerationConfig", + protos.GenerationConfig( + temperature=0.1, + stop_sequences=["end"], + response_schema=protos.Schema(type="STRING"), ), ], [ @@ -48,15 +50,15 @@ def test_to_generation_config(self, config): def test_join_citation_metadatas(self): citations = [ - glm.CitationMetadata( + protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=21, uri="https://google.com"), ] ), - glm.CitationMetadata( + protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=3, end_index=33, uri="https://google.com"), - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=33, uri="https://google.com"), + protos.CitationSource(start_index=55, end_index=92, uri="https://google.com"), ] ), ] @@ -74,14 +76,14 @@ def test_join_citation_metadatas(self): def test_join_safety_ratings_list(self): ratings = [ [ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), ], [ - glm.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), - glm.SafetyRating( + protos.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), + protos.SafetyRating( category="HARM_CATEGORY_DANGEROUS", probability="HIGH", blocked=True, @@ -101,14 +103,14 @@ def test_join_safety_ratings_list(self): def test_join_contents(self): contents = [ - glm.Content(role="assistant", parts=[glm.Part(text="Tell me a story about a ")]), - glm.Content( + protos.Content(role="assistant", parts=[protos.Part(text="Tell me a story about a ")]), + protos.Content( role="assistant", - parts=[glm.Part(text="magic backpack that looks like this: ")], + parts=[protos.Part(text="magic backpack that looks like this: ")], ), - glm.Content( + protos.Content( role="assistant", - parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!"))], ), ] result = generation_types._join_contents(contents) @@ -126,7 +128,8 @@ def test_many_join_contents(self): import string contents = [ - glm.Content(role="assistant", parts=[glm.Part(text=a)]) for a in string.ascii_lowercase + protos.Content(role="assistant", parts=[protos.Part(text=a)]) + for a in string.ascii_lowercase ] result = generation_types._join_contents(contents) @@ -139,41 +142,53 @@ def test_many_join_contents(self): def test_join_candidates(self): candidates = [ - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(text="Tell me a story about a ")], + parts=[protos.Part(text="Tell me a story about a ")], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=85, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=85, uri="https://google.com" + ), ] ), ), - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(text="magic backpack that looks like this: ")], + parts=[protos.Part(text="magic backpack that looks like this: ")], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=92, uri="https://google.com" + ), + protos.CitationSource( + start_index=3, end_index=21, uri="https://google.com" + ), ] ), ), - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[ + protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!")) + ], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=92, uri="https://google.com" + ), + protos.CitationSource( + start_index=3, end_index=21, uri="https://google.com" + ), ] ), finish_reason="STOP", @@ -213,17 +228,17 @@ def test_join_candidates(self): def test_join_prompt_feedbacks(self): feedbacks = [ - glm.GenerateContentResponse.PromptFeedback( + protos.GenerateContentResponse.PromptFeedback( block_reason="SAFETY", safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), ], ), - glm.GenerateContentResponse.PromptFeedback(), - glm.GenerateContentResponse.PromptFeedback(), - glm.GenerateContentResponse.PromptFeedback( + protos.GenerateContentResponse.PromptFeedback(), + protos.GenerateContentResponse.PromptFeedback(), + protos.GenerateContentResponse.PromptFeedback( safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), + protos.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), ] ), ] @@ -396,23 +411,23 @@ def test_join_prompt_feedbacks(self): ] def test_join_candidates(self): - candidate_lists = [[glm.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] + candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] result = generation_types._join_candidate_lists(candidate_lists) self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result]) def test_join_chunks(self): - chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] - chunks[0].prompt_feedback = glm.GenerateContentResponse.PromptFeedback( + chunks[0].prompt_feedback = protos.GenerateContentResponse.PromptFeedback( block_reason="SAFETY", safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), ], ) result = generation_types._join_chunks(chunks) - expected = glm.GenerateContentResponse( + expected = protos.GenerateContentResponse( { "candidates": self.MERGED_CANDIDATES, "prompt_feedback": { @@ -431,7 +446,7 @@ def test_join_chunks(self): self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected)) def test_generate_content_response_iterator_end_to_end(self): - chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] merged = generation_types._join_chunks(chunks) response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -453,7 +468,7 @@ def test_generate_content_response_iterator_end_to_end(self): def test_generate_content_response_multiple_iterators(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in string.ascii_lowercase ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -483,7 +498,7 @@ def test_generate_content_response_multiple_iterators(self): def test_generate_content_response_resolve(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in "abcd" ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -497,7 +512,7 @@ def test_generate_content_response_resolve(self): self.assertEqual(response.candidates[0].content.parts[0].text, "abcd") def test_generate_content_response_from_response(self): - raw_response = glm.GenerateContentResponse( + raw_response = protos.GenerateContentResponse( {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} ) response = generation_types.GenerateContentResponse.from_response(raw_response) @@ -511,7 +526,7 @@ def test_generate_content_response_from_response(self): ) def test_repr_for_generate_content_response_from_response(self): - raw_response = glm.GenerateContentResponse( + raw_response = protos.GenerateContentResponse( {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} ) response = generation_types.GenerateContentResponse.from_response(raw_response) @@ -523,7 +538,7 @@ def test_repr_for_generate_content_response_from_response(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -531,13 +546,8 @@ def test_repr_for_generate_content_response_from_response(self): { "text": "Hello world!" } - ], - "role": "" - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + ] + } } ] }), @@ -547,7 +557,7 @@ def test_repr_for_generate_content_response_from_response(self): def test_repr_for_generate_content_response_from_iterator(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in "abcd" ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -559,7 +569,7 @@ def test_repr_for_generate_content_response_from_iterator(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -567,13 +577,8 @@ def test_repr_for_generate_content_response_from_iterator(self): { "text": "a" } - ], - "role": "" - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + ] + } } ] }), @@ -583,35 +588,35 @@ def test_repr_for_generate_content_response_from_iterator(self): @parameterized.named_parameters( [ - "glm.Schema", - glm.Schema(type="STRING"), - glm.Schema(type="STRING"), + "protos.Schema", + protos.Schema(type="STRING"), + protos.Schema(type="STRING"), ], [ "SchemaDict", {"type": "STRING"}, - glm.Schema(type="STRING"), + protos.Schema(type="STRING"), ], [ "str", str, - glm.Schema(type="STRING"), + protos.Schema(type="STRING"), ], - ["list_of_str", list[str], glm.Schema(type="ARRAY", items=glm.Schema(type="STRING"))], + ["list_of_str", list[str], protos.Schema(type="ARRAY", items=protos.Schema(type="STRING"))], [ "fancy", Person, - glm.Schema( + protos.Schema( type="OBJECT", properties=dict( - name=glm.Schema(type="STRING"), - favorite_color=glm.Schema(type="STRING"), - birthday=glm.Schema( + name=protos.Schema(type="STRING"), + favorite_color=protos.Schema(type="STRING"), + birthday=protos.Schema( type="OBJECT", properties=dict( - day=glm.Schema(type="INTEGER"), - month=glm.Schema(type="INTEGER"), - year=glm.Schema(type="INTEGER"), + day=protos.Schema(type="INTEGER"), + month=protos.Schema(type="INTEGER"), + year=protos.Schema(type="INTEGER"), ), ), ), diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 6fabd59e9..0ece77e94 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -7,11 +7,13 @@ import unittest.mock from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai import generative_models from google.generativeai.types import content_types from google.generativeai.types import generation_types +from google.generativeai.types import helper_types + import PIL.Image @@ -21,61 +23,79 @@ TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes() -def simple_part(text: str) -> glm.Content: - return glm.Content({"parts": [{"text": text}]}) +def simple_part(text: str) -> protos.Content: + return protos.Content({"parts": [{"text": text}]}) + + +def noop(x: int): + return x + +def iter_part(texts: Iterable[str]) -> protos.Content: + return protos.Content({"parts": [{"text": t} for t in texts]}) -def iter_part(texts: Iterable[str]) -> glm.Content: - return glm.Content({"parts": [{"text": t} for t in texts]}) +def simple_response(text: str) -> protos.GenerateContentResponse: + return protos.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) + + +class MockGenerativeServiceClient: + def __init__(self, test): + self.test = test + self.observed_requests = [] + self.observed_kwargs = [] + self.responses = collections.defaultdict(list) -def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) + def generate_content( + self, + request: protos.GenerateContentRequest, + **kwargs, + ) -> protos.GenerateContentResponse: + self.test.assertIsInstance(request, protos.GenerateContentRequest) + self.observed_requests.append(request) + self.observed_kwargs.append(kwargs) + response = self.responses["generate_content"].pop(0) + return response + + def stream_generate_content( + self, + request: protos.GetModelRequest, + **kwargs, + ) -> Iterable[protos.GenerateContentResponse]: + self.observed_requests.append(request) + self.observed_kwargs.append(kwargs) + response = self.responses["stream_generate_content"].pop(0) + return response + + def count_tokens( + self, + request: protos.CountTokensRequest, + **kwargs, + ) -> Iterable[protos.GenerateContentResponse]: + self.observed_requests.append(request) + self.observed_kwargs.append(kwargs) + response = self.responses["count_tokens"].pop(0) + return response class CUJTests(parameterized.TestCase): """Tests are in order with the design doc.""" - def setUp(self): - self.client = unittest.mock.MagicMock() + @property + def observed_requests(self): + return self.client.observed_requests - client_lib._client_manager.clients["generative"] = self.client + @property + def observed_kwargs(self): + return self.client.observed_kwargs - def add_client_method(f): - name = f.__name__ - setattr(self.client, name, f) - return f + @property + def responses(self): + return self.client.responses - self.observed_requests = [] - self.responses = collections.defaultdict(list) - - @add_client_method - def generate_content( - request: glm.GenerateContentRequest, - **kwargs, - ) -> glm.GenerateContentResponse: - self.assertIsInstance(request, glm.GenerateContentRequest) - self.observed_requests.append(request) - response = self.responses["generate_content"].pop(0) - return response - - @add_client_method - def stream_generate_content( - request: glm.GetModelRequest, - **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: - self.observed_requests.append(request) - response = self.responses["stream_generate_content"].pop(0) - return response - - @add_client_method - def count_tokens( - request: glm.CountTokensRequest, - **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: - self.observed_requests.append(request) - response = self.responses["count_tokens"].pop(0) - return response + def setUp(self): + self.client = MockGenerativeServiceClient(self) + client_lib._client_manager.clients["generative"] = self.client def test_hello(self): # Generate text from text prompt @@ -129,9 +149,9 @@ def test_image(self, content): generation_types.GenerationConfig(temperature=0.5), ], [ - "glm", - glm.GenerationConfig(temperature=0.0), - glm.GenerationConfig(temperature=0.5), + "protos", + protos.GenerationConfig(temperature=0.0), + protos.GenerationConfig(temperature=0.5), ], ) def test_generation_config_overwrite(self, config1, config2): @@ -151,12 +171,13 @@ def test_generation_config_overwrite(self, config1, config2): @parameterized.named_parameters( ["dict", {"danger": "low"}, {"danger": "high"}], + ["quick", "low", "high"], [ "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ @@ -166,22 +187,22 @@ def test_generation_config_overwrite(self, config1, config2): [ "object", [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ), ], ], ) def test_safety_overwrite(self, safe1, safe2): # Safety - model = generative_models.GenerativeModel("gemini-pro", safety_settings={"danger": "low"}) + model = generative_models.GenerativeModel("gemini-pro", safety_settings=safe1) self.responses["generate_content"] = [ simple_response(" world!"), @@ -189,23 +210,26 @@ def test_safety_overwrite(self, safe1, safe2): ] _ = model.generate_content("hello") - self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - ) - self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - ) - _ = model.generate_content("hello", safety_settings={"danger": "high"}) + danger = [ + s + for s in self.observed_requests[-1].safety_settings + if s.category == protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + ] self.assertEqual( - self.observed_requests[-1].safety_settings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + danger[0].threshold, + protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) + + _ = model.generate_content("hello", safety_settings=safe2) + danger = [ + s + for s in self.observed_requests[-1].safety_settings + if s.category == protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + ] self.assertEqual( - self.observed_requests[-1].safety_settings[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + danger[0].threshold, + protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ) def test_stream_basic(self): @@ -239,7 +263,7 @@ def test_stream_lookahead(self): def test_stream_prompt_feedback_blocked(self): chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -252,7 +276,7 @@ def test_stream_prompt_feedback_blocked(self): self.assertEqual( response.prompt_feedback.block_reason, - glm.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, + protos.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, ) with self.assertRaises(generation_types.BlockedPromptException): @@ -261,20 +285,20 @@ def test_stream_prompt_feedback_blocked(self): def test_stream_prompt_feedback_not_blocked(self): chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": { "safety_ratings": [ { - "category": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "probability": glm.SafetyRating.HarmProbability.NEGLIGIBLE, + "category": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "probability": protos.SafetyRating.HarmProbability.NEGLIGIBLE, } ] }, "candidates": [{"content": {"parts": [{"text": "first"}]}}], } ), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"content": {"parts": [{"text": " second"}]}}], } @@ -287,7 +311,7 @@ def test_stream_prompt_feedback_not_blocked(self): self.assertEqual( response.prompt_feedback.safety_ratings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS, ) text = "".join(chunk.text for chunk in response) @@ -443,7 +467,7 @@ def test_copy_history(self): chat1 = model.start_chat() chat1.send_message("hello1") - chat2 = copy.deepcopy(chat1) + chat2 = copy.copy(chat1) chat2.send_message("hello2") chat1.send_message("hello3") @@ -520,7 +544,7 @@ def no_throw(): def test_chat_prompt_blocked(self): self.responses["generate_content"] = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -538,7 +562,7 @@ def test_chat_prompt_blocked(self): def test_chat_candidate_blocked(self): # I feel like chat needs a .last so you can look at the partial results. self.responses["generate_content"] = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -558,7 +582,7 @@ def test_chat_streaming_unexpected_stop(self): simple_response("a"), simple_response("b"), simple_response("c"), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -645,9 +669,9 @@ def test_tools(self): }, ), dict( - testcase_name="test_glm_FunctionCallingConfig", + testcase_name="test_protos.FunctionCallingConfig", tool_config={ - "function_calling_config": glm.FunctionCallingConfig( + "function_calling_config": protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.AUTO ) }, @@ -674,9 +698,9 @@ def test_tools(self): }, ), dict( - testcase_name="test_glm_ToolConfig", - tool_config=glm.ToolConfig( - function_calling_config=glm.FunctionCallingConfig( + testcase_name="test_protos.ToolConfig", + tool_config=protos.ToolConfig( + function_calling_config=protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.NONE ) ), @@ -725,18 +749,33 @@ def test_system_instruction(self, instruction, expected_instr): self.assertEqual(req.system_instruction, expected_instr) @parameterized.named_parameters( - ["basic", "Hello"], - ["list", ["Hello"]], + ["basic", {"contents": "Hello"}], + ["list", {"contents": ["Hello"]}], [ "list2", - [{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}], + { + "contents": [ + {"text": "Hello"}, + {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}, + ] + }, + ], + [ + "contents", + {"contents": [{"role": "user", "parts": ["hello"]}]}, ], - ["contents", [{"role": "user", "parts": ["hello"]}]], + ["empty", {}], + [ + "system_instruction", + {"system_instruction": ["You are a cat"]}, + ], + ["tools", {"tools": [noop]}], ) - def test_count_tokens_smoke(self, contents): - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] - model = generative_models.GenerativeModel("gemini-pro-vision") - response = model.count_tokens(contents) + def test_count_tokens_smoke(self, kwargs): + si = kwargs.pop("system_instruction", None) + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] + model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si) + response = model.count_tokens(**kwargs) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) @parameterized.named_parameters( @@ -787,7 +826,7 @@ def test_async_code_match(self, obj, aobj): ) asource = re.sub(" *?# type: ignore", "", asource) - self.assertEqual(source, asource) + self.assertEqual(source, asource, f"error in {obj=}") def test_repr_for_unary_non_streamed_response(self): model = generative_models.GenerativeModel(model_name="gemini-pro") @@ -801,7 +840,7 @@ def test_repr_for_unary_non_streamed_response(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -809,13 +848,8 @@ def test_repr_for_unary_non_streamed_response(self): { "text": "world!" } - ], - "role": "" - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + ] + } } ] }), @@ -839,7 +873,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -847,13 +881,8 @@ def test_repr_for_streaming_start_to_finish(self): { "text": "first" } - ], - "role": "" - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + ] + } } ] }), @@ -869,7 +898,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -877,28 +906,14 @@ def test_repr_for_streaming_start_to_finish(self): { "text": "first second" } - ], - "role": "" + ] }, "index": 0, - "citation_metadata": { - "citation_sources": [] - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + "citation_metadata": {} } ], - "prompt_feedback": { - "block_reason": 0, - "safety_ratings": [] - }, - "usage_metadata": { - "prompt_token_count": 0, - "candidates_token_count": 0, - "total_token_count": 0 - } + "prompt_feedback": {}, + "usage_metadata": {} }), )""" ) @@ -912,7 +927,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -920,28 +935,14 @@ def test_repr_for_streaming_start_to_finish(self): { "text": "first second third" } - ], - "role": "" + ] }, "index": 0, - "citation_metadata": { - "citation_sources": [] - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + "citation_metadata": {} } ], - "prompt_feedback": { - "block_reason": 0, - "safety_ratings": [] - }, - "usage_metadata": { - "prompt_token_count": 0, - "candidates_token_count": 0, - "total_token_count": 0 - } + "prompt_feedback": {}, + "usage_metadata": {} }), )""" ) @@ -950,7 +951,7 @@ def test_repr_for_streaming_start_to_finish(self): def test_repr_error_info_for_stream_prompt_feedback_blocked(self): # response._error => BlockedPromptException chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -968,12 +969,10 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "prompt_feedback": { - "block_reason": 1, - "safety_ratings": [] - }, - "candidates": [] + "block_reason": "SAFETY" + } }), ), error= prompt_feedback { @@ -1020,7 +1019,7 @@ def no_throw(): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -1028,28 +1027,14 @@ def no_throw(): { "text": "123" } - ], - "role": "" + ] }, "index": 0, - "citation_metadata": { - "citation_sources": [] - }, - "finish_reason": 0, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + "citation_metadata": {} } ], - "prompt_feedback": { - "block_reason": 0, - "safety_ratings": [] - }, - "usage_metadata": { - "prompt_token_count": 0, - "candidates_token_count": 0, - "total_token_count": 0 - } + "prompt_feedback": {}, + "usage_metadata": {} }), ), error= """ @@ -1064,7 +1049,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): simple_response("a"), simple_response("b"), simple_response("c"), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -1093,7 +1078,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -1101,28 +1086,15 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): { "text": "abc" } - ], - "role": "" + ] }, - "finish_reason": 3, + "finish_reason": "SAFETY", "index": 0, - "citation_metadata": { - "citation_sources": [] - }, - "safety_ratings": [], - "token_count": 0, - "grounding_attributions": [] + "citation_metadata": {} } ], - "prompt_feedback": { - "block_reason": 0, - "safety_ratings": [] - }, - "usage_metadata": { - "prompt_token_count": 0, - "candidates_token_count": 0, - "total_token_count": 0 - } + "prompt_feedback": {}, + "usage_metadata": {} }), ), error= index: 0 @@ -1169,7 +1141,7 @@ def test_repr_for_multi_turn_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'first'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), glm.Content({'parts': [{'text': 'second'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'third'}], 'role': 'model'})] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'first'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), protos.Content({'parts': [{'text': 'second'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" ) self.assertEqual(expected, result) @@ -1197,7 +1169,7 @@ def test_repr_for_incomplete_streaming_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" ) self.assertEqual(expected, result) @@ -1213,7 +1185,7 @@ def test_repr_for_broken_streaming_chat(self): for chunk in [ simple_response("first"), # FinishReason.SAFETY = 3 - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [ {"finish_reason": 3, "content": {"parts": [{"text": "second"}]}} @@ -1241,7 +1213,7 @@ def test_repr_for_broken_streaming_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" ) self.assertEqual(expected, result) @@ -1252,15 +1224,30 @@ def test_repr_for_system_instruction(self): self.assertIn("system_instruction='Be excellent.'", result) def test_count_tokens_called_with_request_options(self): - self.client.count_tokens = unittest.mock.MagicMock() - request = unittest.mock.ANY + self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7)) request_options = {"timeout": 120} - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") model.count_tokens([{"role": "user", "parts": ["hello"]}], request_options=request_options) - self.client.count_tokens.assert_called_once_with(request, **request_options) + self.assertEqual(request_options, self.observed_kwargs[0]) + + def test_chat_with_request_options(self): + self.responses["generate_content"].append( + protos.GenerateContentResponse( + { + "candidates": [{"finish_reason": "STOP"}], + } + ) + ) + request_options = {"timeout": 120} + + model = generative_models.GenerativeModel("gemini-pro") + chat = model.start_chat() + chat.send_message("hello", request_options=helper_types.RequestOptions(**request_options)) + + request_options["retry"] = None + self.assertEqual(request_options, self.observed_kwargs[0]) if __name__ == "__main__": diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 2c465d1d3..03055ffb3 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -24,14 +24,16 @@ from google.generativeai import client as client_lib from google.generativeai import generative_models from google.generativeai.types import content_types -import google.ai.generativelanguage as glm +from google.generativeai import protos from absl.testing import absltest from absl.testing import parameterized -def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) +def simple_response(text: str) -> protos.GenerateContentResponse: + return protos.GenerateContentResponse( + {"candidates": [{"content": {"parts": [{"text": text}]}}]} + ) class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): @@ -50,28 +52,28 @@ def add_client_method(f): @add_client_method async def generate_content( - request: glm.GenerateContentRequest, + request: protos.GenerateContentRequest, **kwargs, - ) -> glm.GenerateContentResponse: - self.assertIsInstance(request, glm.GenerateContentRequest) + ) -> protos.GenerateContentResponse: + self.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) response = self.responses["generate_content"].pop(0) return response @add_client_method async def stream_generate_content( - request: glm.GetModelRequest, + request: protos.GetModelRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["stream_generate_content"].pop(0) return response @add_client_method async def count_tokens( - request: glm.CountTokensRequest, + request: protos.CountTokensRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["count_tokens"].pop(0) return response @@ -140,9 +142,9 @@ async def responses(): }, ), dict( - testcase_name="test_glm_FunctionCallingConfig", + testcase_name="test_protos.FunctionCallingConfig", tool_config={ - "function_calling_config": glm.FunctionCallingConfig( + "function_calling_config": protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.AUTO ) }, @@ -169,9 +171,9 @@ async def responses(): }, ), dict( - testcase_name="test_glm_ToolConfig", - tool_config=glm.ToolConfig( - function_calling_config=glm.FunctionCallingConfig( + testcase_name="test_protos.ToolConfig", + tool_config=protos.ToolConfig( + function_calling_config=protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.NONE ) ), @@ -211,7 +213,7 @@ async def test_tool_config(self, tool_config, expected_tool_config): ["contents", [{"role": "user", "parts": ["hello"]}]], ) async def test_count_tokens_smoke(self, contents): - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") response = await model.count_tokens_async(contents) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 000000000..f060caf88 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib +import copy +import collections +from typing import Union + +from absl.testing import parameterized + +from google.generativeai import protos + +from google.generativeai import client +from google.generativeai import models +from google.generativeai.types import model_types +from google.generativeai.types import helper_types + +from google.api_core import retry + + +class MockModelClient: + def __init__(self, test): + self.test = test + + def get_model( + self, + request: Union[protos.GetModelRequest, None] = None, + *, + name=None, + timeout=None, + retry=None + ) -> protos.Model: + if request is None: + request = protos.GetModelRequest(name=name) + self.test.assertIsInstance(request, protos.GetModelRequest) + self.test.observed_requests.append(request) + self.test.observed_timeout.append(timeout) + self.test.observed_retry.append(retry) + response = copy.copy(self.test.responses["get_model"]) + return response + + +class HelperTests(parameterized.TestCase): + + def setUp(self): + self.client = MockModelClient(self) + client._client_manager.clients["model"] = self.client + + self.observed_requests = [] + self.observed_retry = [] + self.observed_timeout = [] + self.responses = collections.defaultdict(list) + + @parameterized.named_parameters( + ["None", None, None, None], + ["Empty", {}, None, None], + ["Timeout", {"timeout": 7}, 7, None], + ["Retry", {"retry": retry.Retry(timeout=7)}, None, retry.Retry(timeout=7)], + [ + "RequestOptions", + helper_types.RequestOptions(timeout=7, retry=retry.Retry(multiplier=3)), + 7, + retry.Retry(multiplier=3), + ], + ) + def test_get_model(self, request_options, expected_timeout, expected_retry): + self.responses = {"get_model": protos.Model(name="models/fake-bison-001")} + + _ = models.get_model("models/fake-bison-001", request_options=request_options) + + self.assertEqual(self.observed_timeout[0], expected_timeout) + self.assertEqual(str(self.observed_retry[0]), str(expected_retry)) diff --git a/tests/test_models.py b/tests/test_models.py index e971ef86d..23f80913a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,12 +25,13 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.api_core import operation from google.generativeai import models from google.generativeai import client from google.generativeai.types import model_types +from google.generativeai import types as genai_types import pandas as pd @@ -44,7 +45,7 @@ def setUp(self): client._client_manager.clients["model"] = self.client # TODO(markdaoust): Check if typechecking works better if wee define this as a - # subclass of `glm.ModelServiceClient`, would pyi files for `glm` help? + # subclass of `glm.ModelServiceClient`, would pyi files for `glm`. help? def add_client_method(f): name = f.__name__ setattr(self.client, name, f) @@ -54,63 +55,65 @@ def add_client_method(f): self.responses = {} @add_client_method - def get_model(request: Union[glm.GetModelRequest, None] = None, *, name=None) -> glm.Model: + def get_model( + request: Union[protos.GetModelRequest, None] = None, *, name=None + ) -> protos.Model: if request is None: - request = glm.GetModelRequest(name=name) - self.assertIsInstance(request, glm.GetModelRequest) + request = protos.GetModelRequest(name=name) + self.assertIsInstance(request, protos.GetModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_model"]) return response @add_client_method def get_tuned_model( - request: Union[glm.GetTunedModelRequest, None] = None, + request: Union[protos.GetTunedModelRequest, None] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method def list_models( - request: Union[glm.ListModelsRequest, None] = None, + request: Union[protos.ListModelsRequest, None] = None, *, page_size=None, page_token=None, **kwargs, - ) -> glm.ListModelsResponse: + ) -> protos.ListModelsResponse: if request is None: - request = glm.ListModelsRequest(page_size=page_size, page_token=page_token) - self.assertIsInstance(request, glm.ListModelsRequest) + request = protos.ListModelsRequest(page_size=page_size, page_token=page_token) + self.assertIsInstance(request, protos.ListModelsRequest) self.observed_requests.append(request) response = self.responses["list_models"] return (item for item in response) @add_client_method def list_tuned_models( - request: glm.ListTunedModelsRequest = None, + request: protos.ListTunedModelsRequest = None, *, page_size=None, page_token=None, **kwargs, - ) -> Iterable[glm.TunedModel]: + ) -> Iterable[protos.TunedModel]: if request is None: - request = glm.ListTunedModelsRequest(page_size=page_size, page_token=page_token) - self.assertIsInstance(request, glm.ListTunedModelsRequest) + request = protos.ListTunedModelsRequest(page_size=page_size, page_token=page_token) + self.assertIsInstance(request, protos.ListTunedModelsRequest) self.observed_requests.append(request) response = self.responses["list_tuned_models"] return (item for item in response) @add_client_method def update_tuned_model( - request: glm.UpdateTunedModelRequest, + request: protos.UpdateTunedModelRequest, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: self.observed_requests.append(request) response = self.responses.get("update_tuned_model", None) if response is None: @@ -119,7 +122,7 @@ def update_tuned_model( @add_client_method def delete_tuned_model(name): - request = glm.DeleteTunedModelRequest(name=name) + request = protos.DeleteTunedModelRequest(name=name) self.observed_requests.append(request) response = True return response @@ -129,26 +132,26 @@ def create_tuned_model( request, **kwargs, ): - request = glm.CreateTunedModelRequest(request) + request = protos.CreateTunedModelRequest(request) self.observed_requests.append(request) return self.responses["create_tuned_model"] def test_decode_tuned_model_time_round_trip(self): example_dt = datetime.datetime(2000, 1, 2, 3, 4, 5, 600_000, pytz.UTC) - tuned_model = glm.TunedModel(name="tunedModels/house-mouse-001", create_time=example_dt) + tuned_model = protos.TunedModel(name="tunedModels/house-mouse-001", create_time=example_dt) tuned_model = model_types.decode_tuned_model(tuned_model) self.assertEqual(tuned_model.create_time, example_dt) @parameterized.named_parameters( ["simple", "models/fake-bison-001"], ["simple-tuned", "tunedModels/my-pig-001"], - ["model-instance", glm.Model(name="models/fake-bison-001")], - ["tuned-model-instance", glm.TunedModel(name="tunedModels/my-pig-001")], + ["model-instance", protos.Model(name="models/fake-bison-001")], + ["tuned-model-instance", protos.TunedModel(name="tunedModels/my-pig-001")], ) def test_get_model(self, name): self.responses = { - "get_model": glm.Model(name="models/fake-bison-001"), - "get_tuned_model": glm.TunedModel(name="tunedModels/my-pig-001"), + "get_model": protos.Model(name="models/fake-bison-001"), + "get_tuned_model": protos.TunedModel(name="tunedModels/my-pig-001"), } model = models.get_model(name) @@ -159,7 +162,7 @@ def test_get_model(self, name): @parameterized.named_parameters( ["simple", "mystery-bison-001"], - ["model-instance", glm.Model(name="how?-bison-001")], + ["model-instance", protos.Model(name="how?-bison-001")], ) def test_fail_with_unscoped_model_name(self, name): with self.assertRaises(ValueError): @@ -169,9 +172,9 @@ def test_list_models(self): # The low level lib wraps the response in an iterable, so this is a fair test. self.responses = { "list_models": [ - glm.Model(name="models/fake-bison-001"), - glm.Model(name="models/fake-bison-002"), - glm.Model(name="models/fake-bison-003"), + protos.Model(name="models/fake-bison-001"), + protos.Model(name="models/fake-bison-002"), + protos.Model(name="models/fake-bison-003"), ] } @@ -184,9 +187,9 @@ def test_list_tuned_models(self): self.responses = { # The low level lib wraps the response in an iterable, so this is a fair test. "list_tuned_models": [ - glm.TunedModel(name="tunedModels/my-pig-001"), - glm.TunedModel(name="tunedModels/my-pig-002"), - glm.TunedModel(name="tunedModels/my-pig-003"), + protos.TunedModel(name="tunedModels/my-pig-001"), + protos.TunedModel(name="tunedModels/my-pig-002"), + protos.TunedModel(name="tunedModels/my-pig-003"), ] } found_models = list(models.list_tuned_models()) @@ -196,8 +199,8 @@ def test_list_tuned_models(self): @parameterized.named_parameters( [ - "edited-glm-model", - glm.TunedModel( + "edited-protos.model", + protos.TunedModel( name="tunedModels/my-pig-001", description="Trained on my data", ), @@ -210,7 +213,7 @@ def test_list_tuned_models(self): ], ) def test_update_tuned_model_basics(self, tuned_model, updates): - self.responses["get_tuned_model"] = glm.TunedModel(name="tunedModels/my-pig-001") + self.responses["get_tuned_model"] = protos.TunedModel(name="tunedModels/my-pig-001") # No self.responses['update_tuned_model'] the mock just returns the input. updated_model = models.update_tuned_model(tuned_model, updates) updated_model.description = "Trained on my data" @@ -226,7 +229,7 @@ def test_update_tuned_model_basics(self, tuned_model, updates): ], ) def test_update_tuned_model_nested_fields(self, updates): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/my-pig-001", base_model="models/dance-monkey-007" ) @@ -249,8 +252,8 @@ def test_update_tuned_model_nested_fields(self, updates): @parameterized.named_parameters( ["name", "tunedModels/bipedal-pangolin-223"], [ - "glm.TunedModel", - glm.TunedModel(name="tunedModels/bipedal-pangolin-223"), + "protos.TunedModel", + protos.TunedModel(name="tunedModels/bipedal-pangolin-223"), ], [ "models.TunedModel", @@ -274,23 +277,23 @@ def test_decode_micros(self, time_str, micros): self.assertEqual(time["time"].microsecond, micros) def test_decode_tuned_model(self): - out_fields = glm.TunedModel( - state=glm.TunedModel.State.CREATING, + out_fields = protos.TunedModel( + state=protos.TunedModel.State.CREATING, create_time="2000-01-01T01:01:01.0Z", update_time="2001-01-01T01:01:01.0Z", - tuning_task=glm.TuningTask( - hyperparameters=glm.Hyperparameters( + tuning_task=protos.TuningTask( + hyperparameters=protos.Hyperparameters( batch_size=72, epoch_count=1, learning_rate=0.1 ), start_time="2002-01-01T01:01:01.0Z", complete_time="2003-01-01T01:01:01.0Z", snapshots=[ - glm.TuningSnapshot( + protos.TuningSnapshot( step=1, epoch=1, compute_time="2004-01-01T01:01:01.0Z", ), - glm.TuningSnapshot( + protos.TuningSnapshot( step=2, epoch=1, compute_time="2005-01-01T01:01:01.0Z", @@ -300,7 +303,7 @@ def test_decode_tuned_model(self): ) decoded = model_types.decode_tuned_model(out_fields) - self.assertEqual(decoded.state, glm.TunedModel.State.CREATING) + self.assertEqual(decoded.state, protos.TunedModel.State.CREATING) self.assertEqual(decoded.create_time.year, 2000) self.assertEqual(decoded.update_time.year, 2001) self.assertIsInstance(decoded.tuning_task.hyperparameters, model_types.Hyperparameters) @@ -313,10 +316,10 @@ def test_decode_tuned_model(self): self.assertEqual(decoded.tuning_task.snapshots[1]["compute_time"].year, 2005) @parameterized.named_parameters( - ["simple", glm.TunedModel(base_model="models/swim-fish-000")], + ["simple", protos.TunedModel(base_model="models/swim-fish-000")], [ "nested", - glm.TunedModel( + protos.TunedModel( tuned_model_source={ "tuned_model": "tunedModels/hidden-fish-55", "base_model": "models/swim-fish-000", @@ -340,7 +343,7 @@ def test_smoke_create_tuned_model(self): training_data=[ ("in", "out"), {"text_input": "in", "output": "out"}, - glm.TuningExample(text_input="in", output="out"), + protos.TuningExample(text_input="in", output="out"), ], ) req = self.observed_requests[-1] @@ -350,10 +353,10 @@ def test_smoke_create_tuned_model(self): self.assertLen(req.tuned_model.tuning_task.training_data.examples.examples, 3) @parameterized.named_parameters( - ["simple", glm.TunedModel(base_model="models/swim-fish-000")], + ["simple", protos.TunedModel(base_model="models/swim-fish-000")], [ "nested", - glm.TunedModel( + protos.TunedModel( tuned_model_source={ "tuned_model": "tunedModels/hidden-fish-55", "base_model": "models/swim-fish-000", @@ -379,9 +382,9 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): @parameterized.named_parameters( [ - "glm", - glm.Dataset( - examples=glm.TuningExamples( + "protos", + protos.Dataset( + examples=protos.TuningExamples( examples=[ {"text_input": "a", "output": "1"}, {"text_input": "b", "output": "2"}, @@ -395,7 +398,7 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): [ ("a", "1"), {"text_input": "b", "output": "2"}, - glm.TuningExample({"text_input": "c", "output": "3"}), + protos.TuningExample({"text_input": "c", "output": "3"}), ], ], ["dict", {"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}], @@ -444,8 +447,8 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): def test_create_dataset(self, data, ik="text_input", ok="output"): ds = model_types.encode_tuning_data(data, input_key=ik, output_key=ok) - expect = glm.Dataset( - examples=glm.TuningExamples( + expect = protos.Dataset( + examples=protos.TuningExamples( examples=[ {"text_input": "a", "output": "1"}, {"text_input": "b", "output": "2"}, @@ -470,7 +473,7 @@ def test_get_model_called_with_request_options(self): def test_get_tuned_model_called_with_request_options(self): self.client.get_tuned_model = unittest.mock.MagicMock() name = unittest.mock.ANY - request_options = {"timeout": 120} + request_options = genai_types.RequestOptions(timeout=120) try: models.get_model(name="tunedModels/", request_options=request_options) @@ -501,7 +504,7 @@ def test_update_tuned_model_called_with_request_options(self): self.client.update_tuned_model = unittest.mock.MagicMock() request = unittest.mock.ANY request_options = {"timeout": 120} - self.responses["get_tuned_model"] = glm.TunedModel(name="tunedModels/") + self.responses["get_tuned_model"] = protos.TunedModel(name="tunedModels/") try: models.update_tuned_model( @@ -533,7 +536,7 @@ def test_create_tuned_model_called_with_request_options(self): training_data=[ ("in", "out"), {"text_input": "in", "output": "out"}, - glm.TuningExample(text_input="in", output="out"), + protos.TuningExample(text_input="in", output="out"), ], request_options=request_options, ) diff --git a/tests/test_operations.py b/tests/test_operations.py index 80262db88..6529b77e5 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -16,7 +16,7 @@ from contextlib import redirect_stderr import io -import google.ai.generativelanguage as glm +from google.generativeai import protos import google.protobuf.any_pb2 import google.generativeai.operations as genai_operation @@ -41,7 +41,7 @@ def test_end_to_end(self): # `Any` takes a type name and a serialized proto. metadata = google.protobuf.any_pb2.Any( type_url=self.metadata_type, - value=glm.CreateTunedModelMetadata(tuned_model=name)._pb.SerializeToString(), + value=protos.CreateTunedModelMetadata(tuned_model=name)._pb.SerializeToString(), ) # Initially the `Operation` is not `done`, so it only gives a metadata. @@ -58,7 +58,7 @@ def test_end_to_end(self): metadata=metadata, response=google.protobuf.any_pb2.Any( type_url=self.result_type, - value=glm.TunedModel(name=name)._pb.SerializeToString(), + value=protos.TunedModel(name=name)._pb.SerializeToString(), ), ) @@ -72,8 +72,8 @@ def refresh(*_, **__): operation=initial_pb, refresh=refresh, cancel=lambda: print(f"cancel!"), - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) # Use our wrapper instead. @@ -99,7 +99,7 @@ def gen_operations(): def make_metadata(completed_steps): return google.protobuf.any_pb2.Any( type_url=self.metadata_type, - value=glm.CreateTunedModelMetadata( + value=protos.CreateTunedModelMetadata( tuned_model=name, total_steps=total_steps, completed_steps=completed_steps, @@ -122,7 +122,7 @@ def make_metadata(completed_steps): metadata=make_metadata(total_steps), response=google.protobuf.any_pb2.Any( type_url=self.result_type, - value=glm.TunedModel(name=name)._pb.SerializeToString(), + value=protos.TunedModel(name=name)._pb.SerializeToString(), ), ) @@ -142,8 +142,8 @@ def refresh(*_, **__): operation=initial_pb, refresh=refresh, cancel=None, - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) # Use our wrapper instead. diff --git a/tests/test_permission.py b/tests/test_permission.py index 55ad7a2f0..66b396977 100644 --- a/tests/test_permission.py +++ b/tests/test_permission.py @@ -17,7 +17,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import permission @@ -50,11 +50,11 @@ def add_client_method(f): @add_client_method def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -63,24 +63,24 @@ def create_corpus( @add_client_method def get_tuned_model( - request: Optional[glm.GetTunedModelRequest] = None, + request: Optional[protos.GetTunedModelRequest] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method def create_permission( - request: glm.CreatePermissionRequest, - ) -> glm.Permission: + request: protos.CreatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -88,17 +88,17 @@ def create_permission( @add_client_method def delete_permission( - request: glm.DeletePermissionRequest, + request: protos.DeletePermissionRequest, ) -> None: self.observed_requests.append(request) return None @add_client_method def get_permission( - request: glm.GetPermissionRequest, - ) -> glm.Permission: + request: protos.GetPermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -106,16 +106,16 @@ def get_permission( @add_client_method def list_permissions( - request: glm.ListPermissionsRequest, - ) -> glm.ListPermissionsResponse: + request: protos.ListPermissionsRequest, + ) -> protos.ListPermissionsResponse: self.observed_requests.append(request) return [ - glm.Permission( + protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), ), - glm.Permission( + protos.Permission( name="corpora/demo-corpus/permissions/987654321", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -125,10 +125,10 @@ def list_permissions( @add_client_method def update_permission( - request: glm.UpdatePermissionRequest, - ) -> glm.Permission: + request: protos.UpdatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -136,16 +136,16 @@ def update_permission( @add_client_method def transfer_ownership( - request: glm.TransferOwnershipRequest, - ) -> glm.TransferOwnershipResponse: + request: protos.TransferOwnershipRequest, + ) -> protos.TransferOwnershipResponse: self.observed_requests.append(request) - return glm.TransferOwnershipResponse() + return protos.TransferOwnershipResponse() def test_create_permission_success(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create(role="writer", grantee_type="everyone", email_address=None) self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.CreatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreatePermissionRequest) def test_create_permission_failure_email_set_when_grantee_type_is_everyone(self): x = retriever.create_corpus("demo-corpus") @@ -161,14 +161,14 @@ def test_delete_permission(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") perm.delete() - self.assertIsInstance(self.observed_requests[-1], glm.DeletePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeletePermissionRequest) def test_get_permission_with_full_name(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") fetch_perm = permission.get_permission(name=perm.name) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) def test_get_permission_with_resource_name_and_id_1(self): @@ -178,7 +178,7 @@ def test_get_permission_with_resource_name_and_id_1(self): resource_name="corpora/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) def test_get_permission_with_resource_name_name_and_id_2(self): @@ -186,14 +186,14 @@ def test_get_permission_with_resource_name_name_and_id_2(self): resource_name="tunedModels/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) def test_get_permission_with_resource_type(self): fetch_perm = permission.get_permission( resource_name="demo-model", permission_id=123456789, resource_type="tunedModels" ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) @parameterized.named_parameters( dict( @@ -257,14 +257,14 @@ def test_list_permission(self): self.assertEqual(perms[1].email_address, "_") for perm in perms: self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.ListPermissionsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListPermissionsRequest) def test_update_permission_success(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") updated_perm = perm.update({"role": permission_services.to_role("reader")}) self.assertIsInstance(updated_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.UpdatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdatePermissionRequest) def test_update_permission_failure_restricted_update_path(self): x = retriever.create_corpus("demo-corpus") @@ -275,12 +275,12 @@ def test_update_permission_failure_restricted_update_path(self): ) def test_transfer_ownership(self): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/fake-pig-001", base_model="models/dance-monkey-007" ) x = models.get_tuned_model("tunedModels/fake-pig-001") response = x.permissions.transfer_ownership(email_address="_") - self.assertIsInstance(self.observed_requests[-1], glm.TransferOwnershipRequest) + self.assertIsInstance(self.observed_requests[-1], protos.TransferOwnershipRequest) def test_transfer_ownership_on_corpora(self): x = retriever.create_corpus("demo-corpus") diff --git a/tests/test_permission_async.py b/tests/test_permission_async.py index 165039122..ddc9c22a2 100644 --- a/tests/test_permission_async.py +++ b/tests/test_permission_async.py @@ -17,7 +17,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import permission @@ -49,11 +49,11 @@ def add_client_method(f): @add_client_method async def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -62,24 +62,24 @@ async def create_corpus( @add_client_method def get_tuned_model( - request: Optional[glm.GetTunedModelRequest] = None, + request: Optional[protos.GetTunedModelRequest] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method async def create_permission( - request: glm.CreatePermissionRequest, - ) -> glm.Permission: + request: protos.CreatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -87,17 +87,17 @@ async def create_permission( @add_client_method async def delete_permission( - request: glm.DeletePermissionRequest, + request: protos.DeletePermissionRequest, ) -> None: self.observed_requests.append(request) return None @add_client_method async def get_permission( - request: glm.GetPermissionRequest, - ) -> glm.Permission: + request: protos.GetPermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -105,17 +105,17 @@ async def get_permission( @add_client_method async def list_permissions( - request: glm.ListPermissionsRequest, - ) -> glm.ListPermissionsResponse: + request: protos.ListPermissionsRequest, + ) -> protos.ListPermissionsResponse: self.observed_requests.append(request) async def results(): - yield glm.Permission( + yield protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), ) - yield glm.Permission( + yield protos.Permission( name="corpora/demo-corpus/permissions/987654321", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -126,10 +126,10 @@ async def results(): @add_client_method async def update_permission( - request: glm.UpdatePermissionRequest, - ) -> glm.Permission: + request: protos.UpdatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -137,10 +137,10 @@ async def update_permission( @add_client_method async def transfer_ownership( - request: glm.TransferOwnershipRequest, - ) -> glm.TransferOwnershipResponse: + request: protos.TransferOwnershipRequest, + ) -> protos.TransferOwnershipResponse: self.observed_requests.append(request) - return glm.TransferOwnershipResponse() + return protos.TransferOwnershipResponse() async def test_create_permission_success(self): x = await retriever.create_corpus_async("demo-corpus") @@ -148,7 +148,7 @@ async def test_create_permission_success(self): role="writer", grantee_type="everyone", email_address=None ) self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.CreatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreatePermissionRequest) async def test_create_permission_failure_email_set_when_grantee_type_is_everyone(self): x = await retriever.create_corpus_async("demo-corpus") @@ -168,14 +168,14 @@ async def test_delete_permission(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") await perm.delete_async() - self.assertIsInstance(self.observed_requests[-1], glm.DeletePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeletePermissionRequest) async def test_get_permission_with_full_name(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") fetch_perm = await permission.get_permission_async(name=perm.name) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) async def test_get_permission_with_resource_name_and_id_1(self): @@ -185,7 +185,7 @@ async def test_get_permission_with_resource_name_and_id_1(self): resource_name="corpora/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) async def test_get_permission_with_resource_name_name_and_id_2(self): @@ -193,14 +193,14 @@ async def test_get_permission_with_resource_name_name_and_id_2(self): resource_name="tunedModels/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) async def test_get_permission_with_resource_type(self): fetch_perm = await permission.get_permission_async( resource_name="demo-model", permission_id=123456789, resource_type="tunedModels" ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) @parameterized.named_parameters( dict( @@ -264,14 +264,14 @@ async def test_list_permission(self): self.assertEqual(perms[1].email_address, "_") for perm in perms: self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.ListPermissionsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListPermissionsRequest) async def test_update_permission_success(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") updated_perm = await perm.update_async({"role": permission_services.to_role("reader")}) self.assertIsInstance(updated_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.UpdatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdatePermissionRequest) async def test_update_permission_failure_restricted_update_path(self): x = await retriever.create_corpus_async("demo-corpus") @@ -282,12 +282,12 @@ async def test_update_permission_failure_restricted_update_path(self): ) async def test_transfer_ownership(self): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/fake-pig-001", base_model="models/dance-monkey-007" ) x = models.get_tuned_model("tunedModels/fake-pig-001") response = await x.permissions.transfer_ownership_async(email_address="_") - self.assertIsInstance(self.observed_requests[-1], glm.TransferOwnershipRequest) + self.assertIsInstance(self.observed_requests[-1], protos.TransferOwnershipRequest) async def test_transfer_ownership_on_corpora(self): x = await retriever.create_corpus_async("demo-corpus") diff --git a/tests/test_protos.py b/tests/test_protos.py new file mode 100644 index 000000000..1b59b0c6e --- /dev/null +++ b/tests/test_protos.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib +import re + +from absl.testing import parameterized + +ROOT = pathlib.Path(__file__).parent.parent + + +class UnitTests(parameterized.TestCase): + def test_check_glm_imports(self): + for fpath in ROOT.rglob("*.py"): + if fpath.name == "build_docs.py": + continue + content = fpath.read_text() + for match in re.findall("glm\.\w+", content): + self.assertIn( + "Client", + match, + msg=f"Bad `glm.` usage, use `genai.protos` instead,\n in {fpath}", + ) diff --git a/tests/test_responder.py b/tests/test_responder.py index 4eb310815..c075fc65a 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -17,7 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import responder import IPython.display import PIL.Image @@ -42,9 +42,9 @@ class UnitTests(parameterized.TestCase): [ "FunctionLibrary", responder.FunctionLibrary( - tools=glm.Tool( + tools=protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -56,7 +56,7 @@ class UnitTests(parameterized.TestCase): [ responder.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -64,11 +64,11 @@ class UnitTests(parameterized.TestCase): ], ], [ - "IterableTool-glm.Tool", + "IterableTool-protos.Tool", [ - glm.Tool( + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -93,7 +93,7 @@ class UnitTests(parameterized.TestCase): "IterableTool-IterableFD", [ [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -103,7 +103,7 @@ class UnitTests(parameterized.TestCase): [ "IterableTool-FD", [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -113,17 +113,17 @@ class UnitTests(parameterized.TestCase): "Tool", responder.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] ), ], [ - "glm.Tool", - glm.Tool( + "protos.Tool", + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -175,8 +175,8 @@ class UnitTests(parameterized.TestCase): ), ], [ - "glm.FD", - glm.FunctionDeclaration( + "protos.FD", + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ), ], @@ -216,32 +216,32 @@ def b(): self.assertLen(tools[0].function_declarations, 2) @parameterized.named_parameters( - ["int", int, glm.Schema(type=glm.Type.INTEGER)], - ["float", float, glm.Schema(type=glm.Type.NUMBER)], - ["str", str, glm.Schema(type=glm.Type.STRING)], + ["int", int, protos.Schema(type=protos.Type.INTEGER)], + ["float", float, protos.Schema(type=protos.Type.NUMBER)], + ["str", str, protos.Schema(type=protos.Type.STRING)], [ "list", list[str], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.STRING), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.STRING), ), ], [ "list-list-int", list[list[int]], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema( - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.INTEGER), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema( + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.INTEGER), ), ), ), ], - ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], - ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ["dict", dict, protos.Schema(type=protos.Type.OBJECT)], + ["dict-str-any", dict[str, Any], protos.Schema(type=protos.Type.OBJECT)], ) def test_auto_schema(self, annotation, expected): def fun(a: annotation): diff --git a/tests/test_retriever.py b/tests/test_retriever.py index 910183789..bce9a402b 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -16,7 +16,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import client @@ -42,11 +42,11 @@ def add_client_method(f): @add_client_method def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo_corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -55,11 +55,11 @@ def create_corpus( @add_client_method def get_corpus( - request: glm.GetCorpusRequest, + request: protos.GetCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo_corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -68,11 +68,11 @@ def get_corpus( @add_client_method def update_corpus( - request: glm.UpdateCorpusRequest, + request: protos.UpdateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", @@ -81,18 +81,18 @@ def update_corpus( @add_client_method def list_corpora( - request: glm.ListCorporaRequest, + request: protos.ListCorporaRequest, **kwargs, - ) -> glm.ListCorporaResponse: + ) -> protos.ListCorporaResponse: self.observed_requests.append(request) return [ - glm.Corpus( + protos.Corpus( name="corpora/demo_corpus-1", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Corpus( + protos.Corpus( name="corpora/demo-corpus-2", display_name="demo-corpus-2", create_time="2000-01-01T01:01:01.123456Z", @@ -102,15 +102,15 @@ def list_corpora( @add_client_method def query_corpus( - request: glm.QueryCorpusRequest, + request: protos.QueryCorpusRequest, **kwargs, - ) -> glm.QueryCorpusResponse: + ) -> protos.QueryCorpusResponse: self.observed_requests.append(request) - return glm.QueryCorpusResponse( + return protos.QueryCorpusResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -124,18 +124,18 @@ def query_corpus( @add_client_method def delete_corpus( - request: glm.DeleteCorpusRequest, + request: protos.DeleteCorpusRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def create_document( - request: glm.CreateDocumentRequest, + request: protos.CreateDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -144,11 +144,11 @@ def create_document( @add_client_method def get_document( - request: glm.GetDocumentRequest, + request: protos.GetDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo_doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -157,11 +157,11 @@ def get_document( @add_client_method def update_document( - request: glm.UpdateDocumentRequest, + request: protos.UpdateDocumentRequest, **kwargs, - ) -> glm.Document: + ) -> protos.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo_doc", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", @@ -170,18 +170,18 @@ def update_document( @add_client_method def list_documents( - request: glm.ListDocumentsRequest, + request: protos.ListDocumentsRequest, **kwargs, - ) -> glm.ListDocumentsResponse: + ) -> protos.ListDocumentsResponse: self.observed_requests.append(request) return [ - glm.Document( + protos.Document( name="corpora/demo-corpus/documents/demo_doc_1", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Document( + protos.Document( name="corpora/demo-corpus/documents/demo_doc_2", display_name="demo-doc-2", create_time="2000-01-01T01:01:01.123456Z", @@ -191,22 +191,22 @@ def list_documents( @add_client_method def delete_document( - request: glm.DeleteDocumentRequest, + request: protos.DeleteDocumentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def query_document( - request: glm.QueryDocumentRequest, + request: protos.QueryDocumentRequest, **kwargs, - ) -> glm.QueryDocumentResponse: + ) -> protos.QueryDocumentResponse: self.observed_requests.append(request) - return glm.QueryDocumentResponse( + return protos.QueryDocumentResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -220,11 +220,11 @@ def query_document( @add_client_method def create_chunk( - request: glm.CreateChunkRequest, + request: protos.CreateChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -233,19 +233,19 @@ def create_chunk( @add_client_method def batch_create_chunks( - request: glm.BatchCreateChunksRequest, + request: protos.BatchCreateChunksRequest, **kwargs, - ) -> glm.BatchCreateChunksResponse: + ) -> protos.BatchCreateChunksResponse: self.observed_requests.append(request) - return glm.BatchCreateChunksResponse( + return protos.BatchCreateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/dc", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/dc1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -256,11 +256,11 @@ def batch_create_chunks( @add_client_method def get_chunk( - request: glm.GetChunkRequest, + request: protos.GetChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -269,18 +269,18 @@ def get_chunk( @add_client_method def list_chunks( - request: glm.ListChunksRequest, + request: protos.ListChunksRequest, **kwargs, - ) -> glm.ListChunksResponse: + ) -> protos.ListChunksResponse: self.observed_requests.append(request) return [ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk-1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -290,17 +290,17 @@ def list_chunks( @add_client_method def update_chunk( - request: glm.UpdateChunkRequest, + request: protos.UpdateChunkRequest, **kwargs, - ) -> glm.Chunk: + ) -> protos.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated demo chunk."}, custom_metadata=[ - glm.CustomMetadata( + protos.CustomMetadata( key="tags", - string_list_value=glm.StringList( + string_list_value=protos.StringList( values=["Google For Developers", "Project IDX", "Blog", "Announcement"] ), ) @@ -311,19 +311,19 @@ def update_chunk( @add_client_method def batch_update_chunks( - request: glm.BatchUpdateChunksRequest, + request: protos.BatchUpdateChunksRequest, **kwargs, - ) -> glm.BatchUpdateChunksResponse: + ) -> protos.BatchUpdateChunksResponse: self.observed_requests.append(request) - return glm.BatchUpdateChunksResponse( + return protos.BatchUpdateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk-1", data={"string_value": "This is another updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -334,14 +334,14 @@ def batch_update_chunks( @add_client_method def delete_chunk( - request: glm.DeleteChunkRequest, + request: protos.DeleteChunkRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def batch_delete_chunks( - request: glm.BatchDeleteChunksRequest, + request: protos.BatchDeleteChunksRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -366,7 +366,7 @@ def test_get_corpus(self, name="demo-corpus"): def test_update_corpus(self): demo_corpus = retriever.create_corpus(name="demo-corpus") update_request = demo_corpus.update(updates={"display_name": "demo-corpus_1"}) - self.assertIsInstance(self.observed_requests[-1], glm.UpdateCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdateCorpusRequest) self.assertEqual("demo-corpus_1", demo_corpus.display_name) def test_list_corpora(self): @@ -402,7 +402,7 @@ def test_delete_corpus(self): demo_corpus = retriever.create_corpus(name="demo-corpus") demo_document = demo_corpus.create_document(name="demo-doc") delete_request = retriever.delete_corpus(name="corpora/demo_corpus", force=True) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCorpusRequest) def test_create_document(self, display_name="demo-doc"): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -433,7 +433,7 @@ def test_delete_document(self): demo_document = demo_corpus.create_document(name="demo-doc") demo_doc2 = demo_corpus.create_document(name="demo-doc-2") delete_request = demo_corpus.delete_document(name="corpora/demo-corpus/documents/demo_doc") - self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteDocumentRequest) def test_list_documents(self): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -521,7 +521,7 @@ def test_batch_create_chunks(self, chunks): demo_corpus = retriever.create_corpus(name="demo-corpus") demo_document = demo_corpus.create_document(name="demo-doc") chunks = demo_document.batch_create_chunks(chunks=chunks) - self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchCreateChunksRequest) self.assertEqual("This is a demo chunk.", chunks[0].data.string_value) self.assertEqual("This is another demo chunk.", chunks[1].data.string_value) @@ -548,7 +548,7 @@ def test_list_chunks(self): ) list_req = list(demo_document.list_chunks()) - self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListChunksRequest) self.assertLen(list_req, 2) def test_update_chunk(self): @@ -615,7 +615,7 @@ def test_batch_update_chunks_data_structures(self, updates): data="This is another demo chunk.", ) update_request = demo_document.batch_update_chunks(chunks=updates) - self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchUpdateChunksRequest) self.assertEqual( "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] ) @@ -631,7 +631,7 @@ def test_delete_chunk(self): data="This is a demo chunk.", ) delete_request = demo_document.delete_chunk(name="demo-chunk") - self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteChunkRequest) def test_batch_delete_chunks(self): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -645,7 +645,7 @@ def test_batch_delete_chunks(self): data="This is another demo chunk.", ) delete_request = demo_document.batch_delete_chunks(chunks=[x.name, y.name]) - self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchDeleteChunksRequest) @parameterized.parameters( {"method": "create_corpus"}, diff --git a/tests/test_retriever_async.py b/tests/test_retriever_async.py index b764c23b2..bb0c862d1 100644 --- a/tests/test_retriever_async.py +++ b/tests/test_retriever_async.py @@ -19,7 +19,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import client as client_lib @@ -44,11 +44,11 @@ def add_client_method(f): @add_client_method async def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -57,11 +57,11 @@ async def create_corpus( @add_client_method async def get_corpus( - request: glm.GetCorpusRequest, + request: protos.GetCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -70,11 +70,11 @@ async def get_corpus( @add_client_method async def update_corpus( - request: glm.UpdateCorpusRequest, + request: protos.UpdateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", @@ -83,19 +83,19 @@ async def update_corpus( @add_client_method async def list_corpora( - request: glm.ListCorporaRequest, + request: protos.ListCorporaRequest, **kwargs, - ) -> glm.ListCorporaResponse: + ) -> protos.ListCorporaResponse: self.observed_requests.append(request) async def results(): - yield glm.Corpus( + yield protos.Corpus( name="corpora/demo-corpus-1", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Corpus( + yield protos.Corpus( name="corpora/demo-corpus_2", display_name="demo-corpus-2", create_time="2000-01-01T01:01:01.123456Z", @@ -106,15 +106,15 @@ async def results(): @add_client_method async def query_corpus( - request: glm.QueryCorpusRequest, + request: protos.QueryCorpusRequest, **kwargs, - ) -> glm.QueryCorpusResponse: + ) -> protos.QueryCorpusResponse: self.observed_requests.append(request) - return glm.QueryCorpusResponse( + return protos.QueryCorpusResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -128,18 +128,18 @@ async def query_corpus( @add_client_method async def delete_corpus( - request: glm.DeleteCorpusRequest, + request: protos.DeleteCorpusRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def create_document( - request: glm.CreateDocumentRequest, + request: protos.CreateDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -148,11 +148,11 @@ async def create_document( @add_client_method async def get_document( - request: glm.GetDocumentRequest, + request: protos.GetDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -161,11 +161,11 @@ async def get_document( @add_client_method async def update_document( - request: glm.UpdateDocumentRequest, + request: protos.UpdateDocumentRequest, **kwargs, - ) -> glm.Document: + ) -> protos.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", @@ -174,19 +174,19 @@ async def update_document( @add_client_method async def list_documents( - request: glm.ListDocumentsRequest, + request: protos.ListDocumentsRequest, **kwargs, - ) -> glm.ListDocumentsResponse: + ) -> protos.ListDocumentsResponse: self.observed_requests.append(request) async def results(): - yield glm.Document( + yield protos.Document( name="corpora/demo-corpus/documents/dem-doc_1", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Document( + yield protos.Document( name="corpora/demo-corpus/documents/dem-doc_2", display_name="demo-doc_2", create_time="2000-01-01T01:01:01.123456Z", @@ -197,22 +197,22 @@ async def results(): @add_client_method async def delete_document( - request: glm.DeleteDocumentRequest, + request: protos.DeleteDocumentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def query_document( - request: glm.QueryDocumentRequest, + request: protos.QueryDocumentRequest, **kwargs, - ) -> glm.QueryDocumentResponse: + ) -> protos.QueryDocumentResponse: self.observed_requests.append(request) - return glm.QueryDocumentResponse( + return protos.QueryDocumentResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -226,11 +226,11 @@ async def query_document( @add_client_method async def create_chunk( - request: glm.CreateChunkRequest, + request: protos.CreateChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -239,19 +239,19 @@ async def create_chunk( @add_client_method async def batch_create_chunks( - request: glm.BatchCreateChunksRequest, + request: protos.BatchCreateChunksRequest, **kwargs, - ) -> glm.BatchCreateChunksResponse: + ) -> protos.BatchCreateChunksResponse: self.observed_requests.append(request) - return glm.BatchCreateChunksResponse( + return protos.BatchCreateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/dc", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/dc1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -262,11 +262,11 @@ async def batch_create_chunks( @add_client_method async def get_chunk( - request: glm.GetChunkRequest, + request: protos.GetChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -275,19 +275,19 @@ async def get_chunk( @add_client_method async def list_chunks( - request: glm.ListChunksRequest, + request: protos.ListChunksRequest, **kwargs, - ) -> glm.ListChunksResponse: + ) -> protos.ListChunksResponse: self.observed_requests.append(request) async def results(): - yield glm.Chunk( + yield protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Chunk( + yield protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk-1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -298,11 +298,11 @@ async def results(): @add_client_method async def update_chunk( - request: glm.UpdateChunkRequest, + request: protos.UpdateChunkRequest, **kwargs, - ) -> glm.Chunk: + ) -> protos.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -311,19 +311,19 @@ async def update_chunk( @add_client_method async def batch_update_chunks( - request: glm.BatchUpdateChunksRequest, + request: protos.BatchUpdateChunksRequest, **kwargs, - ) -> glm.BatchUpdateChunksResponse: + ) -> protos.BatchUpdateChunksResponse: self.observed_requests.append(request) - return glm.BatchUpdateChunksResponse( + return protos.BatchUpdateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk-1", data={"string_value": "This is another updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -334,14 +334,14 @@ async def batch_update_chunks( @add_client_method async def delete_chunk( - request: glm.DeleteChunkRequest, + request: protos.DeleteChunkRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def batch_delete_chunks( - request: glm.BatchDeleteChunksRequest, + request: protos.BatchDeleteChunksRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -398,7 +398,7 @@ async def test_delete_corpus(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") demo_document = await demo_corpus.create_document_async(name="demo-doc") delete_request = await retriever.delete_corpus_async(name="corpora/demo-corpus", force=True) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCorpusRequest) async def test_create_document(self, display_name="demo-doc"): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -425,7 +425,7 @@ async def test_delete_document(self): delete_request = await demo_corpus.delete_document_async( name="corpora/demo-corpus/documents/demo-doc" ) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteDocumentRequest) async def test_list_documents(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -513,7 +513,7 @@ async def test_batch_create_chunks(self, chunks): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") demo_document = await demo_corpus.create_document_async(name="demo-doc") chunks = await demo_document.batch_create_chunks_async(chunks=chunks) - self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchCreateChunksRequest) self.assertEqual("This is a demo chunk.", chunks[0].data.string_value) self.assertEqual("This is another demo chunk.", chunks[1].data.string_value) @@ -541,7 +541,7 @@ async def test_list_chunks(self): chunks = [] async for chunk in demo_document.list_chunks_async(): chunks.append(chunk) - self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListChunksRequest) self.assertLen(chunks, 2) async def test_update_chunk(self): @@ -597,7 +597,7 @@ async def test_batch_update_chunks_data_structures(self, updates): data="This is another demo chunk.", ) update_request = await demo_document.batch_update_chunks_async(chunks=updates) - self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchUpdateChunksRequest) self.assertEqual( "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] ) @@ -615,7 +615,7 @@ async def test_delete_chunk(self): delete_request = await demo_document.delete_chunk_async( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk" ) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteChunkRequest) async def test_batch_delete_chunks(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -629,7 +629,7 @@ async def test_batch_delete_chunks(self): data="This is another demo chunk.", ) delete_request = await demo_document.batch_delete_chunks_async(chunks=[x.name, y.name]) - self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchDeleteChunksRequest) async def test_get_corpus_called_with_request_options(self): self.client.get_corpus = unittest.mock.AsyncMock() diff --git a/tests/test_safety.py b/tests/test_safety.py new file mode 100644 index 000000000..2ac8aca46 --- /dev/null +++ b/tests/test_safety.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from google.generativeai.types import safety_types +from google.generativeai import protos + + +class SafetyTests(parameterized.TestCase): + """Tests are in order with the design doc.""" + + @parameterized.named_parameters( + ["block_threshold", protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], + ["block_threshold2", "medium"], + ["block_threshold3", 2], + ["dict", {"danger": "medium"}], + ["dict2", {"danger": 2}], + ["dict3", {"danger": protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], + [ + "list-dict", + [ + dict( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ), + ], + ], + [ + "list-dict2", + [ + dict(category="danger", threshold="med"), + ], + ], + ) + def test_safety_overwrite(self, setting): + setting = safety_types.to_easy_safety_dict(setting) + self.assertEqual( + setting[protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], + protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/test_text.py b/tests/test_text.py index 0bc1d4e59..795c3dfcd 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -18,11 +18,11 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import text as text_service from google.generativeai import client -from google.generativeai.types import safety_types +from google.generativeai.types import palm_safety_types from google.generativeai.types import model_types from absl.testing import absltest from absl.testing import parameterized @@ -46,42 +46,42 @@ def add_client_method(f): @add_client_method def generate_text( - request: glm.GenerateTextRequest, + request: protos.GenerateTextRequest, **kwargs, - ) -> glm.GenerateTextResponse: + ) -> protos.GenerateTextResponse: self.observed_requests.append(request) return self.responses["generate_text"] @add_client_method def embed_text( - request: glm.EmbedTextRequest, + request: protos.EmbedTextRequest, **kwargs, - ) -> glm.EmbedTextResponse: + ) -> protos.EmbedTextResponse: self.observed_requests.append(request) return self.responses["embed_text"] @add_client_method def batch_embed_text( - request: glm.EmbedTextRequest, + request: protos.EmbedTextRequest, **kwargs, - ) -> glm.EmbedTextResponse: + ) -> protos.EmbedTextResponse: self.observed_requests.append(request) - return glm.BatchEmbedTextResponse( - embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts) + return protos.BatchEmbedTextResponse( + embeddings=[protos.Embedding(value=[1, 2, 3])] * len(request.texts) ) @add_client_method def count_text_tokens( - request: glm.CountTextTokensRequest, + request: protos.CountTextTokensRequest, **kwargs, - ) -> glm.CountTextTokensResponse: + ) -> protos.CountTextTokensResponse: self.observed_requests.append(request) return self.responses["count_text_tokens"] @add_client_method - def get_tuned_model(name) -> glm.TunedModel: - request = glm.GetTunedModelRequest(name=name) + def get_tuned_model(name) -> protos.TunedModel: + request = protos.GetTunedModelRequest(name=name) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @@ -93,7 +93,7 @@ def get_tuned_model(name) -> glm.TunedModel: ) def test_make_prompt(self, prompt): x = text_service._make_text_prompt(prompt) - self.assertIsInstance(x, glm.TextPrompt) + self.assertIsInstance(x, protos.TextPrompt) self.assertEqual("Hello how are", x.text) @parameterized.named_parameters( @@ -104,7 +104,7 @@ def test_make_prompt(self, prompt): def test_make_generate_text_request(self, prompt): x = text_service._make_generate_text_request(model="models/chat-bison-001", prompt=prompt) self.assertEqual("models/chat-bison-001", x.model) - self.assertIsInstance(x, glm.GenerateTextRequest) + self.assertIsInstance(x, protos.GenerateTextRequest) @parameterized.named_parameters( [ @@ -116,14 +116,16 @@ def test_make_generate_text_request(self, prompt): ] ) def test_generate_embeddings(self, model, text): - self.responses["embed_text"] = glm.EmbedTextResponse( - embedding=glm.Embedding(value=[1, 2, 3]) + self.responses["embed_text"] = protos.EmbedTextResponse( + embedding=protos.Embedding(value=[1, 2, 3]) ) emb = text_service.generate_embeddings(model=model, text=text) self.assertIsInstance(emb, dict) - self.assertEqual(self.observed_requests[-1], glm.EmbedTextRequest(model=model, text=text)) + self.assertEqual( + self.observed_requests[-1], protos.EmbedTextRequest(model=model, text=text) + ) self.assertIsInstance(emb["embedding"][0], float) @parameterized.named_parameters( @@ -191,11 +193,11 @@ def test_generate_embeddings_batch(self, model, text): ] ) def test_generate_response(self, *, prompt, **kwargs): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output=" road?"), - glm.TextCompletion(output=" bridge?"), - glm.TextCompletion(output=" river?"), + protos.TextCompletion(output=" road?"), + protos.TextCompletion(output=" bridge?"), + protos.TextCompletion(output=" river?"), ] ) @@ -203,8 +205,8 @@ def test_generate_response(self, *, prompt, **kwargs): self.assertEqual( self.observed_requests[-1], - glm.GenerateTextRequest( - model="models/text-bison-001", prompt=glm.TextPrompt(text=prompt), **kwargs + protos.GenerateTextRequest( + model="models/text-bison-001", prompt=protos.TextPrompt(text=prompt), **kwargs ), ) @@ -220,20 +222,20 @@ def test_generate_response(self, *, prompt, **kwargs): ) def test_stop_string(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output="Hello world?"), - glm.TextCompletion(output="Hell!"), - glm.TextCompletion(output="I'm going to stop"), + protos.TextCompletion(output="Hello world?"), + protos.TextCompletion(output="Hell!"), + protos.TextCompletion(output="I'm going to stop"), ] ) complete = text_service.generate_text(prompt="Hello", stop_sequences="stop") self.assertEqual( self.observed_requests[-1], - glm.GenerateTextRequest( + protos.GenerateTextRequest( model="models/text-bison-001", - prompt=glm.TextPrompt(text="Hello"), + prompt=protos.TextPrompt(text="Hello"), stop_sequences=["stop"], ), ) @@ -246,12 +248,12 @@ def test_stop_string(self): testcase_name="basic", safety_settings=[ { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_NONE, }, { - "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, - "threshold": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, }, ], ), @@ -275,16 +277,16 @@ def test_stop_string(self): dict( testcase_name="mixed", safety_settings={ - "medical": safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, - safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE: 1, + "medical": palm_safety_types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE: 1, }, ), ] ) def test_safety_settings(self, safety_settings): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output="No"), + protos.TextCompletion(output="No"), ] ) # This test really just checks that the safety_settings get converted to a proto. @@ -294,36 +296,36 @@ def test_safety_settings(self, safety_settings): self.assertEqual( self.observed_requests[-1].safety_settings[0].category, - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) def test_filters(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[{"output": "hello"}], filters=[ { - "reason": safety_types.BlockedReason.SAFETY, + "reason": palm_safety_types.BlockedReason.SAFETY, "message": "not safe", } ], ) response = text_service.generate_text(prompt="do filters work?") - self.assertIsInstance(response.filters[0]["reason"], safety_types.BlockedReason) - self.assertEqual(response.filters[0]["reason"], safety_types.BlockedReason.SAFETY) + self.assertIsInstance(response.filters[0]["reason"], palm_safety_types.BlockedReason) + self.assertEqual(response.filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) def test_safety_feedback(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[{"output": "hello"}], safety_feedback=[ { "rating": { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "probability": safety_types.HarmProbability.HIGH, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": palm_safety_types.HarmProbability.HIGH, }, "setting": { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "threshold": safety_types.HarmBlockThreshold.BLOCK_NONE, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "threshold": palm_safety_types.HarmBlockThreshold.BLOCK_NONE, }, } ], @@ -332,35 +334,35 @@ def test_safety_feedback(self): response = text_service.generate_text(prompt="does safety feedback work?") self.assertIsInstance( response.safety_feedback[0]["rating"]["probability"], - safety_types.HarmProbability, + palm_safety_types.HarmProbability, ) self.assertEqual( response.safety_feedback[0]["rating"]["probability"], - safety_types.HarmProbability.HIGH, + palm_safety_types.HarmProbability.HIGH, ) self.assertIsInstance( response.safety_feedback[0]["setting"]["category"], - safety_types.HarmCategory, + protos.HarmCategory, ) self.assertEqual( response.safety_feedback[0]["setting"]["category"], - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) def test_candidate_safety_feedback(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ { "output": "hello", "safety_ratings": [ { - "category": safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, - "probability": safety_types.HarmProbability.HIGH, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + "probability": palm_safety_types.HarmProbability.HIGH, }, { - "category": safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, - "probability": safety_types.HarmProbability.LOW, + "category": palm_safety_types.HarmCategory.HARM_CATEGORY_VIOLENCE, + "probability": palm_safety_types.HarmProbability.LOW, }, ], } @@ -370,24 +372,24 @@ def test_candidate_safety_feedback(self): result = text_service.generate_text(prompt="Write a story from the ER.") self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["category"], - safety_types.HarmCategory, + protos.HarmCategory, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["category"], - safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, + palm_safety_types.HarmCategory.HARM_CATEGORY_MEDICAL, ) self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["probability"], - safety_types.HarmProbability, + palm_safety_types.HarmProbability, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["probability"], - safety_types.HarmProbability.HIGH, + palm_safety_types.HarmProbability.HIGH, ) def test_candidate_citations(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ { "output": "Hello Google!", @@ -434,21 +436,21 @@ def test_candidate_citations(self): ), ), dict( - testcase_name="glm_model", - model=glm.Model( + testcase_name="protos.model", + model=protos.Model( name="models/text-bison-001", ), ), dict( - testcase_name="glm_tuned_model", - model=glm.TunedModel( + testcase_name="protos.tuned_model", + model=protos.TunedModel( name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001", ), ), dict( - testcase_name="glm_tuned_model_nested", - model=glm.TunedModel( + testcase_name="protos.tuned_model_nested", + model=protos.TunedModel( name="tunedModels/bipedal-pangolin-002", tuned_model_source={ "tuned_model": "tunedModels/bipedal-pangolin-002", @@ -459,10 +461,10 @@ def test_candidate_citations(self): ] ) def test_count_message_tokens(self, model): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001" ) - self.responses["count_text_tokens"] = glm.CountTextTokensResponse(token_count=7) + self.responses["count_text_tokens"] = protos.CountTextTokensResponse(token_count=7) response = text_service.count_text_tokens(model, "Tell me a story about a magic backpack.") self.assertEqual({"token_count": 7}, response) @@ -472,7 +474,7 @@ def test_count_message_tokens(self, model): self.assertLen(self.observed_requests, 2) self.assertEqual( self.observed_requests[0], - glm.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), + protos.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), ) def test_count_text_tokens_called_with_request_options(self):