diff --git a/google/generativeai/caching.py b/google/generativeai/caching.py index a28a50256..9acad4726 100644 --- a/google/generativeai/caching.py +++ b/google/generativeai/caching.py @@ -14,90 +14,130 @@ # limitations under the License. from __future__ import annotations -import dataclasses import datetime -from typing import Any, Iterable, Optional +import textwrap +from typing import Iterable, Optional from google.generativeai import protos -from google.generativeai.types.model_types import idecode_time from google.generativeai.types import caching_types from google.generativeai.types import content_types -from google.generativeai.utils import flatten_update_paths from google.generativeai.client import get_default_cache_client from google.protobuf import field_mask_pb2 -import google.ai.generativelanguage as glm + +_USER_ROLE = "user" +_MODEL_ROLE = "model" -@dataclasses.dataclass class CachedContent: """Cached content resource.""" - name: str - model: str - create_time: datetime.datetime - update_time: datetime.datetime - expire_time: datetime.datetime + def __init__(self, name): + """Fetches a `CachedContent` resource. - # NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+). - # Adding basic support for now. - def __enter__(self): - return self + Identical to `CachedContent.get`. - def __exit__(self, exc_type, exc_value, exc_tb): - self.delete() - - def _to_dict(self) -> protos.CachedContent: - proto_paths = { - "name": self.name, - "model": self.model, - } - return protos.CachedContent(**proto_paths) - - def _apply_update(self, path, value): - parts = path.split(".") - for part in parts[:-1]: - self = getattr(self, part) - if parts[-1] == "ttl": - value = self.expire_time + datetime.timedelta(seconds=value["seconds"]) - parts[-1] = "expire_time" - setattr(self, parts[-1], value) + Args: + name: The resource name referring to the cached content. + """ + client = get_default_cache_client() - @classmethod - def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent: - # not supposed to get INPUT_ONLY repeated fields, but local gapic lib build - # is returning these, hence setting including_default_value_fields to False - cached_content = type(cached_content).to_dict( - cached_content, including_default_value_fields=False + if "cachedContents/" not in name: + name = "cachedContents/" + name + + request = protos.GetCachedContentRequest(name=name) + response = client.get_cached_content(request) + self._proto = response + + @property + def name(self) -> str: + return self._proto.name + + @property + def model(self) -> str: + return self._proto.model + + @property + def display_name(self) -> str: + return self._proto.display_name + + @property + def usage_metadata(self) -> protos.CachedContent.UsageMetadata: + return self._proto.usage_metadata + + @property + def create_time(self) -> datetime.datetime: + return self._proto.create_time + + @property + def update_time(self) -> datetime.datetime: + return self._proto.update_time + + @property + def expire_time(self) -> datetime.datetime: + return self._proto.expire_time + + def __str__(self): + return textwrap.dedent( + f"""\ + CachedContent( + name='{self.name}', + model='{self.model}', + display_name='{self.display_name}', + usage_metadata={'{'} + 'total_token_count': {self.usage_metadata.total_token_count}, + {'}'}, + create_time={self.create_time}, + update_time={self.update_time}, + expire_time={self.expire_time} + )""" ) - idecode_time(cached_content, "create_time") - idecode_time(cached_content, "update_time") - # always decode `expire_time` as Timestamp is returned - # regardless of what was sent on input - idecode_time(cached_content, "expire_time") - return cls(**cached_content) + __repr__ = __str__ + + @classmethod + def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent: + """Creates an instance of CachedContent form an object, without calling `get`.""" + self = cls.__new__(cls) + self._proto = protos.CachedContent() + self._update(obj) + return self + + def _update(self, updates): + """Updates this instance inplace, does not call the API's `update` method""" + if isinstance(updates, CachedContent): + updates = updates._proto + + if not isinstance(updates, dict): + updates = type(updates).to_dict(updates, including_default_value_fields=False) + + for key, value in updates.items(): + setattr(self._proto, key, value) @staticmethod def _prepare_create_request( model: str, - name: str | None = None, + *, + display_name: str | None = None, system_instruction: Optional[content_types.ContentType] = None, contents: Optional[content_types.ContentsType] = None, tools: Optional[content_types.FunctionLibraryType] = None, tool_config: Optional[content_types.ToolConfigType] = None, - ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, ) -> protos.CreateCachedContentRequest: """Prepares a CreateCachedContentRequest.""" - if name is not None: - if not caching_types.valid_cached_content_name(name): - raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name)) - - name = "cachedContents/" + name + if ttl and expire_time: + raise ValueError( + "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." + ) if "/" not in model: model = "models/" + model + if display_name and len(display_name) > 128: + raise ValueError("`display_name` must be no more than 128 unicode characters.") + if system_instruction: system_instruction = content_types.to_content(system_instruction) @@ -110,18 +150,21 @@ def _prepare_create_request( if contents: contents = content_types.to_contents(contents) + if not contents[-1].role: + contents[-1].role = _USER_ROLE - if ttl: - ttl = caching_types.to_ttl(ttl) + ttl = caching_types.to_optional_ttl(ttl) + expire_time = caching_types.to_optional_expire_time(expire_time) cached_content = protos.CachedContent( - name=name, model=model, + display_name=display_name, system_instruction=system_instruction, contents=contents, tools=tools_lib, tool_config=tool_config, ttl=ttl, + expire_time=expire_time, ) return protos.CreateCachedContentRequest(cached_content=cached_content) @@ -130,13 +173,14 @@ def _prepare_create_request( def create( cls, model: str, - name: str | None = None, + *, + display_name: str | None = None, system_instruction: Optional[content_types.ContentType] = None, contents: Optional[content_types.ContentsType] = None, tools: Optional[content_types.FunctionLibraryType] = None, tool_config: Optional[content_types.ToolConfigType] = None, - ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1), - client: glm.CacheServiceClient | None = None, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, ) -> CachedContent: """Creates `CachedContent` resource. @@ -144,34 +188,40 @@ def create( model: The name of the `model` to use for cached content creation. Any `CachedContent` resource can be only used with the `model` it was created for. - name: The resource name referring to the cached content. + display_name: The user-generated meaningful display name + of the cached content. `display_name` must be no + more than 128 unicode characters. system_instruction: Developer set system instruction. contents: Contents to cache. tools: A list of `Tools` the model may use to generate response. tool_config: Config to apply to all tools. ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + `ttl` and `expire_time` are exclusive arguments. + expire_time: Expiration time for cached resource. + `ttl` and `expire_time` are exclusive arguments. Returns: `CachedContent` resource with specified name. """ - if client is None: - client = get_default_cache_client() + client = get_default_cache_client() request = cls._prepare_create_request( model=model, - name=name, + display_name=display_name, system_instruction=system_instruction, contents=contents, tools=tools, tool_config=tool_config, ttl=ttl, + expire_time=expire_time, ) response = client.create_cached_content(request) - return cls._decode_cached_content(response) + result = CachedContent._from_obj(response) + return result @classmethod - def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent: + def get(cls, name: str) -> CachedContent: """Fetches required `CachedContent` resource. Args: @@ -180,20 +230,18 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC Returns: `CachedContent` resource with specified `name`. """ - if client is None: - client = get_default_cache_client() + client = get_default_cache_client() if "cachedContents/" not in name: name = "cachedContents/" + name request = protos.GetCachedContentRequest(name=name) response = client.get_cached_content(request) - return cls._decode_cached_content(response) + result = CachedContent._from_obj(response) + return result @classmethod - def list( - cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None - ) -> Iterable[CachedContent]: + def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]: """Lists `CachedContent` objects associated with the project. Args: @@ -203,17 +251,16 @@ def list( Returns: A paginated list of `CachedContent` objects. """ - if client is None: - client = get_default_cache_client() + client = get_default_cache_client() request = protos.ListCachedContentsRequest(page_size=page_size) for cached_content in client.list_cached_contents(request): - yield cls._decode_cached_content(cached_content) + cached_content = CachedContent._from_obj(cached_content) + yield cached_content - def delete(self, client: glm.CachedServiceClient | None = None) -> None: + def delete(self) -> None: """Deletes `CachedContent` resource.""" - if client is None: - client = get_default_cache_client() + client = get_default_cache_client() request = protos.DeleteCachedContentRequest(name=self.name) client.delete_cached_content(request) @@ -221,40 +268,47 @@ def delete(self, client: glm.CachedServiceClient | None = None) -> None: def update( self, - updates: dict[str, Any], - client: glm.CacheServiceClient | None = None, - ) -> CachedContent: + *, + ttl: Optional[caching_types.TTLTypes] = None, + expire_time: Optional[caching_types.ExpireTimeTypes] = None, + ) -> None: """Updates requested `CachedContent` resource. Args: - updates: The list of fields to update. Currently only - `ttl/expire_time` is supported as an update path. - - Returns: - `CachedContent` object with specified updates. + ttl: TTL for cached resource (in seconds). Defaults to 1 hour. + `ttl` and `expire_time` are exclusive arguments. + expire_time: Expiration time for cached resource. + `ttl` and `expire_time` are exclusive arguments. """ - if client is None: - client = get_default_cache_client() - - updates = flatten_update_paths(updates) - for update_path in updates: - if update_path == "ttl": - updates = updates.copy() - update_path_val = updates.get(update_path) - updates[update_path] = caching_types.to_ttl(update_path_val) - else: - raise ValueError( - f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead." - ) - field_mask = field_mask_pb2.FieldMask() + client = get_default_cache_client() - for path in updates.keys(): - field_mask.paths.append(path) - for path, value in updates.items(): - self._apply_update(path, value) + if ttl and expire_time: + raise ValueError( + "Exclusive arguments: Please provide either `ttl` or `expire_time`, not both." + ) - request = protos.UpdateCachedContentRequest( - cached_content=self._to_dict(), update_mask=field_mask + ttl = caching_types.to_optional_ttl(ttl) + expire_time = caching_types.to_optional_expire_time(expire_time) + + updates = protos.CachedContent( + name=self.name, + ttl=ttl, + expire_time=expire_time, ) - client.update_cached_content(request) - return self + + field_mask = field_mask_pb2.FieldMask() + + if ttl: + field_mask.paths.append("ttl") + elif expire_time: + field_mask.paths.append("expire_time") + else: + raise ValueError( + f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`." + ) + + request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask) + updated_cc = client.update_cached_content(request) + self._update(updated_cc) + + return diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 10744a948..e3387a64f 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -20,6 +20,9 @@ from google.generativeai.types import helper_types from google.generativeai.types import safety_types +_USER_ROLE = "user" +_MODEL_ROLE = "model" + class GenerativeModel: """ @@ -96,14 +99,9 @@ def __init__( self._client = None self._async_client = None - def __new__(cls, *args, **kwargs) -> GenerativeModel: - self = super().__new__(cls) - - if cached_instance := kwargs.pop("cached_content", None): - setattr(self, "_cached_content", cached_instance.name) - setattr(cls, "cached_content", property(fget=lambda self: self._cached_content)) - - return self + @property + def cached_content(self) -> str: + return getattr(self, "_cached_content", None) @property def model_name(self): @@ -123,7 +121,7 @@ def maybe_text(content): safety_settings={self._safety_settings}, tools={self._tools}, system_instruction={maybe_text(self._system_instruction)}, - cached_content={getattr(self, "cached_content", None)} + cached_content={self.cached_content} )""" ) @@ -139,13 +137,11 @@ def _prepare_request( tool_config: content_types.ToolConfigType | None, ) -> protos.GenerateContentRequest: """Creates a `protos.GenerateContentRequest` from raw inputs.""" - if hasattr(self, "cached_content") and any([self._system_instruction, tools, tool_config]): + if hasattr(self, "_cached_content") and any([self._system_instruction, tools, tool_config]): raise ValueError( "`tools`, `tool_config`, `system_instruction` cannot be set on a model instantinated with `cached_content` as its context." ) - cached_content = getattr(self, "cached_content", None) - tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -174,7 +170,7 @@ def _prepare_request( tools=tools_lib, tool_config=tool_config, system_instruction=self._system_instruction, - cached_content=cached_content, + cached_content=self.cached_content, ) def _get_tools_lib( @@ -190,6 +186,7 @@ def _get_tools_lib( def from_cached_content( cls, cached_content: str, + *, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, ) -> GenerativeModel: ... @@ -199,6 +196,7 @@ def from_cached_content( def from_cached_content( cls, cached_content: caching.CachedContent, + *, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, ) -> GenerativeModel: ... @@ -207,6 +205,7 @@ def from_cached_content( def from_cached_content( cls, cached_content: str | caching.CachedContent, + *, generation_config: generation_types.GenerationConfigType | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, ) -> GenerativeModel: @@ -214,6 +213,8 @@ def from_cached_content( Args: cached_content: context for the model. + generation_config: Overrides for the model's generation config. + safety_settings: Overrides for the model's safety settings. Returns: `GenerativeModel` object with `cached_content` as its context. @@ -221,17 +222,16 @@ def from_cached_content( if isinstance(cached_content, str): cached_content = caching.CachedContent.get(name=cached_content) - # call __new__ with the cached_content to set the model's context. This is done to avoid - # the exposing `cached_content` as a public attribute. - self = cls.__new__(cls, cached_content=cached_content) - # call __init__ to set the model's `generation_config`, `safety_settings`. # `model_name` will be the name of the model for which the `cached_content` was created. - self.__init__( + self = cls( model_name=cached_content.model, generation_config=generation_config, safety_settings=safety_settings, ) + + # set the model's context. + setattr(self, "_cached_content", cached_content.name) return self def generate_content( @@ -309,6 +309,10 @@ def generate_content( tools=tools, tool_config=tool_config, ) + + if request.contents and not request.contents[-1].role: + request.contents[-1].role = _USER_ROLE + if self._client is None: self._client = client.get_default_generative_client() @@ -359,6 +363,10 @@ async def generate_content_async( tools=tools, tool_config=tool_config, ) + + if request.contents and not request.contents[-1].role: + request.contents[-1].role = _USER_ROLE + if self._async_client is None: self._async_client = client.get_default_generative_async_client() @@ -489,9 +497,6 @@ class ChatSession: history: A chat history to initialize the object with. """ - _USER_ROLE = "user" - _MODEL_ROLE = "model" - def __init__( self, model: GenerativeModel, @@ -559,7 +564,7 @@ def send_message( content = content_types.to_content(content) if not content.role: - content.role = self._USER_ROLE + content.role = _USER_ROLE history = self.history[:] history.append(content) @@ -646,7 +651,7 @@ def _handle_afc( ) function_response_parts.append(fr) - send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=_USER_ROLE, parts=function_response_parts) history.append(send) response = self.model.generate_content( @@ -688,7 +693,7 @@ async def send_message_async( content = content_types.to_content(content) if not content.role: - content.role = self._USER_ROLE + content.role = _USER_ROLE history = self.history[:] history.append(content) @@ -753,7 +758,7 @@ async def _handle_afc_async( ) function_response_parts.append(fr) - send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=_USER_ROLE, parts=function_response_parts) history.append(send) response = await self.model.generate_content_async( @@ -820,7 +825,7 @@ def history(self) -> list[protos.Content]: sent = self._last_sent received = last.candidates[0].content if not received.role: - received.role = self._MODEL_ROLE + received.role = _MODEL_ROLE self._history.extend([sent, received]) self._last_sent = None diff --git a/google/generativeai/types/caching_types.py b/google/generativeai/types/caching_types.py index 8d55b70b2..4f1a6b8be 100644 --- a/google/generativeai/types/caching_types.py +++ b/google/generativeai/types/caching_types.py @@ -15,39 +15,69 @@ from __future__ import annotations import datetime -from typing import Optional, Union +from typing import Union from typing_extensions import TypedDict -import re -__all__ = ["TTL"] +__all__ = [ + "ExpireTime", + "TTL", + "TTLTypes", + "ExpireTimeTypes", +] -_VALID_CACHED_CONTENT_NAME = r"([a-z0-9-\.]+)$" -NAME_ERROR_MESSAGE = ( - "The `name` must consist of alphanumeric characters (or `-` or `.`). Received: `{name}`" -) +class TTL(TypedDict): + # Represents datetime.datetime.now() + desired ttl + seconds: int + nanos: int -def valid_cached_content_name(name: str) -> bool: - return re.match(_VALID_CACHED_CONTENT_NAME, name) is not None +class ExpireTime(TypedDict): + # Represents seconds of UTC time since Unix epoch + seconds: int + nanos: int -class TTL(TypedDict): - seconds: int +TTLTypes = Union[TTL, int, datetime.timedelta] +ExpireTimeTypes = Union[ExpireTime, int, datetime.datetime] -ExpirationTypes = Union[TTL, int, datetime.timedelta] +def to_optional_ttl(ttl: TTLTypes | None) -> TTL | None: + if ttl is None: + return None + elif isinstance(ttl, datetime.timedelta): + return { + "seconds": int(ttl.total_seconds()), + "nanos": int(ttl.microseconds * 1000), + } + elif isinstance(ttl, dict): + return ttl + elif isinstance(ttl, int): + return {"seconds": ttl, "nanos": 0} + else: + raise TypeError( + f"Could not convert input to `ttl` \n'" f" type: {type(ttl)}\n", + ttl, + ) -def to_ttl(expiration: Optional[ExpirationTypes]) -> TTL: - if isinstance(expiration, datetime.timedelta): - return {"seconds": int(expiration.total_seconds())} - elif isinstance(expiration, dict): - return expiration - elif isinstance(expiration, int): - return {"seconds": expiration} +def to_optional_expire_time(expire_time: ExpireTimeTypes | None) -> ExpireTime | None: + if expire_time is None: + return expire_time + elif isinstance(expire_time, datetime.datetime): + timestamp = expire_time.timestamp() + seconds = int(timestamp) + nanos = int((seconds % 1) * 1000) + return { + "seconds": seconds, + "nanos": nanos, + } + elif isinstance(expire_time, dict): + return expire_time + elif isinstance(expire_time, int): + return {"seconds": expire_time, "nanos": 0} else: raise TypeError( - f"Could not convert input to `expire_time` \n'" f" type: {type(expiration)}\n", - expiration, + f"Could not convert input to `expire_time` \n'" f" type: {type(expire_time)}\n", + expire_time, ) diff --git a/google/generativeai/version.py b/google/generativeai/version.py index 8018b67ac..69a8b817e 100644 --- a/google/generativeai/version.py +++ b/google/generativeai/version.py @@ -14,4 +14,4 @@ # limitations under the License. from __future__ import annotations -__version__ = "0.6.0" +__version__ = "0.7.0" diff --git a/tests/test_caching.py b/tests/test_caching.py index 47692325b..1d1b2608c 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import datetime +import textwrap import unittest from google.generativeai import caching @@ -44,7 +45,9 @@ def create_cached_content( self.observed_requests.append(request) return protos.CachedContent( name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", @@ -58,7 +61,9 @@ def get_cached_content( self.observed_requests.append(request) return protos.CachedContent( name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", @@ -73,14 +78,18 @@ def list_cached_contents( return [ protos.CachedContent( name="cachedContents/test-cached-content-1", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", ), protos.CachedContent( name="cachedContents/test-cached-content-2", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", @@ -95,7 +104,9 @@ def update_cached_content( self.observed_requests.append(request) return protos.CachedContent( name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T03:01:01.123456Z", @@ -114,8 +125,7 @@ def add(a: int, b: int) -> int: return a + b cc = caching.CachedContent.create( - name="test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", contents=["Add 5 and 6"], tools=[add], tool_config={"function_calling_config": "ANY"}, @@ -125,7 +135,7 @@ def add(a: int, b: int) -> int: self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) self.assertEqual(cc.name, "cachedContents/test-cached-content") - self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + self.assertEqual(cc.model, "models/gemini-1.5-pro") @parameterized.named_parameters( [ @@ -147,10 +157,9 @@ def add(a: int, b: int) -> int: ), ] ) - def test_expiration_types_for_create_cached_content(self, ttl): + def test_ttl_types_for_create_cached_content(self, ttl): cc = caching.CachedContent.create( - name="test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", contents=["cache this please for 2 hours"], ttl=ttl, ) @@ -160,28 +169,39 @@ def test_expiration_types_for_create_cached_content(self, ttl): @parameterized.named_parameters( [ dict( - testcase_name="upper_case", - name="Test-cached-content", + testcase_name="expire_time-is-int-seconds", + expire_time=1717653421, ), dict( - testcase_name="special_characters_except_dot_and_hyphen", - name="test-cac*@/hed-conte#nt", + testcase_name="expire_time-is-datetime", + expire_time=datetime.datetime.now(), ), dict( - testcase_name="empty_name", - name="", + testcase_name="expire_time-is-dict", + expire_time={"seconds": 1717653421}, ), dict( - testcase_name="blank_spaces", - name="test cached content", + testcase_name="expire_time-is-none-default-to-1-hr", + expire_time=None, ), ] ) - def test_create_cached_content_with_invalid_name_format(self, name): + def test_expire_time_types_for_create_cached_content(self, expire_time): + cc = caching.CachedContent.create( + model="models/gemini-1.5-pro", + contents=["cache this please for 2 hours"], + expire_time=expire_time, + ) + self.assertIsInstance(self.observed_requests[-1], protos.CreateCachedContentRequest) + self.assertIsInstance(cc, caching.CachedContent) + + def test_mutual_exclusivity_for_ttl_and_expire_time_in_create_cached_content(self): with self.assertRaises(ValueError): _ = caching.CachedContent.create( - name=name, - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + contents=["cache this please for 2 hours"], + ttl=datetime.timedelta(hours=2), + expire_time=datetime.datetime.now(), ) def test_get_cached_content(self): @@ -189,7 +209,7 @@ def test_get_cached_content(self): self.assertIsInstance(self.observed_requests[-1], protos.GetCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) self.assertEqual(cc.name, "cachedContents/test-cached-content") - self.assertEqual(cc.model, "models/gemini-1.0-pro-001") + self.assertEqual(cc.model, "models/gemini-1.5-pro") def test_list_cached_contents(self): ccs = list(caching.CachedContent.list(page_size=2)) @@ -198,25 +218,27 @@ def test_list_cached_contents(self): self.assertIsInstance(ccs[0], caching.CachedContent) self.assertIsInstance(ccs[1], caching.CachedContent) - def test_update_cached_content_invalid_update_paths(self): - update_masks = dict( - name="change", - model="models/gemini-1.5-pro-001", - system_instruction="Always add 10 to the result.", - contents=["add this Content"], - ) + def test_update_cached_content_ttl_and_expire_time_are_mutually_exclusive(self): + ttl = datetime.timedelta(hours=2) + expire_time = datetime.datetime.now() cc = caching.CachedContent.get(name="cachedContents/test-cached-content") with self.assertRaises(ValueError): - cc.update(updates=update_masks) + cc.update(ttl=ttl, expire_time=expire_time) - def test_update_cached_content_valid_update_paths(self): - update_masks = dict( - ttl=datetime.timedelta(hours=2), - ) + @parameterized.named_parameters( + [ + dict(testcase_name="ttl", ttl=datetime.timedelta(hours=2)), + dict( + testcase_name="expire_time", + expire_time=datetime.datetime(2024, 6, 5, 12, 12, 12, 23), + ), + ] + ) + def test_update_cached_content_valid_update_paths(self, ttl=None, expire_time=None): cc = caching.CachedContent.get(name="cachedContents/test-cached-content") - cc = cc.update(updates=update_masks) + cc.update(ttl=ttl, expire_time=expire_time) self.assertIsInstance(self.observed_requests[-1], protos.UpdateCachedContentRequest) self.assertIsInstance(cc, caching.CachedContent) @@ -229,17 +251,23 @@ def test_delete_cached_content(self): cc.delete() self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) - def test_auto_delete_cached_content_with_context_manager(self): - with caching.CachedContent.create( - name="test-cached-content", - model="models/gemini-1.0-pro-001", - contents=["Add 5 and 6"], - system_instruction="Always add 10 to the result.", - ttl=datetime.timedelta(minutes=30), - ) as cc: - ... # some logic - - self.assertIsInstance(self.observed_requests[-1], protos.DeleteCachedContentRequest) + def test_repr_cached_content(self): + expexted_repr = textwrap.dedent( + """\ + CachedContent( + name='cachedContents/test-cached-content', + model='models/gemini-1.5-pro', + display_name='Cached content for test', + usage_metadata={ + 'total_token_count': 1, + }, + create_time=2000-01-01 01:01:01.123456+00:00, + update_time=2000-01-01 01:01:01.123456+00:00, + expire_time=2000-01-01 01:01:01.123456+00:00 + )""" + ) + cc = caching.CachedContent.get(name="cachedContents/test-cached-content") + self.assertEqual(repr(cc), expexted_repr) if __name__ == "__main__": diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index c4d46ffec..0a824647f 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -3,9 +3,7 @@ import copy import datetime import pathlib -from typing import Any import textwrap -import unittest.mock from absl.testing import absltest from absl.testing import parameterized from google.generativeai import protos @@ -86,7 +84,9 @@ def get_cached_content( self.observed_requests.append(request) return protos.CachedContent( name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", expire_time="2000-01-01T01:01:01.123456Z", @@ -338,12 +338,16 @@ def test_stream_prompt_feedback_not_blocked(self): dict(testcase_name="test_cached_content_as_id", cached_content="test-cached-content"), dict( testcase_name="test_cached_content_as_CachedContent_object", - cached_content=caching.CachedContent( - name="cachedContents/test-cached-content", - model="models/gemini-1.0-pro-001", - create_time=datetime.datetime.now(), - update_time=datetime.datetime.now(), - expire_time=datetime.datetime.now(), + cached_content=caching.CachedContent._from_obj( + dict( + name="cachedContents/test-cached-content", + model="models/gemini-1.5-pro", + display_name="Cached content for test", + usage_metadata={"total_token_count": 1}, + create_time=datetime.datetime.now(), + update_time=datetime.datetime.now(), + expire_time=datetime.datetime.now(), + ) ), ), ], @@ -353,7 +357,7 @@ def test_model_with_cached_content_as_context(self, cached_content): cc_name = model.cached_content # pytype: disable=attribute-error model_name = model.model_name self.assertEqual(cc_name, "cachedContents/test-cached-content") - self.assertEqual(model_name, "models/gemini-1.0-pro-001") + self.assertEqual(model_name, "models/gemini-1.5-pro") self.assertEqual( model.cached_content, # pytype: disable=attribute-error "cachedContents/test-cached-content", @@ -801,9 +805,9 @@ def test_tool_config(self, tool_config, expected_tool_config): [ "part_dict", {"parts": [{"text": "talk like a pirate"}]}, - simple_part("talk like a pirate"), + protos.Content(parts=[{"text": "talk like a pirate"}]), ], - ["part_list", ["talk like:", "a pirate"], iter_part(["talk like:", "a pirate"])], + ["part_list", ["talk like", "a pirate"], iter_part(["talk like", "a pirate"])], ) def test_system_instruction(self, instruction, expected_instr): self.responses["generate_content"] = [simple_response("echo echo")] @@ -1301,7 +1305,7 @@ def test_repr_for_model_created_from_cahced_content(self): ) result = repr(model) self.assertIn("cached_content=cachedContents/test-cached-content", result) - self.assertIn("model_name='models/gemini-1.0-pro-001'", result) + self.assertIn("model_name='models/gemini-1.5-pro'", result) def test_count_tokens_called_with_request_options(self): self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7))