diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index f17a82a17..91b1e8848 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -26,12 +26,10 @@ get_default_generative_client, get_default_generative_async_client, ) -from google.generativeai import string_utils from google.generativeai.types import model_types -from google.generativeai import models +from google.generativeai.types import helper_types from google.generativeai.types import safety_types from google.generativeai.types import content_types -from google.generativeai.types import answer_types from google.generativeai.types import retriever_types from google.generativeai.types.retriever_types import MetadataFilter @@ -247,7 +245,7 @@ def generate_answer( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Calls the GenerateAnswer API and returns a `types.Answer` containing the response. @@ -320,7 +318,7 @@ async def generate_answer_async( safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Calls the API and returns a `types.Answer` containing the answer. diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index 0cc342096..1a4345550 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -26,6 +26,7 @@ from google.generativeai.client import get_default_discuss_async_client from google.generativeai import string_utils from google.generativeai.types import discuss_types +from google.generativeai.types import helper_types from google.generativeai.types import model_types from google.generativeai.types import safety_types @@ -316,7 +317,7 @@ def chat( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: """Calls the API and returns a `types.ChatResponse` containing the response. @@ -416,7 +417,7 @@ async def chat_async( top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: request = _make_generate_message_request( model=model, @@ -469,7 +470,7 @@ def last(self, message: discuss_types.MessageOptions): def reply( self, message: discuss_types.MessageOptions, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.ChatResponse: if isinstance(self._client, glm.DiscussServiceAsyncClient): raise TypeError(f"reply can't be called on an async client, use reply_async instead.") @@ -537,7 +538,7 @@ def _build_chat_response( def _generate_response( request: glm.GenerateMessageRequest, client: glm.DiscussServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: request_options = {} @@ -553,7 +554,7 @@ def _generate_response( async def _generate_response_async( request: glm.GenerateMessageRequest, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: if request_options is None: request_options = {} @@ -574,7 +575,7 @@ def count_message_tokens( messages: discuss_types.MessagesOptions | None = None, model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL, client: glm.DiscussServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> discuss_types.TokenCount: model = model_types.make_model_name(model) prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages) diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 375d5dcb4..14fff1737 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -14,8 +14,6 @@ # limitations under the License. from __future__ import annotations -import dataclasses -from collections.abc import Iterable, Sequence, Mapping import itertools from typing import Any, Iterable, overload, TypeVar, Union, Mapping @@ -24,7 +22,7 @@ from google.generativeai.client import get_default_generative_client from google.generativeai.client import get_default_generative_async_client -from google.generativeai import string_utils +from google.generativeai.types import helper_types from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai.types import content_types @@ -104,7 +102,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -116,7 +114,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -127,7 +125,7 @@ def embed_content( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create embeddings for content passed in. @@ -224,7 +222,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -236,7 +234,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -247,7 +245,7 @@ async def embed_content_async( title: str | None = None, output_dimensionality: int | None = None, client: glm.GenerativeServiceAsyncClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """The async version of `genai.embed_content`.""" model = model_types.make_model_name(model) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index a0e7df1e2..62893bf55 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -15,9 +15,9 @@ import google.api_core.exceptions from google.ai import generativelanguage as glm from google.generativeai import client -from google.generativeai import string_utils from google.generativeai.types import content_types from google.generativeai.types import generation_types +from google.generativeai.types import helper_types from google.generativeai.types import safety_types @@ -181,7 +181,7 @@ def generate_content( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.GenerateContentResponse: """A multipurpose function to generate responses from the model. @@ -281,7 +281,7 @@ async def generate_content_async( stream: bool = False, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> generation_types.AsyncGenerateContentResponse: """The async version of `GenerativeModel.generate_content`.""" request = self._prepare_request( @@ -328,7 +328,7 @@ def count_tokens( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> glm.CountTokensResponse: if request_options is None: request_options = {} @@ -355,7 +355,7 @@ async def count_tokens_async( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> glm.CountTokensResponse: if request_options is None: request_options = {} diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 7c7b8a5cf..bc517d7cb 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -21,6 +21,7 @@ from google.generativeai import operations from google.generativeai.client import get_default_model_client from google.generativeai.types import model_types +from google.generativeai.types import helper_types from google.api_core import operation from google.api_core import protobuf_helpers from google.protobuf import field_mask_pb2 @@ -31,7 +32,7 @@ def get_model( name: model_types.AnyModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model | model_types.TunedModel: """Given a model name, fetch the `types.Model` or `types.TunedModel` object. @@ -62,7 +63,7 @@ def get_base_model( name: model_types.BaseModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.Model: """Get the `types.Model` for the given base model name. @@ -99,7 +100,7 @@ def get_tuned_model( name: model_types.TunedModelNameOptions, *, client=None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: """Get the `types.TunedModel` for the given tuned model name. @@ -162,7 +163,7 @@ def list_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.ModelsIterable: """Lists available models. @@ -196,7 +197,7 @@ def list_tuned_models( *, page_size: int | None = 50, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModelsIterable: """Lists available models. @@ -244,7 +245,7 @@ def create_tuned_model( input_key: str = "text_input", output_key: str = "output", client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> operations.CreateTunedModelOperation: """Launches a tuning job to create a TunedModel. @@ -344,6 +345,7 @@ def create_tuned_model( top_k=top_k, tuning_task=tuning_task, ) + operation = client.create_tuned_model( dict(tuned_model_id=id, tuned_model=tuned_model), **request_options ) @@ -357,7 +359,7 @@ def update_tuned_model( updates: None = None, *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -368,7 +370,7 @@ def update_tuned_model( updates: dict[str, Any], *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: pass @@ -378,7 +380,7 @@ def update_tuned_model( updates: dict[str, Any] | None = None, *, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> model_types.TunedModel: """Push updates to the tuned model. Only certain attributes are updatable.""" if request_options is None: @@ -395,6 +397,7 @@ def update_tuned_model( "`updates` must be a `dict`.\n" f"got: {type(updates)}" ) + tuned_model = client.get_tuned_model(name=name, **request_options) updates = flatten_update_paths(updates) @@ -436,7 +439,7 @@ def _apply_update(thing, path, value): def delete_tuned_model( tuned_model: model_types.TunedModelNameOptions, client: glm.ModelServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> None: if request_options is None: request_options = {} diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index dfd5e9026..190a222a6 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -23,6 +23,7 @@ from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client +from google.generativeai.types import helper_types from google.generativeai.types.model_types import idecode_time from google.generativeai.types import retriever_types @@ -31,7 +32,7 @@ def create_corpus( name: str | None = None, display_name: str | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """ Create a new `Corpus` in the retriever service, and return it as a `retriever_types.Corpus` instance. @@ -78,7 +79,7 @@ async def create_corpus_async( name: str | None = None, display_name: str | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: """This is the async version of `retriever.create_corpus`.""" if request_options is None: @@ -106,7 +107,7 @@ async def create_corpus_async( def get_corpus( name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """ Fetch a specific `Corpus` from the retriever service. @@ -139,7 +140,7 @@ def get_corpus( async def get_corpus_async( name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> retriever_types.Corpus: # fmt: skip """This is the async version of `retriever.get_corpus`.""" if request_options is None: @@ -164,7 +165,7 @@ def delete_corpus( name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """ Delete a `Corpus` from the service. @@ -191,7 +192,7 @@ async def delete_corpus_async( name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): # fmt: skip """This is the async version of `retriever.delete_corpus`.""" if request_options is None: @@ -211,7 +212,7 @@ def list_corpora( *, page_size: Optional[int] = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[retriever_types.Corpus]: """ List the Corpuses you own in the service. @@ -242,7 +243,7 @@ async def list_corpora_async( *, page_size: Optional[int] = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[retriever_types.Corpus]: """This is the async version of `retriever.list_corpora`.""" if request_options is None: diff --git a/google/generativeai/text.py b/google/generativeai/text.py index 3a147f945..2f3da6842 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -23,6 +23,7 @@ from google.generativeai.client import get_default_text_client from google.generativeai import string_utils +from google.generativeai.types import helper_types from google.generativeai.types import text_types from google.generativeai.types import model_types from google.generativeai import models @@ -141,7 +142,7 @@ def generate_text( safety_settings: safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, client: glm.TextServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.Completion: """Calls the API and returns a `types.Completion` containing the response. @@ -217,7 +218,7 @@ def __init__(self, **kwargs): def _generate_response( request: glm.GenerateTextRequest, client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ Generates a response using the provided `glm.GenerateTextRequest` and client. @@ -253,7 +254,7 @@ def count_text_tokens( model: model_types.AnyModelNameOptions, prompt: str, client: glm.TextServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.TokenCount: base_model = models.get_base_model_name(model) @@ -276,7 +277,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: str, client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict: ... @@ -285,7 +286,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: Sequence[str], client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.BatchEmbeddingDict: ... @@ -293,7 +294,7 @@ def generate_embeddings( model: model_types.BaseModelNameOptions, text: str | Sequence[str], client: glm.TextServiceClient = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict: """Calls the API to create an embedding for the text passed in. diff --git a/google/generativeai/types/__init__.py b/google/generativeai/types/__init__.py index dc0a76761..21768bbe6 100644 --- a/google/generativeai/types/__init__.py +++ b/google/generativeai/types/__init__.py @@ -19,6 +19,7 @@ from google.generativeai.types.discuss_types import * from google.generativeai.types.file_types import * from google.generativeai.types.generation_types import * +from google.generativeai.types.helper_types import * from google.generativeai.types.model_types import * from google.generativeai.types.safety_types import * from google.generativeai.types.text_types import * diff --git a/google/generativeai/types/helper_types.py b/google/generativeai/types/helper_types.py new file mode 100644 index 000000000..3eba4d3f9 --- /dev/null +++ b/google/generativeai/types/helper_types.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import google.api_core.timeout +import google.api_core.retry + +import collections +import dataclasses + +from typing import Union +from typing_extensions import TypedDict + +__all__ = ["RequestOptions", "RequestOptionsType"] + + +class RequestOptionsDict(TypedDict, total=False): + retry: google.api_core.retry.Retry + timeout: Union[int, float, google.api_core.timeout.TimeToDeadlineTimeout] + + +@dataclasses.dataclass(init=False) +class RequestOptions(collections.abc.Mapping): + """Request options + + >>> import google.generativeai as genai + >>> from google.generativeai.types import RequestOptions + >>> from google.api_core import retry + >>> + >>> model = genai.GenerativeModel() + >>> response = model.generate_content('Hello', + ... request_options=RequestOptions( + ... retry=retry.Retry(initial=10, multiplier=2, maximum=60, timeout=300))) + >>> response = model.generate_content('Hello', + ... request_options=RequestOptions(timeout=600))) + + Args: + retry: Refer to [retry docs](https://googleapis.dev/python/google-api-core/latest/retry.html) for details. + timeout: In seconds (or provide a [TimeToDeadlineTimeout](https://googleapis.dev/python/google-api-core/latest/timeout.html) object). + """ + + retry: google.api_core.retry.Retry | None + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None + + def __init__( + self, + *, + retry: google.api_core.retry.Retry | None = None, + timeout: int | float | google.api_core.timeout.TimeToDeadlineTimeout | None = None, + ): + self.retry = retry + self.timeout = timeout + + # Inherit from Mapping for **unpacking + def __getitem__(self, item): + if item == "retry": + return self.retry + elif item == "timeout": + return self.timeout + else: + raise KeyError(f'RequestOptions does not have a "{item}" key') + + def __iter__(self): + yield "retry" + yield "timeout" + + def __len__(self): + return 2 + + +RequestOptionsType = Union[RequestOptions, RequestOptionsDict] diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index 72859f207..538d3924a 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -27,6 +27,8 @@ from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client from google.generativeai import string_utils +from google.generativeai.types import helper_types + from google.generativeai.types import permission_types from google.generativeai.types.model_types import idecode_time from google.generativeai.utils import flatten_update_paths @@ -261,7 +263,7 @@ def create_document( display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ Request to create a `Document`. @@ -312,7 +314,7 @@ async def create_document_async( display_name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.create_document`.""" if request_options is None: @@ -346,7 +348,7 @@ def get_document( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """ Get information about a specific `Document`. @@ -375,7 +377,7 @@ async def get_document_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Document: """This is the async version of `Corpus.get_document`.""" if request_options is None: @@ -401,7 +403,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified `Corpus`. @@ -439,7 +441,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.update`.""" if request_options is None: @@ -470,7 +472,7 @@ def query( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """ Query a corpus for information. @@ -524,7 +526,7 @@ async def query_async( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[RelevantChunk]: """This is the async version of `Corpus.query`.""" if request_options is None: @@ -566,7 +568,7 @@ def delete_document( name: str, force: bool = False, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Delete a document in the corpus. @@ -593,7 +595,7 @@ async def delete_document_async( name: str, force: bool = False, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Corpus.delete_document`.""" if request_options is None: @@ -612,7 +614,7 @@ def list_documents( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Document]: """ List documents in corpus. @@ -642,7 +644,7 @@ async def list_documents_async( self, page_size: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Document]: """This is the async version of `Corpus.list_documents`.""" if request_options is None: @@ -744,7 +746,7 @@ def create_chunk( name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """ Create a `Chunk` object which has textual data. @@ -801,7 +803,7 @@ async def create_chunk_async( name: str | None = None, custom_metadata: Iterable[CustomMetadata] | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Chunk: """This is the async version of `Document.create_chunk`.""" if request_options is None: @@ -900,7 +902,7 @@ def batch_create_chunks( self, chunks: BatchCreateChunkOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Create chunks within the given document. @@ -926,7 +928,7 @@ async def batch_create_chunks_async( self, chunks: BatchCreateChunkOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_create_chunk`.""" if request_options is None: @@ -943,7 +945,7 @@ def get_chunk( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Get information about a specific chunk. @@ -972,7 +974,7 @@ async def get_chunk_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.get_chunk`.""" if request_options is None: @@ -992,7 +994,7 @@ def list_chunks( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> Iterable[Chunk]: """ List chunks of a document. @@ -1018,7 +1020,7 @@ async def list_chunks_async( self, page_size: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> AsyncIterable[Chunk]: """This is the async version of `Document.list_chunks`.""" if request_options is None: @@ -1037,7 +1039,7 @@ def query( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """ Query a `Document` in the `Corpus` for information. @@ -1090,7 +1092,7 @@ async def query_async( metadata_filters: Iterable[MetadataFilter] | None = None, results_count: int | None = None, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ) -> list[RelevantChunk]: """This is the async version of `Document.query`.""" if request_options is None: @@ -1137,7 +1139,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified document. @@ -1174,7 +1176,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.update`.""" if request_options is None: @@ -1202,7 +1204,7 @@ def batch_update_chunks( self, chunks: BatchUpdateChunksOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update multiple chunks within the same document. @@ -1299,7 +1301,7 @@ async def batch_update_chunks_async( self, chunks: BatchUpdateChunksOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_update_chunks`.""" if request_options is None: @@ -1387,7 +1389,7 @@ def delete_chunk( self, name: str, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, # fmt: {} + request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """ Delete a `Chunk`. @@ -1412,7 +1414,7 @@ async def delete_chunk_async( self, name: str, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, # fmt: {} + request_options: helper_types.RequestOptionsType | None = None, # fmt: {} ): """This is the async version of `Document.delete_chunk`.""" if request_options is None: @@ -1431,7 +1433,7 @@ def batch_delete_chunks( self, chunks: BatchDeleteChunkOptions, client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Delete multiple `Chunk`s from a document. @@ -1464,7 +1466,7 @@ async def batch_delete_chunks_async( self, chunks: BatchDeleteChunkOptions, client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Document.batch_delete_chunks`.""" if request_options is None: @@ -1570,7 +1572,7 @@ def update( self, updates: dict[str, Any], client: glm.RetrieverServiceClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """ Update a list of fields for a specified `Chunk`. @@ -1619,7 +1621,7 @@ async def update_async( self, updates: dict[str, Any], client: glm.RetrieverServiceAsyncClient | None = None, - request_options: dict[str, Any] | None = None, + request_options: helper_types.RequestOptionsType | None = None, ): """This is the async version of `Chunk.update`.""" if request_options is None: diff --git a/tests/test_answer.py b/tests/test_answer.py index 6fa12603c..4128567f4 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -21,6 +21,7 @@ import google.ai.generativelanguage as glm from google.generativeai import answer +from google.generativeai import types as genai_types from google.generativeai import client from absl.testing import absltest from absl.testing import parameterized @@ -239,7 +240,7 @@ def test_generate_answer(self): def test_generate_answer_called_with_request_options(self): self.client.generate_answer = mock.MagicMock() request = mock.ANY - request_options = {"timeout": 120} + request_options = genai_types.RequestOptions(timeout=120) answer.generate_answer(contents=[], inline_passages=[], request_options=request_options) diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 9d628a42c..e7efd5ef2 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy -from typing import Any import unittest.mock diff --git a/tests/test_helpers.py b/tests/test_helpers.py new file mode 100644 index 000000000..0c2de7f29 --- /dev/null +++ b/tests/test_helpers.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib +import copy +import collections +from typing import Union + +from absl.testing import parameterized + +import google.ai.generativelanguage as glm + +from google.generativeai import client +from google.generativeai import models +from google.generativeai.types import model_types +from google.generativeai.types import helper_types + +from google.api_core import retry + + +class MockModelClient: + def __init__(self, test): + self.test = test + + def get_model( + self, + request: Union[glm.GetModelRequest, None] = None, + *, + name=None, + timeout=None, + retry=None + ) -> glm.Model: + if request is None: + request = glm.GetModelRequest(name=name) + self.test.assertIsInstance(request, glm.GetModelRequest) + self.test.observed_requests.append(request) + self.test.observed_timeout.append(timeout) + self.test.observed_retry.append(retry) + response = copy.copy(self.test.responses["get_model"]) + return response + + +class HelperTests(parameterized.TestCase): + + def setUp(self): + self.client = MockModelClient(self) + client._client_manager.clients["model"] = self.client + + self.observed_requests = [] + self.observed_retry = [] + self.observed_timeout = [] + self.responses = collections.defaultdict(list) + + @parameterized.named_parameters( + ["None", None, None, None], + ["Empty", {}, None, None], + ["Timeout", {"timeout": 7}, 7, None], + ["Retry", {"retry": retry.Retry(timeout=7)}, None, retry.Retry(timeout=7)], + [ + "RequestOptions", + helper_types.RequestOptions(timeout=7, retry=retry.Retry(multiplier=3)), + 7, + retry.Retry(multiplier=3), + ], + ) + def test_get_model(self, request_options, expected_timeout, expected_retry): + self.responses = {"get_model": glm.Model(name="models/fake-bison-001")} + + _ = models.get_model("models/fake-bison-001", request_options=request_options) + + self.assertEqual(self.observed_timeout[0], expected_timeout) + self.assertEqual(str(self.observed_retry[0]), str(expected_retry)) diff --git a/tests/test_models.py b/tests/test_models.py index e971ef86d..f39ed3a2c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -31,6 +31,7 @@ from google.generativeai import models from google.generativeai import client from google.generativeai.types import model_types +from google.generativeai import types as genai_types import pandas as pd @@ -470,7 +471,7 @@ def test_get_model_called_with_request_options(self): def test_get_tuned_model_called_with_request_options(self): self.client.get_tuned_model = unittest.mock.MagicMock() name = unittest.mock.ANY - request_options = {"timeout": 120} + request_options = genai_types.RequestOptions(timeout=120) try: models.get_model(name="tunedModels/", request_options=request_options)