Skip to content

Improve request_options #297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions google/generativeai/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
get_default_generative_client,
get_default_generative_async_client,
)
from google.generativeai import string_utils
from google.generativeai.types import model_types
from google.generativeai import models
from google.generativeai.types import helper_types
from google.generativeai.types import safety_types
from google.generativeai.types import content_types
from google.generativeai.types import answer_types
from google.generativeai.types import retriever_types
from google.generativeai.types.retriever_types import MetadataFilter

Expand Down Expand Up @@ -247,7 +245,7 @@ def generate_answer(
safety_settings: safety_types.SafetySettingOptions | None = None,
temperature: float | None = None,
client: glm.GenerativeServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
):
"""
Calls the GenerateAnswer API and returns a `types.Answer` containing the response.
Expand Down Expand Up @@ -320,7 +318,7 @@ async def generate_answer_async(
safety_settings: safety_types.SafetySettingOptions | None = None,
temperature: float | None = None,
client: glm.GenerativeServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
):
"""
Calls the API and returns a `types.Answer` containing the answer.
Expand Down
13 changes: 7 additions & 6 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.generativeai.client import get_default_discuss_async_client
from google.generativeai import string_utils
from google.generativeai.types import discuss_types
from google.generativeai.types import helper_types
from google.generativeai.types import model_types
from google.generativeai.types import safety_types

Expand Down Expand Up @@ -316,7 +317,7 @@ def chat(
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
client: glm.DiscussServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> discuss_types.ChatResponse:
"""Calls the API and returns a `types.ChatResponse` containing the response.

Expand Down Expand Up @@ -416,7 +417,7 @@ async def chat_async(
top_k: float | None = None,
prompt: discuss_types.MessagePromptOptions | None = None,
client: glm.DiscussServiceAsyncClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> discuss_types.ChatResponse:
request = _make_generate_message_request(
model=model,
Expand Down Expand Up @@ -469,7 +470,7 @@ def last(self, message: discuss_types.MessageOptions):
def reply(
self,
message: discuss_types.MessageOptions,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> discuss_types.ChatResponse:
if isinstance(self._client, glm.DiscussServiceAsyncClient):
raise TypeError(f"reply can't be called on an async client, use reply_async instead.")
Expand Down Expand Up @@ -537,7 +538,7 @@ def _build_chat_response(
def _generate_response(
request: glm.GenerateMessageRequest,
client: glm.DiscussServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> ChatResponse:
if request_options is None:
request_options = {}
Expand All @@ -553,7 +554,7 @@ def _generate_response(
async def _generate_response_async(
request: glm.GenerateMessageRequest,
client: glm.DiscussServiceAsyncClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> ChatResponse:
if request_options is None:
request_options = {}
Expand All @@ -574,7 +575,7 @@ def count_message_tokens(
messages: discuss_types.MessagesOptions | None = None,
model: model_types.AnyModelNameOptions = DEFAULT_DISCUSS_MODEL,
client: glm.DiscussServiceAsyncClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> discuss_types.TokenCount:
model = model_types.make_model_name(model)
prompt = _make_message_prompt(prompt, context=context, examples=examples, messages=messages)
Expand Down
16 changes: 7 additions & 9 deletions google/generativeai/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
from __future__ import annotations

import dataclasses
from collections.abc import Iterable, Sequence, Mapping
import itertools
from typing import Any, Iterable, overload, TypeVar, Union, Mapping

Expand All @@ -24,7 +22,7 @@
from google.generativeai.client import get_default_generative_client
from google.generativeai.client import get_default_generative_async_client

from google.generativeai import string_utils
from google.generativeai.types import helper_types
from google.generativeai.types import text_types
from google.generativeai.types import model_types
from google.generativeai.types import content_types
Expand Down Expand Up @@ -104,7 +102,7 @@ def embed_content(
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.EmbeddingDict: ...


Expand All @@ -116,7 +114,7 @@ def embed_content(
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.BatchEmbeddingDict: ...


Expand All @@ -127,7 +125,7 @@ def embed_content(
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceClient = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
"""Calls the API to create embeddings for content passed in.
Expand Down Expand Up @@ -224,7 +222,7 @@ async def embed_content_async(
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceAsyncClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.EmbeddingDict: ...


Expand All @@ -236,7 +234,7 @@ async def embed_content_async(
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceAsyncClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.BatchEmbeddingDict: ...


Expand All @@ -247,7 +245,7 @@ async def embed_content_async(
title: str | None = None,
output_dimensionality: int | None = None,
client: glm.GenerativeServiceAsyncClient = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> text_types.EmbeddingDict | text_types.BatchEmbeddingDict:
"""The async version of `genai.embed_content`."""
model = model_types.make_model_name(model)
Expand Down
10 changes: 5 additions & 5 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
import google.api_core.exceptions
from google.ai import generativelanguage as glm
from google.generativeai import client
from google.generativeai import string_utils
from google.generativeai.types import content_types
from google.generativeai.types import generation_types
from google.generativeai.types import helper_types
from google.generativeai.types import safety_types


Expand Down Expand Up @@ -181,7 +181,7 @@ def generate_content(
stream: bool = False,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> generation_types.GenerateContentResponse:
"""A multipurpose function to generate responses from the model.

Expand Down Expand Up @@ -281,7 +281,7 @@ async def generate_content_async(
stream: bool = False,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> generation_types.AsyncGenerateContentResponse:
"""The async version of `GenerativeModel.generate_content`."""
request = self._prepare_request(
Expand Down Expand Up @@ -328,7 +328,7 @@ def count_tokens(
safety_settings: safety_types.SafetySettingOptions | None = None,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> glm.CountTokensResponse:
if request_options is None:
request_options = {}
Expand All @@ -355,7 +355,7 @@ async def count_tokens_async(
safety_settings: safety_types.SafetySettingOptions | None = None,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> glm.CountTokensResponse:
if request_options is None:
request_options = {}
Expand Down
23 changes: 13 additions & 10 deletions google/generativeai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.generativeai import operations
from google.generativeai.client import get_default_model_client
from google.generativeai.types import model_types
from google.generativeai.types import helper_types
from google.api_core import operation
from google.api_core import protobuf_helpers
from google.protobuf import field_mask_pb2
Expand All @@ -31,7 +32,7 @@ def get_model(
name: model_types.AnyModelNameOptions,
*,
client=None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> model_types.Model | model_types.TunedModel:
"""Given a model name, fetch the `types.Model` or `types.TunedModel` object.

Expand Down Expand Up @@ -62,7 +63,7 @@ def get_base_model(
name: model_types.BaseModelNameOptions,
*,
client=None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> model_types.Model:
"""Get the `types.Model` for the given base model name.

Expand Down Expand Up @@ -99,7 +100,7 @@ def get_tuned_model(
name: model_types.TunedModelNameOptions,
*,
client=None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> model_types.TunedModel:
"""Get the `types.TunedModel` for the given tuned model name.

Expand Down Expand Up @@ -162,7 +163,7 @@ def list_models(
*,
page_size: int | None = 50,
client: glm.ModelServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> model_types.ModelsIterable:
"""Lists available models.

Expand Down Expand Up @@ -196,7 +197,7 @@ def list_tuned_models(
*,
page_size: int | None = 50,
client: glm.ModelServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> model_types.TunedModelsIterable:
"""Lists available models.

Expand Down Expand Up @@ -244,7 +245,7 @@ def create_tuned_model(
input_key: str = "text_input",
output_key: str = "output",
client: glm.ModelServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> operations.CreateTunedModelOperation:
"""Launches a tuning job to create a TunedModel.

Expand Down Expand Up @@ -344,6 +345,7 @@ def create_tuned_model(
top_k=top_k,
tuning_task=tuning_task,
)

operation = client.create_tuned_model(
dict(tuned_model_id=id, tuned_model=tuned_model), **request_options
)
Expand All @@ -357,7 +359,7 @@ def update_tuned_model(
updates: None = None,
*,
client: glm.ModelServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> model_types.TunedModel:
pass

Expand All @@ -368,7 +370,7 @@ def update_tuned_model(
updates: dict[str, Any],
*,
client: glm.ModelServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> model_types.TunedModel:
pass

Expand All @@ -378,7 +380,7 @@ def update_tuned_model(
updates: dict[str, Any] | None = None,
*,
client: glm.ModelServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> model_types.TunedModel:
"""Push updates to the tuned model. Only certain attributes are updatable."""
if request_options is None:
Expand All @@ -395,6 +397,7 @@ def update_tuned_model(
"`updates` must be a `dict`.\n"
f"got: {type(updates)}"
)

tuned_model = client.get_tuned_model(name=name, **request_options)

updates = flatten_update_paths(updates)
Expand Down Expand Up @@ -436,7 +439,7 @@ def _apply_update(thing, path, value):
def delete_tuned_model(
tuned_model: model_types.TunedModelNameOptions,
client: glm.ModelServiceClient | None = None,
request_options: dict[str, Any] | None = None,
request_options: helper_types.RequestOptionsType | None = None,
) -> None:
if request_options is None:
request_options = {}
Expand Down
Loading