-
- As of January 1, 2020 this library no longer supports Python 2 on the latest released version. +
+ As of January 1, 2020 this library no longer supports Python 2 on the latest released version. Library versions released prior to that date will continue to be available. For more information please visit Python 2 support on Google Cloud.
diff --git a/docs/summary_overview.md b/docs/summary_overview.md index 7d0ca2569b..38025e91ac 100644 --- a/docs/summary_overview.md +++ b/docs/summary_overview.md @@ -1,14 +1,22 @@ -# Vertex AI API +[ +This is a templated file. Adding content to this file may result in it being +reverted. Instead, if you want to place additional content, create an +"overview_content.md" file in `docs/` directory. The Sphinx tool will +pick up on the content and merge the content. +]: # -Overview of the APIs available for Vertex AI API. +# AI Platform API + +Overview of the APIs available for AI Platform API. ## All entries -Classes, methods and properties & attributes for Vertex AI API. +Classes, methods and properties & attributes for +AI Platform API. [classes](https://cloud.google.com/python/docs/reference/aiplatform/latest/summary_class.html) [methods](https://cloud.google.com/python/docs/reference/aiplatform/latest/summary_method.html) [properties and -attributes](https://cloud.google.com/python/docs/reference/aiplatform/latest/summary_property.html) \ No newline at end of file +attributes](https://cloud.google.com/python/docs/reference/aiplatform/latest/summary_property.html) diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/gapic_version.py +++ b/google/cloud/aiplatform/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py index fd079a2cd0..0e2910acc9 100644 --- a/google/cloud/aiplatform/models.py +++ b/google/cloud/aiplatform/models.py @@ -694,8 +694,6 @@ def __init__( self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name) self.authorized_session = None - self.raw_predict_request_url = None - self.stream_raw_predict_request_url = None @property def _prediction_client(self) -> utils.PredictionClientWithOverride: @@ -783,7 +781,30 @@ def private_service_connect_config( ) -> Optional[gca_service_networking.PrivateServiceConnectConfig]: """The Private Service Connect configuration for this Endpoint.""" self._assert_gca_resource_is_available() - return self._gca_resource.private_service_connect_config + return getattr(self._gca_resource, "private_service_connect_config", None) + + @property + def dedicated_endpoint_dns(self) -> Optional[str]: + """The dedicated endpoint dns for this Endpoint. + + This property is only available if dedicated endpoint is enabled. + If dedicated endpoint is not enabled, this property returns None. + """ + if re.match(r"^projects/.*/endpoints/.*$", self._gca_resource.name): + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "dedicated_endpoint_dns", None) + return None + + @property + def dedicated_endpoint_enabled(self) -> bool: + """The dedicated endpoint is enabled for this Endpoint. + + This property will be true if dedicated endpoint is enabled. + """ + if re.match(r"^projects/.*/endpoints/.*$", self._gca_resource.name): + self._assert_gca_resource_is_available() + return getattr(self._gca_resource, "dedicated_endpoint_enabled", False) + return False @classmethod def create( @@ -1113,8 +1134,6 @@ def _construct_sdk_resource_from_gapic( credentials=credentials, ) endpoint.authorized_session = None - endpoint.raw_predict_request_url = None - endpoint.stream_raw_predict_request_url = None return endpoint @@ -2338,10 +2357,9 @@ def predict( ) -> Prediction: """Make a prediction against this Endpoint. - For dedicated endpoint, set use_dedicated_endpoint = True: + Example usage: ``` - response = my_endpoint.predict(instances=[...], - use_dedicated_endpoint=True) + response = my_endpoint.predict(instances=[...]) my_predictions = response.predictions ``` @@ -2379,11 +2397,19 @@ def predict( Raises: ImportError: If there is an issue importing the `TCPKeepAliveAdapter` package. + ValueError: If the dedicated endpoint DNS is empty for dedicated endpoints. + ValueError: If the prediction request fails for dedicated endpoints. """ self.wait() + + if parameters is not None: + data = json.dumps({"instances": instances, "parameters": parameters}) + else: + data = json.dumps({"instances": instances}) + if use_raw_predict: raw_predict_response = self.raw_predict( - body=json.dumps({"instances": instances, "parameters": parameters}), + body=data, headers={"Content-Type": "application/json"}, use_dedicated_endpoint=use_dedicated_endpoint, timeout=timeout, @@ -2403,63 +2429,7 @@ def predict( ), ) - if use_dedicated_endpoint: - self._sync_gca_resource_if_skipped() - if ( - not self._gca_resource.dedicated_endpoint_enabled - or self._gca_resource.dedicated_endpoint_dns is None - ): - raise ValueError( - "Dedicated endpoint is not enabled or DNS is empty." - "Please make sure endpoint has dedicated endpoint enabled" - "and model are ready before making a prediction." - ) - try: - from requests_toolbelt.adapters.socket_options import ( - TCPKeepAliveAdapter, - ) - except ImportError: - raise ImportError( - "Cannot import the requests-toolbelt library. Please install requests-toolbelt." - ) - - if not self.authorized_session: - self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES - self.authorized_session = google_auth_requests.AuthorizedSession( - self.credentials - ) - - headers = { - "Content-Type": "application/json", - } - - url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:predict" - # count * interval need to be larger than 1 hr (3600s) - keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100) - self.authorized_session.mount("https://", keep_alive) - response = self.authorized_session.post( - url=url, - data=json.dumps( - { - "instances": instances, - "parameters": parameters, - } - ), - headers=headers, - timeout=timeout, - ) - - prediction_response = json.loads(response.text) - - return Prediction( - predictions=prediction_response.get("predictions"), - metadata=prediction_response.get("metadata"), - deployed_model_id=prediction_response.get("deployedModelId"), - model_resource_name=prediction_response.get("model"), - model_version_id=prediction_response.get("modelVersionId"), - ) - - else: + if not self.dedicated_endpoint_enabled: prediction_response = self._prediction_client.predict( endpoint=self._gca_resource.name, instances=instances, @@ -2482,6 +2452,58 @@ def predict( model_resource_name=prediction_response.model, ) + if self.dedicated_endpoint_dns is None: + raise ValueError( + "Dedicated endpoint DNS is empty. Please make sure endpoint" + "and model are ready before making a prediction." + ) + + if not self.authorized_session: + self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES + self.authorized_session = google_auth_requests.AuthorizedSession( + self.credentials + ) + + if timeout is not None and timeout > google_auth_requests._DEFAULT_TIMEOUT: + try: + from requests_toolbelt.adapters.socket_options import ( + TCPKeepAliveAdapter, + ) + except ImportError: + raise ImportError( + "Cannot import the requests-toolbelt library." + "Please install requests-toolbelt." + ) + # count * interval need to be larger than 1 hr (3600s) + keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100) + self.authorized_session.mount("https://", keep_alive) + + url = f"https://{self.dedicated_endpoint_dns}/v1/{self.resource_name}:predict" + headers = { + "Content-Type": "application/json", + } + response = self.authorized_session.post( + url=url, + data=data, + headers=headers, + timeout=timeout, + ) + + if response.status_code != 200: + raise ValueError( + f"Failed to make prediction request. Status code:" + f"{response.status_code}, response: {response.text}." + ) + prediction_response = json.loads(response.text) + + return Prediction( + predictions=prediction_response.get("predictions"), + metadata=prediction_response.get("metadata"), + deployed_model_id=prediction_response.get("deployedModelId"), + model_resource_name=prediction_response.get("model"), + model_version_id=prediction_response.get("modelVersionId"), + ) + async def predict_async( self, instances: List, @@ -2562,12 +2584,6 @@ def raw_predict( body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}' headers = {'Content-Type':'application/json'} ) - # For dedicated endpoint: - response = my_endpoint.raw_predict( - body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}', - headers = {'Content-Type':'application/json'}, - dedicated_endpoint=True, - ) status_code = response.status_code results = json.dumps(response.text) @@ -2593,34 +2609,29 @@ def raw_predict( self.credentials ) - if self.raw_predict_request_url is None: - self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict" - - url = self.raw_predict_request_url - - if use_dedicated_endpoint: - try: - from requests_toolbelt.adapters.socket_options import ( - TCPKeepAliveAdapter, - ) - except ImportError: - raise ImportError( - "Cannot import the requests-toolbelt library. Please install requests-toolbelt." - ) - self._sync_gca_resource_if_skipped() - if ( - not self._gca_resource.dedicated_endpoint_enabled - or self._gca_resource.dedicated_endpoint_dns is None - ): + if self.dedicated_endpoint_enabled: + if self.dedicated_endpoint_dns is None: raise ValueError( - "Dedicated endpoint is not enabled or DNS is empty." - "Please make sure endpoint has dedicated endpoint enabled" + "Dedicated endpoint DNS is empty. Please make sure endpoint" "and model are ready before making a prediction." ) - url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict" - # count * interval need to be larger than 1 hr (3600s) - keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100) - self.authorized_session.mount("https://", keep_alive) + url = f"https://{self.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict" + + if timeout is not None and timeout > google_auth_requests._DEFAULT_TIMEOUT: + try: + from requests_toolbelt.adapters.socket_options import ( + TCPKeepAliveAdapter, + ) + except ImportError: + raise ImportError( + "Cannot import the requests-toolbelt library." + "Please install requests-toolbelt." + ) + # count * interval need to be larger than 1 hr (3600s) + keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100) + self.authorized_session.mount("https://", keep_alive) + else: + url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict" return self.authorized_session.post( url=url, data=body, headers=headers, timeout=timeout @@ -2648,18 +2659,6 @@ def stream_raw_predict( stream_result = json.dumps(response.text) ``` - For dedicated endpoint: - ``` - my_endpoint = aiplatform.Endpoint(ENDPOINT_ID) - for stream_response in my_endpoint.stream_raw_predict( - body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}', - headers = {'Content-Type':'application/json'}, - use_dedicated_endpoint=True, - ): - status_code = response.status_code - stream_result = json.dumps(response.text) - ``` - Args: body (bytes): The body of the prediction request in bytes. This must not @@ -2682,23 +2681,15 @@ def stream_raw_predict( self.credentials ) - if self.stream_raw_predict_request_url is None: - self.stream_raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict" - - url = self.stream_raw_predict_request_url - - if use_dedicated_endpoint: - self._sync_gca_resource_if_skipped() - if ( - not self._gca_resource.dedicated_endpoint_enabled - or self._gca_resource.dedicated_endpoint_dns is None - ): + if self.dedicated_endpoint_enabled: + if self.dedicated_endpoint_dns is None: raise ValueError( - "Dedicated endpoint is not enabled or DNS is empty." - "Please make sure endpoint has dedicated endpoint enabled" + "Dedicated endpoint DNS is empty. Please make sure endpoint" "and model are ready before making a prediction." ) - url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:streamRawPredict" + url = f"https://{self.dedicated_endpoint_dns}/v1/{self.resource_name}:streamRawPredict" + else: + url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict" with self.authorized_session.post( url=url, @@ -3332,6 +3323,7 @@ def __init__( ) self._http_client = urllib3.PoolManager(cert_reqs="CERT_NONE") + self._authorized_session = None @property def predict_http_uri(self) -> Optional[str]: @@ -3604,6 +3596,7 @@ def _construct_sdk_resource_from_gapic( ) endpoint._http_client = urllib3.PoolManager(cert_reqs="CERT_NONE") + endpoint._authorized_session = None return endpoint @@ -3776,24 +3769,31 @@ def predict( "address or DNS." ) - if not self.credentials.valid: - self.credentials.refresh(google_auth_requests.Request()) + if not self._authorized_session: + self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES + self._authorized_session = google_auth_requests.AuthorizedSession( + self.credentials, + ) + self._authorized_session.verify = False - token = self.credentials.token - headers = { - "Authorization": f"Bearer {token}", - "Content-Type": "application/json", - } + if parameters: + data = json.dumps({"instances": instances, "parameters": parameters}) + else: + data = json.dumps({"instances": instances}) url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:predict" - response = self._http_request( - method="POST", + response = self._authorized_session.post( url=url, - body=json.dumps({"instances": instances, "parameters": parameters}), - headers=headers, + data=data, + headers={"Content-Type": "application/json"}, ) - prediction_response = json.loads(response.data) + if response.status_code != 200: + raise ValueError( + f"Failed to make prediction request. Status code:" + f"{response.status_code}, response: {response.text}." + ) + prediction_response = json.loads(response.text) return Prediction( predictions=prediction_response.get("predictions"), @@ -3873,19 +3873,19 @@ def raw_predict( "Invalid endpoint override provided. Please only use IP" "address or DNS." ) - if not self.credentials.valid: - self.credentials.refresh(google_auth_requests.Request()) - token = self.credentials.token - headers_with_token = dict(headers) - headers_with_token["Authorization"] = f"Bearer {token}" + if not self._authorized_session: + self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES + self._authorized_session = google_auth_requests.AuthorizedSession( + self.credentials, + ) + self._authorized_session.verify = False url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict" - return self._http_request( - method="POST", + return self._authorized_session.post( url=url, body=body, - headers=headers_with_token, + headers=headers, ) def stream_raw_predict( @@ -3953,24 +3953,19 @@ def stream_raw_predict( "Invalid endpoint override provided. Please only use IP" "address or DNS." ) - if not self.credentials.valid: - self.credentials.refresh(google_auth_requests.Request()) - token = self.credentials.token - headers_with_token = dict(headers) - headers_with_token["Authorization"] = f"Bearer {token}" - - if not self.authorized_session: + if not self._authorized_session: self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES - self.authorized_session = google_auth_requests.AuthorizedSession( - self.credentials + self._authorized_session = google_auth_requests.AuthorizedSession( + self.credentials, ) + self._authorized_session.verify = False url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict" - with self.authorized_session.post( + with self._authorized_session.post( url=url, data=body, - headers=headers_with_token, + headers=headers, stream=True, verify=False, ) as resp: diff --git a/google/cloud/aiplatform/persistent_resource.py b/google/cloud/aiplatform/persistent_resource.py index f0944a5bb4..06c1914676 100644 --- a/google/cloud/aiplatform/persistent_resource.py +++ b/google/cloud/aiplatform/persistent_resource.py @@ -171,6 +171,7 @@ def create( labels: Optional[Dict[str, str]] = None, network: Optional[str] = None, kms_key_name: Optional[str] = None, + enable_custom_service_account: Optional[bool] = None, service_account: Optional[str] = None, reserved_ip_ranges: List[str] = None, sync: Optional[bool] = True, # pylint: disable=unused-argument @@ -234,6 +235,16 @@ def create( PersistentResource. If set, this PersistentResource and all sub-resources of this PersistentResource will be secured by this key. + enable_custom_service_account (bool): + Optional. When set to True, allows the `service_account` + parameter to specify a custom service account for workloads on this + PersistentResource. Defaults to None (False behavior). + + If True, the service account provided in the `service_account` parameter + will be used for workloads (runtimes, jobs), provided the user has the + ``iam.serviceAccounts.actAs`` permission. If False, the + `service_account` parameter is ignored, and the PersistentResource + will use the default service account. service_account (str): Optional. Default service account that this PersistentResource's workloads run as. The workloads @@ -295,7 +306,31 @@ def create( gca_encryption_spec_compat.EncryptionSpec(kms_key_name=kms_key_name) ) - if service_account: + # Raise ValueError if enable_custom_service_account is False but + # service_account is provided + if ( + enable_custom_service_account is False and service_account is not None + ): # pylint: disable=g-bool-id-comparison + raise ValueError( + "The parameter `enable_custom_service_account` was set to False, " + "but a value was provided for `service_account`. These two " + "settings are incompatible. If you want to use a custom " + "service account, set `enable_custom_service_account` to True." + ) + + elif enable_custom_service_account: + service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec( + enable_custom_service_account=True, + # Set service_account if it is provided, otherwise set to None + service_account=service_account if service_account else None, + ) + gca_persistent_resource.resource_runtime_spec = ( + gca_persistent_resource_compat.ResourceRuntimeSpec( + service_account_spec=service_account_spec + ) + ) + elif service_account: + # Handle the deprecated case where only service_account is provided service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec( enable_custom_service_account=True, service_account=service_account ) diff --git a/google/cloud/aiplatform/preview/persistent_resource.py b/google/cloud/aiplatform/preview/persistent_resource.py index e06b08e90b..54a9d55d3d 100644 --- a/google/cloud/aiplatform/preview/persistent_resource.py +++ b/google/cloud/aiplatform/preview/persistent_resource.py @@ -177,6 +177,7 @@ def create( labels: Optional[Dict[str, str]] = None, network: Optional[str] = None, kms_key_name: Optional[str] = None, + enable_custom_service_account: Optional[bool] = None, service_account: Optional[str] = None, reserved_ip_ranges: List[str] = None, sync: Optional[bool] = True, # pylint: disable=unused-argument @@ -240,6 +241,16 @@ def create( PersistentResource. If set, this PersistentResource and all sub-resources of this PersistentResource will be secured by this key. + enable_custom_service_account (bool): + Optional. When set to True, allows the `service_account` + parameter to specify a custom service account for workloads on this + PersistentResource. Defaults to None (False behavior). + + If True, the service account provided in the `service_account` parameter + will be used for workloads (runtimes, jobs), provided the user has the + ``iam.serviceAccounts.actAs`` permission. If False, the + `service_account` parameter is ignored, and the PersistentResource + will use the default service account. service_account (str): Optional. Default service account that this PersistentResource's workloads run as. The workloads @@ -301,7 +312,29 @@ def create( gca_encryption_spec_compat.EncryptionSpec(kms_key_name=kms_key_name) ) - if service_account: + if ( + enable_custom_service_account is False and service_account is not None + ): # pylint: disable=g-bool-id-comparison + raise ValueError( + "The parameter `enable_custom_service_account` was set to False, " + "but a value was provided for `service_account`. These two " + "settings are incompatible. If you want to use a custom " + "service account, set `enable_custom_service_account` to True." + ) + + elif enable_custom_service_account: + service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec( + enable_custom_service_account=True, + # Set service_account if it is provided, otherwise set to None + service_account=service_account if service_account else None, + ) + gca_persistent_resource.resource_runtime_spec = ( + gca_persistent_resource_compat.ResourceRuntimeSpec( + service_account_spec=service_account_spec + ) + ) + elif service_account: + # Handle the deprecated case where only service_account is provided service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec( enable_custom_service_account=True, service_account=service_account ) diff --git a/google/cloud/aiplatform/releases.txt b/google/cloud/aiplatform/releases.txt index ad3b915fdf..09a949b7da 100644 --- a/google/cloud/aiplatform/releases.txt +++ b/google/cloud/aiplatform/releases.txt @@ -1,4 +1,4 @@ Use this file when you need to force a patch release with release-please. Edit line 4 below with the version for the release. -1.55.0 \ No newline at end of file +1.92.0 \ No newline at end of file diff --git a/google/cloud/aiplatform/tensorboard/uploader_constants.py b/google/cloud/aiplatform/tensorboard/uploader_constants.py index 78ce23c614..16e3c3672a 100644 --- a/google/cloud/aiplatform/tensorboard/uploader_constants.py +++ b/google/cloud/aiplatform/tensorboard/uploader_constants.py @@ -2,28 +2,24 @@ import dataclasses -from tensorboard.plugins.distribution import ( - metadata as distribution_metadata, -) -from tensorboard.plugins.graph import metadata as graphs_metadata -from tensorboard.plugins.histogram import ( - metadata as histogram_metadata, -) -from tensorboard.plugins.hparams import metadata as hparams_metadata -from tensorboard.plugins.image import metadata as images_metadata -from tensorboard.plugins.scalar import metadata as scalar_metadata -from tensorboard.plugins.text import metadata as text_metadata +from tensorboard.plugins.distribution import distributions_plugin +from tensorboard.plugins.graph import graphs_plugin +from tensorboard.plugins.histogram import histograms_plugin +from tensorboard.plugins.hparams import hparams_plugin +from tensorboard.plugins.image import images_plugin +from tensorboard.plugins.scalar import scalars_plugin +from tensorboard.plugins.text import text_plugin PROFILE_PLUGIN_NAME = "profile" ALLOWED_PLUGINS = frozenset( [ - scalar_metadata.PLUGIN_NAME, - histogram_metadata.PLUGIN_NAME, - distribution_metadata.PLUGIN_NAME, - text_metadata.PLUGIN_NAME, - hparams_metadata.PLUGIN_NAME, - images_metadata.PLUGIN_NAME, - graphs_metadata.PLUGIN_NAME, + scalars_plugin.ScalarsPlugin.plugin_name, + histograms_plugin.HistogramsPlugin.plugin_name, + distributions_plugin.DistributionsPlugin.plugin_name, + text_plugin.TextPlugin.plugin_name, + hparams_plugin.HParamsPlugin.plugin_name, + images_plugin.ImagesPlugin.plugin_name, + graphs_plugin.GraphsPlugin.plugin_name, PROFILE_PLUGIN_NAME, ] ) diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py +++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py index fedf961642..2982cf33c8 100644 --- a/google/cloud/aiplatform/version.py +++ b/google/cloud/aiplatform/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.91.0" +__version__ = "1.92.0" diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py index a4287d4f6c..71170aee11 100644 --- a/google/cloud/aiplatform_v1/__init__.py +++ b/google/cloud/aiplatform_v1/__init__.py @@ -652,6 +652,7 @@ from .types.migration_service import MigrateResourceResponse from .types.migration_service import SearchMigratableResourcesRequest from .types.migration_service import SearchMigratableResourcesResponse +from .types.model import Checkpoint from .types.model import GenieSource from .types.model import LargeModelReference from .types.model import Model @@ -1134,6 +1135,7 @@ "CheckTrialEarlyStoppingStateMetatdata", "CheckTrialEarlyStoppingStateRequest", "CheckTrialEarlyStoppingStateResponse", + "Checkpoint", "Citation", "CitationMetadata", "Claim", diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform_v1/gapic_version.py +++ b/google/cloud/aiplatform_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1/services/gen_ai_cache_service/async_client.py b/google/cloud/aiplatform_v1/services/gen_ai_cache_service/async_client.py index 16267080e8..c229cffbbc 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_cache_service/async_client.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_cache_service/async_client.py @@ -48,6 +48,7 @@ from google.cloud.aiplatform_v1.types import cached_content from google.cloud.aiplatform_v1.types import cached_content as gca_cached_content from google.cloud.aiplatform_v1.types import content +from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import gen_ai_cache_service from google.cloud.aiplatform_v1.types import tool from google.cloud.location import locations_pb2 # type: ignore diff --git a/google/cloud/aiplatform_v1/services/gen_ai_cache_service/client.py b/google/cloud/aiplatform_v1/services/gen_ai_cache_service/client.py index 82ba257ad2..b327f67bd9 100644 --- a/google/cloud/aiplatform_v1/services/gen_ai_cache_service/client.py +++ b/google/cloud/aiplatform_v1/services/gen_ai_cache_service/client.py @@ -64,6 +64,7 @@ from google.cloud.aiplatform_v1.types import cached_content from google.cloud.aiplatform_v1.types import cached_content as gca_cached_content from google.cloud.aiplatform_v1.types import content +from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import gen_ai_cache_service from google.cloud.aiplatform_v1.types import tool from google.cloud.location import locations_pb2 # type: ignore diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py index 886fa6bf79..da85ad6aed 100644 --- a/google/cloud/aiplatform_v1/services/migration_service/client.py +++ b/google/cloud/aiplatform_v1/services/migration_service/client.py @@ -242,62 +242,57 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]: @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, + location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/datasets/{dataset}".format( + return "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) + m = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", + path, + ) return m.groupdict() if m else {} @staticmethod def dataset_path( project: str, - location: str, dataset: str, ) -> str: """Returns a fully-qualified dataset string.""" - return "projects/{project}/locations/{location}/datasets/{dataset}".format( + return "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) @staticmethod def parse_dataset_path(path: str) -> Dict[str, str]: """Parses a dataset path into its component segments.""" - m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$", - path, - ) + m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path) return m.groupdict() if m else {} @staticmethod diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py index cfffa3544d..48888a1e70 100644 --- a/google/cloud/aiplatform_v1/types/__init__.py +++ b/google/cloud/aiplatform_v1/types/__init__.py @@ -689,6 +689,7 @@ SearchMigratableResourcesResponse, ) from .model import ( + Checkpoint, GenieSource, LargeModelReference, Model, @@ -1713,6 +1714,7 @@ "MigrateResourceResponse", "SearchMigratableResourcesRequest", "SearchMigratableResourcesResponse", + "Checkpoint", "GenieSource", "LargeModelReference", "Model", diff --git a/google/cloud/aiplatform_v1/types/cached_content.py b/google/cloud/aiplatform_v1/types/cached_content.py index f4e0711a8d..f3fb112609 100644 --- a/google/cloud/aiplatform_v1/types/cached_content.py +++ b/google/cloud/aiplatform_v1/types/cached_content.py @@ -20,6 +20,7 @@ import proto # type: ignore from google.cloud.aiplatform_v1.types import content +from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1.types import tool from google.protobuf import duration_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore @@ -64,10 +65,10 @@ class CachedContent(proto.Message): Optional. Immutable. The user-generated meaningful display name of the cached content. model (str): - Immutable. The name of the publisher model to - use for cached content. Format: - - projects/{project}/locations/{location}/publishers/{publisher}/models/{model} + Immutable. The name of the ``Model`` to use for cached + content. Currently, only the published Gemini base models + are supported, in form of + projects/{PROJECT}/locations/{LOCATION}/publishers/google/models/{MODEL} system_instruction (google.cloud.aiplatform_v1.types.Content): Optional. Input only. Immutable. Developer set system instruction. Currently, text only @@ -81,7 +82,7 @@ class CachedContent(proto.Message): Optional. Input only. Immutable. Tool config. This config is shared for all tools create_time (google.protobuf.timestamp_pb2.Timestamp): - Output only. Creatation time of the cache + Output only. Creation time of the cache entry. update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. When the cache entry was last @@ -89,6 +90,10 @@ class CachedContent(proto.Message): usage_metadata (google.cloud.aiplatform_v1.types.CachedContent.UsageMetadata): Output only. Metadata on the usage of the cached content. + encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec): + Input only. Immutable. Customer-managed encryption key spec + for a ``CachedContent``. If set, this ``CachedContent`` and + all its sub-resources will be secured by this key. """ class UsageMetadata(proto.Message): @@ -188,6 +193,11 @@ class UsageMetadata(proto.Message): number=12, message=UsageMetadata, ) + encryption_spec: gca_encryption_spec.EncryptionSpec = proto.Field( + proto.MESSAGE, + number=13, + message=gca_encryption_spec.EncryptionSpec, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/content.py b/google/cloud/aiplatform_v1/types/content.py index d94dbf0c1b..76c2a8ff61 100644 --- a/google/cloud/aiplatform_v1/types/content.py +++ b/google/cloud/aiplatform_v1/types/content.py @@ -70,7 +70,8 @@ class HarmCategory(proto.Enum): The harm category is sexually explicit content. HARM_CATEGORY_CIVIC_INTEGRITY (5): - The harm category is civic integrity. + Deprecated: Election filter is not longer + supported. The harm category is civic integrity. """ HARM_CATEGORY_UNSPECIFIED = 0 HARM_CATEGORY_HATE_SPEECH = 1 @@ -393,6 +394,10 @@ class GenerationConfig(proto.Message): Optional. Routing configuration. This field is a member of `oneof`_ ``_routing_config``. + thinking_config (google.cloud.aiplatform_v1.types.GenerationConfig.ThinkingConfig): + Optional. Config for thinking features. + An error will be returned if this field is set + for models that don't support thinking. """ class RoutingConfig(proto.Message): @@ -491,6 +496,25 @@ class ManualRoutingMode(proto.Message): message="GenerationConfig.RoutingConfig.ManualRoutingMode", ) + class ThinkingConfig(proto.Message): + r"""Config for thinking features. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + thinking_budget (int): + Optional. Indicates the thinking budget in tokens. This is + only applied when enable_thinking is true. + + This field is a member of `oneof`_ ``_thinking_budget``. + """ + + thinking_budget: int = proto.Field( + proto.INT32, + number=3, + optional=True, + ) + temperature: float = proto.Field( proto.FLOAT, number=1, @@ -561,6 +585,11 @@ class ManualRoutingMode(proto.Message): optional=True, message=RoutingConfig, ) + thinking_config: ThinkingConfig = proto.Field( + proto.MESSAGE, + number=25, + message=ThinkingConfig, + ) class SafetySetting(proto.Message): diff --git a/google/cloud/aiplatform_v1/types/io.py b/google/cloud/aiplatform_v1/types/io.py index 7954b603e9..0fcc4b8080 100644 --- a/google/cloud/aiplatform_v1/types/io.py +++ b/google/cloud/aiplatform_v1/types/io.py @@ -82,7 +82,7 @@ class GcsSource(proto.Message): Required. Google Cloud Storage URI(-s) to the input file(s). May contain wildcards. For more information on wildcards, see - https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + https://cloud.google.com/storage/docs/wildcards. """ uris: MutableSequence[str] = proto.RepeatedField( diff --git a/google/cloud/aiplatform_v1/types/model.py b/google/cloud/aiplatform_v1/types/model.py index 3b1ac173e9..878ebdee0d 100644 --- a/google/cloud/aiplatform_v1/types/model.py +++ b/google/cloud/aiplatform_v1/types/model.py @@ -40,6 +40,7 @@ "Port", "ModelSourceInfo", "Probe", + "Checkpoint", }, ) @@ -322,6 +323,9 @@ class Model(proto.Message): Output only. Reserved for future use. satisfies_pzi (bool): Output only. Reserved for future use. + checkpoints (MutableSequence[google.cloud.aiplatform_v1.types.Checkpoint]): + Optional. Output only. The checkpoints of the + model. """ class DeploymentResourcesType(proto.Enum): @@ -681,6 +685,11 @@ class BaseModelSource(proto.Message): proto.BOOL, number=52, ) + checkpoints: MutableSequence["Checkpoint"] = proto.RepeatedField( + proto.MESSAGE, + number=57, + message="Checkpoint", + ) class LargeModelReference(proto.Message): @@ -1468,4 +1477,30 @@ class HttpHeader(proto.Message): ) +class Checkpoint(proto.Message): + r"""Describes the machine learning model version checkpoint. + + Attributes: + checkpoint_id (str): + The ID of the checkpoint. + epoch (int): + The epoch of the checkpoint. + step (int): + The step of the checkpoint. + """ + + checkpoint_id: str = proto.Field( + proto.STRING, + number=1, + ) + epoch: int = proto.Field( + proto.INT64, + number=2, + ) + step: int = proto.Field( + proto.INT64, + number=3, + ) + + __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/openapi.py b/google/cloud/aiplatform_v1/types/openapi.py index 79525e4834..9f5da2ed40 100644 --- a/google/cloud/aiplatform_v1/types/openapi.py +++ b/google/cloud/aiplatform_v1/types/openapi.py @@ -138,6 +138,23 @@ class Schema(proto.Message): Optional. The value should be validated against any (one or more) of the subschemas in the list. + ref (str): + Optional. Allows indirect references between schema nodes. + The value should be a valid reference to a child of the root + ``defs``. + + For example, the following schema defines a reference to a + schema node named "Pet": + + type: object properties: pet: ref: #/defs/Pet defs: Pet: + type: object properties: name: type: string + + The value of the "pet" property is a reference to the schema + node named "Pet". See details in + https://json-schema.org/understanding-json-schema/structuring + defs (MutableMapping[str, google.cloud.aiplatform_v1.types.Schema]): + Optional. A map of definitions for use by ``ref`` Only + allowed at the root of the schema. """ type_: "Type" = proto.Field( @@ -235,6 +252,16 @@ class Schema(proto.Message): number=11, message="Schema", ) + ref: str = proto.Field( + proto.STRING, + number=27, + ) + defs: MutableMapping[str, "Schema"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=28, + message="Schema", + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1/types/vertex_rag_data.py index fd251f3db8..22692ee84b 100644 --- a/google/cloud/aiplatform_v1/types/vertex_rag_data.py +++ b/google/cloud/aiplatform_v1/types/vertex_rag_data.py @@ -615,12 +615,21 @@ class RagFileTransformationConfig(proto.Message): class RagFileParsingConfig(proto.Message): r"""Specifies the parsing config for RagFiles. + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields Attributes: layout_parser (google.cloud.aiplatform_v1.types.RagFileParsingConfig.LayoutParser): The Layout Parser to use for RagFiles. + This field is a member of `oneof`_ ``parser``. + llm_parser (google.cloud.aiplatform_v1.types.RagFileParsingConfig.LlmParser): + The LLM Parser to use for RagFiles. + This field is a member of `oneof`_ ``parser``. """ @@ -656,12 +665,52 @@ class LayoutParser(proto.Message): number=2, ) + class LlmParser(proto.Message): + r"""Specifies the advanced parsing for RagFiles. + + Attributes: + model_name (str): + The name of a LLM model used for parsing. Format: + + - ``projects/{project_id}/locations/{location}/publishers/{publisher}/models/{model}`` + max_parsing_requests_per_min (int): + The maximum number of requests the job is + allowed to make to the LLM model per minute. + Consult + https://cloud.google.com/vertex-ai/generative-ai/docs/quotas + and your document size to set an appropriate + value here. If unspecified, a default value of + 5000 QPM would be used. + custom_parsing_prompt (str): + The prompt to use for parsing. If not + specified, a default prompt will be used. + """ + + model_name: str = proto.Field( + proto.STRING, + number=1, + ) + max_parsing_requests_per_min: int = proto.Field( + proto.INT32, + number=2, + ) + custom_parsing_prompt: str = proto.Field( + proto.STRING, + number=3, + ) + layout_parser: LayoutParser = proto.Field( proto.MESSAGE, number=4, oneof="parser", message=LayoutParser, ) + llm_parser: LlmParser = proto.Field( + proto.MESSAGE, + number=5, + oneof="parser", + message=LlmParser, + ) class UploadRagFileConfig(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py index fe88ef9c23..07334a27e8 100644 --- a/google/cloud/aiplatform_v1beta1/__init__.py +++ b/google/cloud/aiplatform_v1beta1/__init__.py @@ -175,6 +175,7 @@ from .types.dataset_service import ExportDataRequest from .types.dataset_service import ExportDataResponse from .types.dataset_service import GeminiExample +from .types.dataset_service import GeminiRequestReadConfig from .types.dataset_service import GeminiTemplateConfig from .types.dataset_service import GetAnnotationSpecRequest from .types.dataset_service import GetDatasetRequest @@ -1703,6 +1704,7 @@ "GcsDestination", "GcsSource", "GeminiExample", + "GeminiRequestReadConfig", "GeminiTemplateConfig", "GenAiCacheServiceClient", "GenAiTuningServiceClient", diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py index dbb6c23a4a..4cfdd5d77d 100644 --- a/google/cloud/aiplatform_v1beta1/gapic_version.py +++ b/google/cloud/aiplatform_v1beta1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "1.91.0" # {x-release-please-version} +__version__ = "1.92.0" # {x-release-please-version} diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py index 2b03b00214..429f6b5505 100644 --- a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py @@ -48,6 +48,7 @@ from google.cloud.aiplatform_v1beta1.types import cached_content from google.cloud.aiplatform_v1beta1.types import cached_content as gca_cached_content from google.cloud.aiplatform_v1beta1.types import content +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import gen_ai_cache_service from google.cloud.aiplatform_v1beta1.types import tool from google.cloud.location import locations_pb2 # type: ignore diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py index 6eb1bc6e15..bdcd049578 100644 --- a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py @@ -64,6 +64,7 @@ from google.cloud.aiplatform_v1beta1.types import cached_content from google.cloud.aiplatform_v1beta1.types import cached_content as gca_cached_content from google.cloud.aiplatform_v1beta1.types import content +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import gen_ai_cache_service from google.cloud.aiplatform_v1beta1.types import tool from google.cloud.location import locations_pb2 # type: ignore diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/session_service/async_client.py index e605cf2bb3..80722280ed 100644 --- a/google/cloud/aiplatform_v1beta1/services/session_service/async_client.py +++ b/google/cloud/aiplatform_v1beta1/services/session_service/async_client.py @@ -313,8 +313,8 @@ async def create_session( timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> operation_async.AsyncOperation: - r"""Creates a new [Session][google.cloud.aiplatform.v1beta1.Session] - in a given project and location. + r"""Creates a new + [Session][google.cloud.aiplatform.v1beta1.Session]. .. code-block:: python @@ -358,7 +358,6 @@ async def sample_create_session(): parent (:class:`str`): Required. The resource name of the location to create the session in. Format: - ``projects/{project}/locations/{location}`` or ``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`` This corresponds to the ``parent`` field @@ -568,7 +567,7 @@ async def list_sessions( metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> pagers.ListSessionsAsyncPager: r"""Lists [Sessions][google.cloud.aiplatform.v1beta1.Session] in a - given project and location. + given reasoning engine. .. code-block:: python @@ -864,8 +863,6 @@ async def sample_delete_session(): [SessionService.DeleteSession][google.cloud.aiplatform.v1beta1.SessionService.DeleteSession]. name (:class:`str`): Required. The resource name of the session. Format: - ``projects/{project}/locations/{location}/sessions/{session}`` - or ``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}`` This corresponds to the ``name`` field diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/client.py b/google/cloud/aiplatform_v1beta1/services/session_service/client.py index d25604a46a..599f1bdc1b 100644 --- a/google/cloud/aiplatform_v1beta1/services/session_service/client.py +++ b/google/cloud/aiplatform_v1beta1/services/session_service/client.py @@ -225,12 +225,14 @@ def transport(self) -> SessionServiceTransport: def session_path( project: str, location: str, + reasoning_engine: str, session: str, ) -> str: """Returns a fully-qualified session string.""" - return "projects/{project}/locations/{location}/sessions/{session}".format( + return "projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}".format( project=project, location=location, + reasoning_engine=reasoning_engine, session=session, ) @@ -238,7 +240,7 @@ def session_path( def parse_session_path(path: str) -> Dict[str, str]: """Parses a session path into its component segments.""" m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/sessions/(?P.+?)$", + r"^projects/(?P.+?)/locations/(?P.+?)/reasoningEngines/(?P.+?)/sessions/(?P.+?)$", path, ) return m.groupdict() if m else {} @@ -247,13 +249,15 @@ def parse_session_path(path: str) -> Dict[str, str]: def session_event_path( project: str, location: str, + reasoning_engine: str, session: str, event: str, ) -> str: """Returns a fully-qualified session_event string.""" - return "projects/{project}/locations/{location}/sessions/{session}/events/{event}".format( + return "projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}/events/{event}".format( project=project, location=location, + reasoning_engine=reasoning_engine, session=session, event=event, ) @@ -262,7 +266,7 @@ def session_event_path( def parse_session_event_path(path: str) -> Dict[str, str]: """Parses a session_event path into its component segments.""" m = re.match( - r"^projects/(?P.+?)/locations/(?P.+?)/sessions/(?P.+?)/events/(?P.+?)$", + r"^projects/(?P.+?)/locations/(?P.+?)/reasoningEngines/(?P.+?)/sessions/(?P.+?)/events/(?P.+?)$", path, ) return m.groupdict() if m else {} @@ -794,8 +798,8 @@ def create_session( timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> gac_operation.Operation: - r"""Creates a new [Session][google.cloud.aiplatform.v1beta1.Session] - in a given project and location. + r"""Creates a new + [Session][google.cloud.aiplatform.v1beta1.Session]. .. code-block:: python @@ -839,7 +843,6 @@ def sample_create_session(): parent (str): Required. The resource name of the location to create the session in. Format: - ``projects/{project}/locations/{location}`` or ``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`` This corresponds to the ``parent`` field @@ -1043,7 +1046,7 @@ def list_sessions( metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), ) -> pagers.ListSessionsPager: r"""Lists [Sessions][google.cloud.aiplatform.v1beta1.Session] in a - given project and location. + given reasoning engine. .. code-block:: python @@ -1333,8 +1336,6 @@ def sample_delete_session(): [SessionService.DeleteSession][google.cloud.aiplatform.v1beta1.SessionService.DeleteSession]. name (str): Required. The resource name of the session. Format: - ``projects/{project}/locations/{location}/sessions/{session}`` - or ``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}`` This corresponds to the ``name`` field diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc.py index 938378fc83..7f0ffd3884 100644 --- a/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc.py +++ b/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc.py @@ -350,8 +350,8 @@ def create_session( ) -> Callable[[session_service.CreateSessionRequest], operations_pb2.Operation]: r"""Return a callable for the create session method over gRPC. - Creates a new [Session][google.cloud.aiplatform.v1beta1.Session] - in a given project and location. + Creates a new + [Session][google.cloud.aiplatform.v1beta1.Session]. Returns: Callable[[~.CreateSessionRequest], @@ -407,7 +407,7 @@ def list_sessions( r"""Return a callable for the list sessions method over gRPC. Lists [Sessions][google.cloud.aiplatform.v1beta1.Session] in a - given project and location. + given reasoning engine. Returns: Callable[[~.ListSessionsRequest], diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc_asyncio.py index 6842c62242..36f4cc53ec 100644 --- a/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc_asyncio.py +++ b/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc_asyncio.py @@ -360,8 +360,8 @@ def create_session( ]: r"""Return a callable for the create session method over gRPC. - Creates a new [Session][google.cloud.aiplatform.v1beta1.Session] - in a given project and location. + Creates a new + [Session][google.cloud.aiplatform.v1beta1.Session]. Returns: Callable[[~.CreateSessionRequest], @@ -418,7 +418,7 @@ def list_sessions( r"""Return a callable for the list sessions method over gRPC. Lists [Sessions][google.cloud.aiplatform.v1beta1.Session] in a - given project and location. + given reasoning engine. Returns: Callable[[~.ListSessionsRequest], diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/transports/rest_base.py b/google/cloud/aiplatform_v1beta1/services/session_service/transports/rest_base.py index f7a621b52a..73df9faad5 100644 --- a/google/cloud/aiplatform_v1beta1/services/session_service/transports/rest_base.py +++ b/google/cloud/aiplatform_v1beta1/services/session_service/transports/rest_base.py @@ -169,11 +169,6 @@ def _get_unset_required_fields(cls, message_dict): @staticmethod def _get_http_options(): http_options: List[Dict[str, str]] = [ - { - "method": "post", - "uri": "/v1beta1/{parent=projects/*/locations/*}/sessions", - "body": "session", - }, { "method": "post", "uri": "/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*}/sessions", @@ -231,10 +226,6 @@ def _get_unset_required_fields(cls, message_dict): @staticmethod def _get_http_options(): http_options: List[Dict[str, str]] = [ - { - "method": "delete", - "uri": "/v1beta1/{name=projects/*/locations/*/sessions/*}", - }, { "method": "delete", "uri": "/v1beta1/{name=projects/*/locations/*/reasoningEngines/*/sessions/*}", @@ -282,10 +273,6 @@ def _get_unset_required_fields(cls, message_dict): @staticmethod def _get_http_options(): http_options: List[Dict[str, str]] = [ - { - "method": "get", - "uri": "/v1beta1/{name=projects/*/locations/*/sessions/*}", - }, { "method": "get", "uri": "/v1beta1/{name=projects/*/locations/*/reasoningEngines/*/sessions/*}", @@ -337,10 +324,6 @@ def _get_http_options(): "method": "get", "uri": "/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*/sessions/*}/events", }, - { - "method": "get", - "uri": "/v1beta1/{parent=projects/*/locations/*/sessions/*}/events", - }, ] return http_options @@ -384,10 +367,6 @@ def _get_unset_required_fields(cls, message_dict): @staticmethod def _get_http_options(): http_options: List[Dict[str, str]] = [ - { - "method": "get", - "uri": "/v1beta1/{parent=projects/*/locations/*}/sessions", - }, { "method": "get", "uri": "/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*}/sessions", @@ -435,11 +414,6 @@ def _get_unset_required_fields(cls, message_dict): @staticmethod def _get_http_options(): http_options: List[Dict[str, str]] = [ - { - "method": "patch", - "uri": "/v1beta1/{session.name=projects/*/locations/*/sessions/*}", - "body": "session", - }, { "method": "patch", "uri": "/v1beta1/{session.name=projects/*/locations/*/reasoningEngines/*/sessions/*}", diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py index b72fd73b19..121f8afc9e 100644 --- a/google/cloud/aiplatform_v1beta1/types/__init__.py +++ b/google/cloud/aiplatform_v1beta1/types/__init__.py @@ -108,6 +108,7 @@ ExportDataRequest, ExportDataResponse, GeminiExample, + GeminiRequestReadConfig, GeminiTemplateConfig, GetAnnotationSpecRequest, GetDatasetRequest, @@ -1481,6 +1482,7 @@ "ExportDataRequest", "ExportDataResponse", "GeminiExample", + "GeminiRequestReadConfig", "GeminiTemplateConfig", "GetAnnotationSpecRequest", "GetDatasetRequest", diff --git a/google/cloud/aiplatform_v1beta1/types/cached_content.py b/google/cloud/aiplatform_v1beta1/types/cached_content.py index b804f9d2f1..a202168461 100644 --- a/google/cloud/aiplatform_v1beta1/types/cached_content.py +++ b/google/cloud/aiplatform_v1beta1/types/cached_content.py @@ -20,6 +20,7 @@ import proto # type: ignore from google.cloud.aiplatform_v1beta1.types import content +from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform_v1beta1.types import tool from google.protobuf import duration_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore @@ -64,10 +65,10 @@ class CachedContent(proto.Message): Optional. Immutable. The user-generated meaningful display name of the cached content. model (str): - Immutable. The name of the publisher model to - use for cached content. Format: - - projects/{project}/locations/{location}/publishers/{publisher}/models/{model} + Immutable. The name of the ``Model`` to use for cached + content. Currently, only the published Gemini base models + are supported, in form of + projects/{PROJECT}/locations/{LOCATION}/publishers/google/models/{MODEL} system_instruction (google.cloud.aiplatform_v1beta1.types.Content): Optional. Input only. Immutable. Developer set system instruction. Currently, text only @@ -81,7 +82,7 @@ class CachedContent(proto.Message): Optional. Input only. Immutable. Tool config. This config is shared for all tools create_time (google.protobuf.timestamp_pb2.Timestamp): - Output only. Creatation time of the cache + Output only. Creation time of the cache entry. update_time (google.protobuf.timestamp_pb2.Timestamp): Output only. When the cache entry was last @@ -89,6 +90,10 @@ class CachedContent(proto.Message): usage_metadata (google.cloud.aiplatform_v1beta1.types.CachedContent.UsageMetadata): Output only. Metadata on the usage of the cached content. + encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec): + Input only. Immutable. Customer-managed encryption key spec + for a ``CachedContent``. If set, this ``CachedContent`` and + all its sub-resources will be secured by this key. """ class UsageMetadata(proto.Message): @@ -188,6 +193,11 @@ class UsageMetadata(proto.Message): number=12, message=UsageMetadata, ) + encryption_spec: gca_encryption_spec.EncryptionSpec = proto.Field( + proto.MESSAGE, + number=13, + message=gca_encryption_spec.EncryptionSpec, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/content.py b/google/cloud/aiplatform_v1beta1/types/content.py index 9da0b87d36..438ab47ec8 100644 --- a/google/cloud/aiplatform_v1beta1/types/content.py +++ b/google/cloud/aiplatform_v1beta1/types/content.py @@ -73,7 +73,8 @@ class HarmCategory(proto.Enum): The harm category is sexually explicit content. HARM_CATEGORY_CIVIC_INTEGRITY (5): - The harm category is civic integrity. + Deprecated: Election filter is not longer + supported. The harm category is civic integrity. """ HARM_CATEGORY_UNSPECIFIED = 0 HARM_CATEGORY_HATE_SPEECH = 1 @@ -474,6 +475,10 @@ class GenerationConfig(proto.Message): Optional. The speech generation config. This field is a member of `oneof`_ ``_speech_config``. + thinking_config (google.cloud.aiplatform_v1beta1.types.GenerationConfig.ThinkingConfig): + Optional. Config for thinking features. + An error will be returned if this field is set + for models that don't support thinking. model_config (google.cloud.aiplatform_v1beta1.types.GenerationConfig.ModelConfig): Optional. Config for model selection. """ @@ -612,6 +617,25 @@ class ManualRoutingMode(proto.Message): message="GenerationConfig.RoutingConfig.ManualRoutingMode", ) + class ThinkingConfig(proto.Message): + r"""Config for thinking features. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + thinking_budget (int): + Optional. Indicates the thinking budget in tokens. This is + only applied when enable_thinking is true. + + This field is a member of `oneof`_ ``_thinking_budget``. + """ + + thinking_budget: int = proto.Field( + proto.INT32, + number=3, + optional=True, + ) + class ModelConfig(proto.Message): r"""Config for model selection. @@ -736,6 +760,11 @@ class FeatureSelectionPreference(proto.Enum): optional=True, message="SpeechConfig", ) + thinking_config: ThinkingConfig = proto.Field( + proto.MESSAGE, + number=25, + message=ThinkingConfig, + ) model_config: ModelConfig = proto.Field( proto.MESSAGE, number=27, diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py index 4760a85c17..4acfa55464 100644 --- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py +++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py @@ -70,6 +70,7 @@ "AssessDataResponse", "AssessDataOperationMetadata", "GeminiTemplateConfig", + "GeminiRequestReadConfig", "GeminiExample", "AssembleDataRequest", "AssembleDataResponse", @@ -1179,6 +1180,9 @@ class AssessDataRequest(proto.Message): Required. The name of the Dataset resource. Used only for MULTIMODAL datasets. Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + gemini_request_read_config (google.cloud.aiplatform_v1beta1.types.GeminiRequestReadConfig): + Optional. The Gemini request read config for + the dataset. """ class TuningValidationAssessmentConfig(proto.Message): @@ -1302,6 +1306,11 @@ class BatchPredictionResourceUsageAssessmentConfig(proto.Message): proto.STRING, number=1, ) + gemini_request_read_config: "GeminiRequestReadConfig" = proto.Field( + proto.MESSAGE, + number=8, + message="GeminiRequestReadConfig", + ) class AssessDataResponse(proto.Message): @@ -1450,7 +1459,7 @@ class GeminiTemplateConfig(proto.Message): assembling the request to use for downstream applications. field_mapping (MutableMapping[str, str]): - Required. Map of template params to the + Required. Map of template parameters to the columns in the dataset table. """ @@ -1466,6 +1475,43 @@ class GeminiTemplateConfig(proto.Message): ) +class GeminiRequestReadConfig(proto.Message): + r"""Configuration for how to read Gemini requests from a + multimodal dataset. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + template_config (google.cloud.aiplatform_v1beta1.types.GeminiTemplateConfig): + Gemini request template with placeholders. + + This field is a member of `oneof`_ ``read_config``. + assembled_request_column_name (str): + Optional. Column name in the dataset table + that contains already fully assembled Gemini + requests. + + This field is a member of `oneof`_ ``read_config``. + """ + + template_config: "GeminiTemplateConfig" = proto.Field( + proto.MESSAGE, + number=1, + oneof="read_config", + message="GeminiTemplateConfig", + ) + assembled_request_column_name: str = proto.Field( + proto.STRING, + number=4, + oneof="read_config", + ) + + class GeminiExample(proto.Message): r"""Format for Gemini examples used for Vertex Multimodal datasets. @@ -1607,6 +1653,8 @@ class AssembleDataRequest(proto.Message): Required. The name of the Dataset resource (used only for MULTIMODAL datasets). Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + gemini_request_read_config (google.cloud.aiplatform_v1beta1.types.GeminiRequestReadConfig): + Optional. The read config for the dataset. """ gemini_template_config: "GeminiTemplateConfig" = proto.Field( @@ -1624,6 +1672,11 @@ class AssembleDataRequest(proto.Message): proto.STRING, number=1, ) + gemini_request_read_config: "GeminiRequestReadConfig" = proto.Field( + proto.MESSAGE, + number=6, + message="GeminiRequestReadConfig", + ) class AssembleDataResponse(proto.Message): diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py index f55a7fb21d..93e761fdfe 100644 --- a/google/cloud/aiplatform_v1beta1/types/io.py +++ b/google/cloud/aiplatform_v1beta1/types/io.py @@ -82,7 +82,7 @@ class GcsSource(proto.Message): Required. Google Cloud Storage URI(-s) to the input file(s). May contain wildcards. For more information on wildcards, see - https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + https://cloud.google.com/storage/docs/wildcards. """ uris: MutableSequence[str] = proto.RepeatedField( diff --git a/google/cloud/aiplatform_v1beta1/types/model_garden_service.py b/google/cloud/aiplatform_v1beta1/types/model_garden_service.py index 9b195a10c7..36da1c66c3 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_garden_service.py +++ b/google/cloud/aiplatform_v1beta1/types/model_garden_service.py @@ -356,6 +356,10 @@ class DeployConfig(proto.Message): fast_tryout_enabled (bool): Optional. If true, enable the QMT fast tryout feature for this model if possible. + system_labels (MutableMapping[str, str]): + Optional. System labels for Model Garden + deployments. These labels are managed by Google + and for tracking purposes only. """ dedicated_resources: machine_resources.DedicatedResources = proto.Field( @@ -367,6 +371,11 @@ class DeployConfig(proto.Message): proto.BOOL, number=2, ) + system_labels: MutableMapping[str, str] = proto.MapField( + proto.STRING, + proto.STRING, + number=3, + ) publisher_model_name: str = proto.Field( proto.STRING, diff --git a/google/cloud/aiplatform_v1beta1/types/model_monitoring_spec.py b/google/cloud/aiplatform_v1beta1/types/model_monitoring_spec.py index 4ee655ba70..fcab13f678 100644 --- a/google/cloud/aiplatform_v1beta1/types/model_monitoring_spec.py +++ b/google/cloud/aiplatform_v1beta1/types/model_monitoring_spec.py @@ -349,7 +349,7 @@ class ModelMonitoringGcsSource(proto.Message): Google Cloud Storage URI to the input file(s). May contain wildcards. For more information on wildcards, see - https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames. + https://cloud.google.com/storage/docs/wildcards. format_ (google.cloud.aiplatform_v1beta1.types.ModelMonitoringInput.ModelMonitoringDataset.ModelMonitoringGcsSource.DataFormat): Data format of the dataset. """ diff --git a/google/cloud/aiplatform_v1beta1/types/openapi.py b/google/cloud/aiplatform_v1beta1/types/openapi.py index c5f3916a4c..0c1300840f 100644 --- a/google/cloud/aiplatform_v1beta1/types/openapi.py +++ b/google/cloud/aiplatform_v1beta1/types/openapi.py @@ -135,6 +135,23 @@ class Schema(proto.Message): Optional. The value should be validated against any (one or more) of the subschemas in the list. + ref (str): + Optional. Allows indirect references between schema nodes. + The value should be a valid reference to a child of the root + ``defs``. + + For example, the following schema defines a reference to a + schema node named "Pet": + + type: object properties: pet: ref: #/defs/Pet defs: Pet: + type: object properties: name: type: string + + The value of the "pet" property is a reference to the schema + node named "Pet". See details in + https://json-schema.org/understanding-json-schema/structuring + defs (MutableMapping[str, google.cloud.aiplatform_v1beta1.types.Schema]): + Optional. A map of definitions for use by ``ref`` Only + allowed at the root of the schema. """ type_: "Type" = proto.Field( @@ -232,6 +249,16 @@ class Schema(proto.Message): number=11, message="Schema", ) + ref: str = proto.Field( + proto.STRING, + number=27, + ) + defs: MutableMapping[str, "Schema"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=28, + message="Schema", + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/session.py b/google/cloud/aiplatform_v1beta1/types/session.py index 75ee6c7263..17063ba983 100644 --- a/google/cloud/aiplatform_v1beta1/types/session.py +++ b/google/cloud/aiplatform_v1beta1/types/session.py @@ -238,8 +238,8 @@ class EventActions(proto.Message): updating an artifact. key is the filename, value is the version. transfer_to_agent (bool): - Optional. If set, the event transfers to the - specified agent. + Deprecated. If set, the event transfers to + the specified agent. escalate (bool): Optional. The agent is escalating to a higher level agent. @@ -251,6 +251,9 @@ class EventActions(proto.Message): multiple function calls. Struct value is the required auth config, which can be another struct. + transfer_agent (str): + Optional. If set, the event transfers to the + specified agent. """ skip_summarization: bool = proto.Field( @@ -280,6 +283,10 @@ class EventActions(proto.Message): number=7, message=struct_pb2.Struct, ) + transfer_agent: str = proto.Field( + proto.STRING, + number=8, + ) __all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/google/cloud/aiplatform_v1beta1/types/session_service.py b/google/cloud/aiplatform_v1beta1/types/session_service.py index 88243b1317..06aa0d576d 100644 --- a/google/cloud/aiplatform_v1beta1/types/session_service.py +++ b/google/cloud/aiplatform_v1beta1/types/session_service.py @@ -50,7 +50,6 @@ class CreateSessionRequest(proto.Message): parent (str): Required. The resource name of the location to create the session in. Format: - ``projects/{project}/locations/{location}`` or ``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`` session (google.cloud.aiplatform_v1beta1.types.Session): Required. The session to create. @@ -220,8 +219,6 @@ class DeleteSessionRequest(proto.Message): Attributes: name (str): Required. The resource name of the session. Format: - ``projects/{project}/locations/{location}/sessions/{session}`` - or ``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}`` """ @@ -244,7 +241,8 @@ class ListEventsRequest(proto.Message): Optional. The maximum number of events to return. The service may return fewer than this value. If unspecified, at most 100 events will - be returned. + be returned. These events are ordered by + timestamp in ascending order. page_token (str): Optional. The [next_page_token][google.cloud.aiplatform.v1beta1.ListEventsResponse.next_page_token] @@ -274,6 +272,7 @@ class ListEventsResponse(proto.Message): Attributes: session_events (MutableSequence[google.cloud.aiplatform_v1beta1.types.SessionEvent]): A list of events matching the request. + Ordered by timestamp in ascending order. next_page_token (str): A token, which can be sent as [ListEventsRequest.page_token][google.cloud.aiplatform.v1beta1.ListEventsRequest.page_token] diff --git a/owlbot.py b/owlbot.py index 0f6378b815..55c08974e6 100644 --- a/owlbot.py +++ b/owlbot.py @@ -117,6 +117,8 @@ ".kokoro/release.sh", ".kokoro/release/common.cfg", ".kokoro/requirements*", + ".kokoro/samples/python3.7/**", + ".kokoro/samples/python3.8/**", # exclude sample configs so periodic samples are tested against main # instead of pypi ".kokoro/samples/python3.9/common.cfg", @@ -127,6 +129,8 @@ ".kokoro/samples/python3.10/periodic.cfg", ".kokoro/samples/python3.11/periodic.cfg", ".kokoro/samples/python3.12/periodic.cfg", + ".github/auto-label.yaml", + ".github/blunderbuss.yml", ".github/CODEOWNERS", ".github/PULL_REQUEST_TEMPLATE.md", ".github/workflows", # exclude gh actions as credentials are needed for tests @@ -161,41 +165,4 @@ "python3", ) - # Update publish-docs to include gemini docs workflow. - s.replace( - ".kokoro/publish-docs.sh", - "# build docs", - """\ -# build Gemini docs -nox -s gemini_docs -# create metadata -python3 -m docuploader create-metadata \\ - --name="vertexai" \\ - --version=$(python3 setup.py --version) \\ - --language=$(jq --raw-output '.language // empty' .repo-metadata.json) \\ - --distribution-name="google-cloud-vertexai" \\ - --product-page=$(jq --raw-output '.product_documentation // empty' .repo-metadata.json) \\ - --github-repository=$(jq --raw-output '.repo // empty' .repo-metadata.json) \\ - --issue-tracker=$(jq --raw-output '.issue_tracker // empty' .repo-metadata.json) -cat docs.metadata -# upload docs -python3 -m docuploader upload gemini_docs/_build/html --metadata-file docs.metadata --staging-bucket "${STAGING_BUCKET}" -# Gemini docfx yaml files -nox -s gemini_docfx -# create metadata. -python3 -m docuploader create-metadata \\ - --name="vertexai" \\ - --version=$(python3 setup.py --version) \\ - --language=$(jq --raw-output '.language // empty' .repo-metadata.json) \\ - --distribution-name="google-cloud-vertexai" \\ - --product-page=$(jq --raw-output '.product_documentation // empty' .repo-metadata.json) \\ - --github-repository=$(jq --raw-output '.repo // empty' .repo-metadata.json) \\ - --issue-tracker=$(jq --raw-output '.issue_tracker // empty' .repo-metadata.json) \\ - --stem="/vertex-ai/generative-ai/docs/reference/python" -cat docs.metadata -# upload docs -python3 -m docuploader upload gemini_docs/_build/html/docfx_yaml --metadata-file docs.metadata --destination-prefix docfx --staging-bucket "${V2_STAGING_BUCKET}" -# build docs""", - ) - s.shell.run(["nox", "-s", "blacken"], hide_output=False) diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py index 9600dc46d6..83d974d1fc 100644 --- a/pypi/_vertex_ai_placeholder/version.py +++ b/pypi/_vertex_ai_placeholder/version.py @@ -15,4 +15,4 @@ # limitations under the License. # -__version__ = "1.91.0" +__version__ = "1.92.0" diff --git a/renovate.json b/renovate.json index 39b2a0ec92..c7875c469b 100644 --- a/renovate.json +++ b/renovate.json @@ -5,7 +5,7 @@ ":preserveSemverRanges", ":disableDependencyDashboard" ], - "ignorePaths": [".pre-commit-config.yaml", ".kokoro/requirements.txt", "setup.py"], + "ignorePaths": [".pre-commit-config.yaml", ".kokoro/requirements.txt", "setup.py", ".github/workflows/unittest.yml"], "pip_requirements": { "fileMatch": ["requirements-test.txt", "samples/[\\S/]*constraints.txt", "samples/[\\S/]*constraints-test.txt"] } diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json index d046841f64..1b98adb366 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.91.0" + "version": "1.92.0" }, "snippets": [ { diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json index ba930474b0..ef15cef83c 100644 --- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json +++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-cloud-aiplatform", - "version": "1.91.0" + "version": "1.92.0" }, "snippets": [ { diff --git a/samples/model-builder/noxfile.py b/samples/model-builder/noxfile.py index 483b559017..a169b5b5b4 100644 --- a/samples/model-builder/noxfile.py +++ b/samples/model-builder/noxfile.py @@ -89,7 +89,7 @@ def get_pytest_env_vars() -> Dict[str, str]: # DO NOT EDIT - automatically generated. # All versions used to test samples. -ALL_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] +ALL_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] # Any default versions that should be ignored. IGNORED_VERSIONS = TEST_CONFIG["ignored_versions"] diff --git a/samples/snippets/noxfile.py b/samples/snippets/noxfile.py index 483b559017..a169b5b5b4 100644 --- a/samples/snippets/noxfile.py +++ b/samples/snippets/noxfile.py @@ -89,7 +89,7 @@ def get_pytest_env_vars() -> Dict[str, str]: # DO NOT EDIT - automatically generated. # All versions used to test samples. -ALL_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] +ALL_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] # Any default versions that should be ignored. IGNORED_VERSIONS = TEST_CONFIG["ignored_versions"] diff --git a/setup.py b/setup.py index ba9bafc0b5..7ca103ce29 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,7 @@ prediction_extra_require = [ "docker >= 5.0.3", "fastapi >= 0.71.0, <=0.114.0", - "httpx >=0.23.0, <0.25.0", # Optional dependency of fastapi + "httpx >=0.23.0, <=0.28.1", # Optional dependency of fastapi "starlette >= 0.17.1", "uvicorn[standard] >= 0.16.0", ] @@ -103,7 +103,7 @@ ray_extra_require = [ # Cluster only supports 2.9.3, 2.33.0, and 2.42.0. Keep 2.4.0 for our # testing environment. - # Note that testing is submiting a job in a cluster with Ray 2.9.3 remotely. + # Note that testing is submitting a job in a cluster with Ray 2.9.3 remotely. ( "ray[default] >= 2.4, <= 2.42.0,!= 2.5.*,!= 2.6.*,!= 2.7.*,!=" " 2.8.*,!=2.9.0,!=2.9.1,!=2.9.2, !=2.10.*, !=2.11.*, !=2.12.*, !=2.13.*, !=" @@ -258,8 +258,9 @@ "scikit-learn<1.6.0; python_version<='3.10'", "scikit-learn; python_version>'3.10'", # Lazy import requires > 2.12.0 - "tensorflow == 2.13.0; python_version<='3.11'", - "tensorflow == 2.16.1; python_version>'3.11'", + "tensorflow == 2.14.1; python_version<='3.11'", + "tensorflow == 2.19.0; python_version>'3.11'", + "protobuf <= 5.29.4", # TODO(jayceeli) torch 2.1.0 has conflict with pyfakefs, will check if # future versions fix this issue "torch >= 2.0.0, < 2.1.0; python_version<='3.11'", @@ -304,6 +305,7 @@ "google-cloud-bigquery >= 1.15.0, < 4.0.0, !=3.20.0", "google-cloud-resource-manager >= 1.3.3, < 3.0.0", "shapely < 3.0.0", + "google-genai >= 1.0.0, <2.0.0", ) + genai_requires, extras_require={ diff --git a/testing/constraints-3.12.txt b/testing/constraints-3.12.txt index caa5b754bc..64a02b12c6 100644 --- a/testing/constraints-3.12.txt +++ b/testing/constraints-3.12.txt @@ -4,11 +4,11 @@ google-api-core==2.21.0 # Tests google-api-core with rest async support google-auth==2.35.0 # Tests google-auth with rest async support proto-plus -protobuf>=6 mock==4.0.2 google-cloud-storage==2.2.1 # Increased for kfp 2.0 compatibility packaging==24.1 # Increased to unbreak canonicalize_version error (b/377774673) pytest-xdist==3.3.1 # Pinned to unbreak unit tests ray==2.5.0 # Pinned until 2.9.3 is verified for Ray tests ipython==8.22.2 # Pinned to unbreak TypeAliasType import error -google-adk==0.0.2 \ No newline at end of file +google-adk==0.0.2 +google-genai>=1.10.0 \ No newline at end of file diff --git a/tests/system/aiplatform/test_persistent_resource.py b/tests/system/aiplatform/test_persistent_resource.py index 66268d1818..fc8cc3dd24 100644 --- a/tests/system/aiplatform/test_persistent_resource.py +++ b/tests/system/aiplatform/test_persistent_resource.py @@ -27,11 +27,15 @@ persistent_resource_v1 as gca_persistent_resource, ) from tests.system.aiplatform import e2e_base +from google.cloud.aiplatform.tests.unit.aiplatform import constants as test_constants import pytest _TEST_MACHINE_TYPE = "n1-standard-4" _TEST_INITIAL_REPLICA_COUNT = 2 +_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = ( + test_constants.ProjectConstants._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE +) @pytest.mark.usefixtures("tear_down_resources") @@ -59,7 +63,9 @@ def test_create_persistent_resource(self, shared_state): ] test_resource = persistent_resource.PersistentResource.create( - persistent_resource_id=resource_id, resource_pools=resource_pools + persistent_resource_id=resource_id, + resource_pools=resource_pools, + enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE, ) shared_state["resources"] = [test_resource] diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py index 8c3897141b..633659b449 100644 --- a/tests/unit/aiplatform/constants.py +++ b/tests/unit/aiplatform/constants.py @@ -59,6 +59,7 @@ class ProjectConstants: _TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" _TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com" _TEST_LABELS = {"my_key": "my_value"} + _TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = True @dataclasses.dataclass(frozen=True) diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py index a5d9fa08eb..14e27abbc4 100644 --- a/tests/unit/aiplatform/test_endpoints.py +++ b/tests/unit/aiplatform/test_endpoints.py @@ -97,6 +97,7 @@ _TEST_PROJECT_ALLOWLIST = [_TEST_PROJECT] _TEST_ENDPOINT_OVERRIDE = "endpoint-override.aiplatform.vertex.goog" +_TEST_SHARED_ENDPOINT_DNS = f"{_TEST_LOCATION}-aiplatform.googleapis.com" _TEST_DEDICATED_ENDPOINT_DNS = ( f"{_TEST_ID}.{_TEST_PROJECT}.{_TEST_LOCATION}-aiplatform.vertex.goog" ) @@ -2689,19 +2690,30 @@ def test_predict_dedicated_endpoint_with_timeout(self, predict_endpoint_http_moc ) @pytest.mark.usefixtures("get_endpoint_mock") - def test_predict_use_dedicated_endpoint_for_regular_endpoint(self): - test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + def test_predict_use_dedicated_endpoint_for_regular_endpoint( + self, predict_client_predict_mock + ): + test_endpoint = models.Endpoint(_TEST_ID) + test_prediction = test_endpoint.predict( + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + use_dedicated_endpoint=True, + ) - with pytest.raises(ValueError) as err: - test_endpoint.predict( - instances=_TEST_INSTANCES, - parameters={"param": 3.0}, - use_dedicated_endpoint=True, - ) - assert err.match( - regexp=r"Dedicated endpoint is not enabled or DNS is empty." - "Please make sure endpoint has dedicated endpoint enabled" - "and model are ready before making a prediction." + true_prediction = models.Prediction( + predictions=_TEST_PREDICTION, + deployed_model_id=_TEST_ID, + metadata=_TEST_METADATA, + model_version_id=_TEST_VERSION_ID, + model_resource_name=_TEST_MODEL_NAME, + ) + + assert true_prediction == test_prediction + predict_client_predict_mock.assert_called_once_with( + endpoint=_TEST_ENDPOINT_NAME, + instances=_TEST_INSTANCES, + parameters={"param": 3.0}, + timeout=None, ) @pytest.mark.usefixtures("get_dedicated_endpoint_mock") @@ -2768,19 +2780,35 @@ def test_raw_predict_dedicated_endpoint_with_timeout( ) @pytest.mark.usefixtures("get_endpoint_mock") - def test_raw_predict_use_dedicated_endpoint_for_regular_endpoint(self): - test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME) + def test_raw_predict_use_dedicated_endpoint_for_regular_endpoint( + self, predict_endpoint_http_mock + ): + test_endpoint = models.Endpoint(_TEST_ID) - with pytest.raises(ValueError) as err: - test_endpoint.raw_predict( - body=_TEST_RAW_INPUTS, - headers={"Content-Type": "application/json"}, - use_dedicated_endpoint=True, - ) - assert err.match( - regexp=r"Dedicated endpoint is not enabled or DNS is empty." - "Please make sure endpoint has dedicated endpoint enabled" - "and model are ready before making a prediction." + test_prediction = test_endpoint.raw_predict( + body=_TEST_RAW_INPUTS, + headers={"Content-Type": "application/json"}, + use_dedicated_endpoint=True, + ) + + true_prediction = requests.Response() + true_prediction.status_code = 200 + true_prediction._content = json.dumps( + { + "predictions": _TEST_PREDICTION, + "metadata": _TEST_METADATA, + "deployedModelId": _TEST_DEPLOYED_MODELS[0].id, + "model": _TEST_MODEL_NAME, + "modelVersionId": "1", + } + ).encode("utf-8") + assert true_prediction.status_code == test_prediction.status_code + assert true_prediction.text == test_prediction.text + predict_endpoint_http_mock.assert_called_once_with( + url=f"https://{_TEST_SHARED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict", + data=_TEST_RAW_INPUTS, + headers={"Content-Type": "application/json"}, + timeout=None, ) @pytest.mark.asyncio @@ -3658,7 +3686,7 @@ def test_psa_predict(self, predict_private_endpoint_mock): ) @pytest.mark.usefixtures("get_psc_private_endpoint_mock") - def test_psc_predict(self, predict_private_endpoint_mock): + def test_psc_predict(self, predict_endpoint_http_mock): test_endpoint = models.PrivateEndpoint( project=_TEST_PROJECT, location=_TEST_LOCATION, endpoint_name=_TEST_ID ) @@ -3677,14 +3705,10 @@ def test_psc_predict(self, predict_private_endpoint_mock): ) assert true_prediction == test_prediction - predict_private_endpoint_mock.assert_called_once_with( - method="POST", + predict_endpoint_http_mock.assert_called_once_with( url=f"https://{_TEST_ENDPOINT_OVERRIDE}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict", - body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}', - headers={ - "Content-Type": "application/json", - "Authorization": "Bearer None", - }, + data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}', + headers={"Content-Type": "application/json"}, ) @pytest.mark.usefixtures("get_psc_private_endpoint_mock") @@ -3697,7 +3721,6 @@ def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock): body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}', headers={ "Content-Type": "application/json", - "Authorization": "Bearer None", }, endpoint_override=_TEST_ENDPOINT_OVERRIDE, ) @@ -3709,7 +3732,6 @@ def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock): data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}', headers={ "Content-Type": "application/json", - "Authorization": "Bearer None", }, stream=True, verify=False, diff --git a/tests/unit/aiplatform/test_persistent_resource.py b/tests/unit/aiplatform/test_persistent_resource.py index 55c460b113..fe6e2bbd90 100644 --- a/tests/unit/aiplatform/test_persistent_resource.py +++ b/tests/unit/aiplatform/test_persistent_resource.py @@ -49,6 +49,10 @@ _TEST_RESERVED_IP_RANGES = test_constants.TrainingJobConstants._TEST_RESERVED_IP_RANGES _TEST_KEY_NAME = test_constants.TrainingJobConstants._TEST_DEFAULT_ENCRYPTION_KEY_NAME _TEST_SERVICE_ACCOUNT = test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT +_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = ( + test_constants.ProjectConstants._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE +) + _TEST_PERSISTENT_RESOURCE_PROTO = persistent_resource_v1.PersistentResource( name=_TEST_PERSISTENT_RESOURCE_ID, @@ -298,7 +302,7 @@ def test_create_persistent_resource_with_kms_key( ) @pytest.mark.parametrize("sync", [True, False]) - def test_create_persistent_resource_with_service_account( + def test_create_persistent_resource_enable_custom_sa_true_with_sa( self, create_persistent_resource_mock, get_persistent_resource_mock, @@ -309,6 +313,7 @@ def test_create_persistent_resource_with_service_account( resource_pools=[ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, ], + enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE, service_account=_TEST_SERVICE_ACCOUNT, sync=sync, ) @@ -321,7 +326,8 @@ def test_create_persistent_resource_with_service_account( ) service_account_spec = persistent_resource_v1.ServiceAccountSpec( - enable_custom_service_account=True, service_account=_TEST_SERVICE_ACCOUNT + enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE, + service_account=_TEST_SERVICE_ACCOUNT, ) expected_persistent_resource_arg.resource_runtime_spec = ( persistent_resource_v1.ResourceRuntimeSpec( @@ -341,6 +347,164 @@ def test_create_persistent_resource_with_service_account( name=_TEST_PERSISTENT_RESOURCE_ID ) + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_enable_custom_sa_true_no_sa( + self, + create_persistent_resource_mock, + get_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + enable_custom_service_account=True, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + service_account_spec = persistent_resource_v1.ServiceAccountSpec( + enable_custom_service_account=True, + service_account=None, + ) + expected_persistent_resource_arg.resource_runtime_spec = ( + persistent_resource_v1.ResourceRuntimeSpec( + service_account_spec=service_account_spec + ) + ) + + create_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_enable_custom_sa_false_raises_error( + self, + create_persistent_resource_mock, + get_persistent_resource_mock, + sync, + ): + with pytest.raises(ValueError) as excinfo: + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + enable_custom_service_account=False, + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + ) + if not sync: + my_test_resource.wait() + + assert str(excinfo.value) == ( + "The parameter `enable_custom_service_account` was set to False, " + "but a value was provided for `service_account`. These two " + "settings are incompatible. If you want to use a custom " + "service account, set `enable_custom_service_account` to True." + ) + + create_persistent_resource_mock.assert_not_called() + get_persistent_resource_mock.assert_not_called() + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_enable_custom_sa_none_with_sa( + self, + create_persistent_resource_mock, + get_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + enable_custom_service_account=None, + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + service_account_spec = persistent_resource_v1.ServiceAccountSpec( + enable_custom_service_account=True, + service_account=_TEST_SERVICE_ACCOUNT, + ) + expected_persistent_resource_arg.resource_runtime_spec = ( + persistent_resource_v1.ResourceRuntimeSpec( + service_account_spec=service_account_spec + ) + ) + + create_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_enable_custom_sa_none_no_sa( + self, + create_persistent_resource_mock, + get_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + enable_custom_service_account=None, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + + # Assert that resource_runtime_spec is NOT set + call_args = create_persistent_resource_mock.call_args.kwargs + assert "resource_runtime_spec" not in call_args["persistent_resource"] + + create_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + def test_list_persistent_resources(self, list_persistent_resources_mock): resource_list = persistent_resource.PersistentResource.list() diff --git a/tests/unit/aiplatform/test_persistent_resource_preview.py b/tests/unit/aiplatform/test_persistent_resource_preview.py index f189af2a8b..d45bf51816 100644 --- a/tests/unit/aiplatform/test_persistent_resource_preview.py +++ b/tests/unit/aiplatform/test_persistent_resource_preview.py @@ -56,6 +56,9 @@ _TEST_RESERVED_IP_RANGES = test_constants.TrainingJobConstants._TEST_RESERVED_IP_RANGES _TEST_KEY_NAME = test_constants.TrainingJobConstants._TEST_DEFAULT_ENCRYPTION_KEY_NAME _TEST_SERVICE_ACCOUNT = test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT +_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = ( + test_constants.ProjectConstants._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE +) _TEST_PERSISTENT_RESOURCE_PROTO = persistent_resource_compat.PersistentResource( name=_TEST_PERSISTENT_RESOURCE_ID, @@ -291,7 +294,7 @@ def test_create_persistent_resource_with_kms_key( ) @pytest.mark.parametrize("sync", [True, False]) - def test_create_persistent_resource_with_service_account( + def test_create_persistent_resource_enable_custom_sa_true_with_sa( self, create_preview_persistent_resource_mock, get_persistent_resource_mock, @@ -302,6 +305,7 @@ def test_create_persistent_resource_with_service_account( resource_pools=[ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, ], + enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE, service_account=_TEST_SERVICE_ACCOUNT, sync=sync, ) @@ -312,9 +316,53 @@ def test_create_persistent_resource_with_service_account( expected_persistent_resource_arg = _get_persistent_resource_proto( name=_TEST_PERSISTENT_RESOURCE_ID, ) + service_account_spec = persistent_resource_compat.ServiceAccountSpec( + enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE, + service_account=_TEST_SERVICE_ACCOUNT, + ) + expected_persistent_resource_arg.resource_runtime_spec = ( + persistent_resource_compat.ResourceRuntimeSpec( + service_account_spec=service_account_spec + ) + ) + + create_preview_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_enable_custom_sa_true_no_sa( + self, + create_preview_persistent_resource_mock, + get_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + enable_custom_service_account=True, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) service_account_spec = persistent_resource_compat.ServiceAccountSpec( - enable_custom_service_account=True, service_account=_TEST_SERVICE_ACCOUNT + enable_custom_service_account=True, + service_account=None, ) expected_persistent_resource_arg.resource_runtime_spec = ( persistent_resource_compat.ResourceRuntimeSpec( @@ -334,6 +382,120 @@ def test_create_persistent_resource_with_service_account( name=_TEST_PERSISTENT_RESOURCE_ID ) + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_enable_custom_sa_false_with_sa_raises_error( + self, + create_preview_persistent_resource_mock, + get_persistent_resource_mock, + sync, + ): + with pytest.raises(ValueError) as excinfo: + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + enable_custom_service_account=False, + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + ) + if not sync: + my_test_resource.wait() + + assert str(excinfo.value) == ( + "The parameter `enable_custom_service_account` was set to False, " + "but a value was provided for `service_account`. These two " + "settings are incompatible. If you want to use a custom " + "service account, set `enable_custom_service_account` to True." + ) + + create_preview_persistent_resource_mock.assert_not_called() + get_persistent_resource_mock.assert_not_called() + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_enable_custom_sa_none_with_sa( + self, + create_preview_persistent_resource_mock, + get_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + enable_custom_service_account=None, + service_account=_TEST_SERVICE_ACCOUNT, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + service_account_spec = persistent_resource_compat.ServiceAccountSpec( + enable_custom_service_account=True, + service_account=_TEST_SERVICE_ACCOUNT, + ) + expected_persistent_resource_arg.resource_runtime_spec = ( + persistent_resource_compat.ResourceRuntimeSpec( + service_account_spec=service_account_spec + ) + ) + + create_preview_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + + @pytest.mark.parametrize("sync", [True, False]) + def test_create_persistent_resource_enable_custom_sa_none_no_sa( + self, + create_preview_persistent_resource_mock, + get_persistent_resource_mock, + sync, + ): + my_test_resource = persistent_resource.PersistentResource.create( + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + resource_pools=[ + test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL, + ], + enable_custom_service_account=None, + sync=sync, + ) + + if not sync: + my_test_resource.wait() + + expected_persistent_resource_arg = _get_persistent_resource_proto( + name=_TEST_PERSISTENT_RESOURCE_ID, + ) + + # Assert that resource_runtime_spec is NOT set + call_args = create_preview_persistent_resource_mock.call_args.kwargs + assert "resource_runtime_spec" not in call_args["persistent_resource"] + + create_preview_persistent_resource_mock.assert_called_once_with( + parent=_TEST_PARENT, + persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID, + persistent_resource=expected_persistent_resource_arg, + timeout=None, + ) + get_persistent_resource_mock.assert_called_once() + _, mock_kwargs = get_persistent_resource_mock.call_args + assert mock_kwargs["name"] == _get_resource_name( + name=_TEST_PERSISTENT_RESOURCE_ID + ) + def test_list_persistent_resources(self, list_persistent_resources_mock): resource_list = persistent_resource.PersistentResource.list() diff --git a/tests/unit/gapic/aiplatform_v1/test_gen_ai_cache_service.py b/tests/unit/gapic/aiplatform_v1/test_gen_ai_cache_service.py index ce2fe8eddd..20effd6e03 100644 --- a/tests/unit/gapic/aiplatform_v1/test_gen_ai_cache_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_gen_ai_cache_service.py @@ -73,6 +73,7 @@ from google.cloud.aiplatform_v1.types import cached_content from google.cloud.aiplatform_v1.types import cached_content as gca_cached_content from google.cloud.aiplatform_v1.types import content +from google.cloud.aiplatform_v1.types import encryption_spec from google.cloud.aiplatform_v1.types import gen_ai_cache_service from google.cloud.aiplatform_v1.types import openapi from google.cloud.aiplatform_v1.types import tool @@ -4665,6 +4666,8 @@ def test_create_cached_content_rest_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, } @@ -4731,6 +4734,7 @@ def test_create_cached_content_rest_call_success(request_type): "video_duration_seconds": 2346, "audio_duration_seconds": 2341, }, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -5144,6 +5148,8 @@ def test_update_cached_content_rest_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, } @@ -5210,6 +5216,7 @@ def test_update_cached_content_rest_call_success(request_type): "video_duration_seconds": 2346, "audio_duration_seconds": 2341, }, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -6485,6 +6492,8 @@ async def test_create_cached_content_rest_asyncio_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, } @@ -6551,6 +6560,7 @@ async def test_create_cached_content_rest_asyncio_call_success(request_type): "video_duration_seconds": 2346, "audio_duration_seconds": 2341, }, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -6996,6 +7006,8 @@ async def test_update_cached_content_rest_asyncio_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, } @@ -7062,6 +7074,7 @@ async def test_update_cached_content_rest_asyncio_call_success(request_type): "video_duration_seconds": 2346, "audio_duration_seconds": 2341, }, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py index 9e8dced432..f0c5ab4866 100644 --- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py @@ -5398,22 +5398,19 @@ def test_parse_annotated_dataset_path(): def test_dataset_path(): project = "cuttlefish" - location = "mussel" - dataset = "winkle" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + dataset = "mussel" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "nautilus", - "location": "scallop", - "dataset": "abalone", + "project": "winkle", + "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) @@ -5424,18 +5421,21 @@ def test_parse_dataset_path(): def test_dataset_path(): project = "squid" - dataset = "clam" - expected = "projects/{project}/datasets/{dataset}".format( + location = "clam" + dataset = "whelk" + expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( project=project, + location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, dataset) + actual = MigrationServiceClient.dataset_path(project, location, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "whelk", + "project": "clam", + "location": "whelk", "dataset": "octopus", } path = MigrationServiceClient.dataset_path(**expected) @@ -5446,22 +5446,19 @@ def test_parse_dataset_path(): def test_dataset_path(): - project = "oyster" - location = "nudibranch" - dataset = "cuttlefish" - expected = "projects/{project}/locations/{location}/datasets/{dataset}".format( + project = "cuttlefish" + dataset = "mussel" + expected = "projects/{project}/datasets/{dataset}".format( project=project, - location=location, dataset=dataset, ) - actual = MigrationServiceClient.dataset_path(project, location, dataset) + actual = MigrationServiceClient.dataset_path(project, dataset) assert expected == actual def test_parse_dataset_path(): expected = { - "project": "mussel", - "location": "winkle", + "project": "winkle", "dataset": "nautilus", } path = MigrationServiceClient.dataset_path(**expected) diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py index ea52d506db..4a299dca9c 100644 --- a/tests/unit/gapic/aiplatform_v1/test_model_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py @@ -14989,6 +14989,9 @@ def test_update_model_rest_call_success(request_type): }, "satisfies_pzs": True, "satisfies_pzi": True, + "checkpoints": [ + {"checkpoint_id": "checkpoint_id_value", "epoch": 527, "step": 444} + ], } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -18970,6 +18973,9 @@ async def test_update_model_rest_asyncio_call_success(request_type): }, "satisfies_pzs": True, "satisfies_pzi": True, + "checkpoints": [ + {"checkpoint_id": "checkpoint_id_value", "epoch": 527, "step": 444} + ], } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py index 9bf770ce4e..d60c2e7eb5 100644 --- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py +++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py @@ -9226,6 +9226,9 @@ def test_create_training_pipeline_rest_call_success(request_type): }, "satisfies_pzs": True, "satisfies_pzi": True, + "checkpoints": [ + {"checkpoint_id": "checkpoint_id_value", "epoch": 527, "step": 444} + ], }, "model_id": "model_id_value", "parent_model": "parent_model_value", @@ -12161,6 +12164,9 @@ async def test_create_training_pipeline_rest_asyncio_call_success(request_type): }, "satisfies_pzs": True, "satisfies_pzi": True, + "checkpoints": [ + {"checkpoint_id": "checkpoint_id_value", "epoch": 527, "step": 444} + ], }, "model_id": "model_id_value", "parent_model": "parent_model_value", diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py index 5123fc4e2d..0288684a06 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py @@ -4577,6 +4577,8 @@ def test_import_extension_rest_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, }, @@ -5153,6 +5155,8 @@ def test_update_extension_rest_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, }, @@ -6359,6 +6363,8 @@ async def test_import_extension_rest_asyncio_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, }, @@ -6983,6 +6989,8 @@ async def test_update_extension_rest_asyncio_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, }, diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py index 5ad43802b0..13ad36d108 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py @@ -73,6 +73,7 @@ from google.cloud.aiplatform_v1beta1.types import cached_content from google.cloud.aiplatform_v1beta1.types import cached_content as gca_cached_content from google.cloud.aiplatform_v1beta1.types import content +from google.cloud.aiplatform_v1beta1.types import encryption_spec from google.cloud.aiplatform_v1beta1.types import gen_ai_cache_service from google.cloud.aiplatform_v1beta1.types import openapi from google.cloud.aiplatform_v1beta1.types import tool @@ -4674,6 +4675,8 @@ def test_create_cached_content_rest_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, } @@ -4742,6 +4745,7 @@ def test_create_cached_content_rest_call_success(request_type): "video_duration_seconds": 2346, "audio_duration_seconds": 2341, }, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -5164,6 +5168,8 @@ def test_update_cached_content_rest_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, } @@ -5232,6 +5238,7 @@ def test_update_cached_content_rest_call_success(request_type): "video_duration_seconds": 2346, "audio_duration_seconds": 2341, }, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -6516,6 +6523,8 @@ async def test_create_cached_content_rest_asyncio_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, } @@ -6584,6 +6593,7 @@ async def test_create_cached_content_rest_asyncio_call_success(request_type): "video_duration_seconds": 2346, "audio_duration_seconds": 2341, }, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency @@ -7038,6 +7048,8 @@ async def test_update_cached_content_rest_asyncio_call_success(request_type): "pattern": "pattern_value", "example": {}, "any_of": {}, + "ref": "ref_value", + "defs": {}, }, "response": {}, } @@ -7106,6 +7118,7 @@ async def test_update_cached_content_rest_asyncio_call_success(request_type): "video_duration_seconds": 2346, "audio_duration_seconds": 2341, }, + "encryption_spec": {"kms_key_name": "kms_key_name_value"}, } # The version of a generated dependency at test runtime may differ from the version used during generation. # Delete any fields which are not present in the current runtime dependency diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_session_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_session_service.py index 99c757d313..66289a380b 100644 --- a/tests/unit/gapic/aiplatform_v1beta1/test_session_service.py +++ b/tests/unit/gapic/aiplatform_v1beta1/test_session_service.py @@ -3980,7 +3980,9 @@ def test_create_session_rest_flattened(): return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } # get truthy value for each flattened field mock_args = dict( @@ -4004,7 +4006,7 @@ def test_create_session_rest_flattened(): assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/sessions" + "%s/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*}/sessions" % client.transport._host, args[1], ) @@ -4158,7 +4160,9 @@ def test_get_session_rest_flattened(): return_value = session.Session() # get arguments that satisfy an http rule for this method - sample_request = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + sample_request = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } # get truthy value for each flattened field mock_args = dict( @@ -4183,7 +4187,7 @@ def test_get_session_rest_flattened(): assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/sessions/*}" + "%s/v1beta1/{name=projects/*/locations/*/reasoningEngines/*/sessions/*}" % client.transport._host, args[1], ) @@ -4355,7 +4359,9 @@ def test_list_sessions_rest_flattened(): return_value = session_service.ListSessionsResponse() # get arguments that satisfy an http rule for this method - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } # get truthy value for each flattened field mock_args = dict( @@ -4380,7 +4386,7 @@ def test_list_sessions_rest_flattened(): assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{parent=projects/*/locations/*}/sessions" + "%s/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*}/sessions" % client.transport._host, args[1], ) @@ -4451,7 +4457,9 @@ def test_list_sessions_rest_pager(transport: str = "rest"): return_val.status_code = 200 req.side_effect = return_values - sample_request = {"parent": "projects/sample1/locations/sample2"} + sample_request = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } pager = client.list_sessions(request=sample_request) @@ -4595,7 +4603,9 @@ def test_update_session_rest_flattened(): # get arguments that satisfy an http rule for this method sample_request = { - "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"} + "session": { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } } # get truthy value for each flattened field @@ -4622,7 +4632,7 @@ def test_update_session_rest_flattened(): assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{session.name=projects/*/locations/*/sessions/*}" + "%s/v1beta1/{session.name=projects/*/locations/*/reasoningEngines/*/sessions/*}" % client.transport._host, args[1], ) @@ -4777,7 +4787,9 @@ def test_delete_session_rest_flattened(): return_value = operations_pb2.Operation(name="operations/spam") # get arguments that satisfy an http rule for this method - sample_request = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + sample_request = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } # get truthy value for each flattened field mock_args = dict( @@ -4800,7 +4812,7 @@ def test_delete_session_rest_flattened(): assert len(req.mock_calls) == 1 _, args, _ = req.mock_calls[0] assert path_template.validate( - "%s/v1beta1/{name=projects/*/locations/*/sessions/*}" + "%s/v1beta1/{name=projects/*/locations/*/reasoningEngines/*/sessions/*}" % client.transport._host, args[1], ) @@ -5740,7 +5752,9 @@ def test_create_session_rest_bad_request( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -5771,7 +5785,9 @@ def test_create_session_rest_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } request_init["session"] = { "name": "name_value", "create_time": {"seconds": 751, "nanos": 543}, @@ -5936,7 +5952,9 @@ def test_get_session_rest_bad_request(request_type=session_service.GetSessionReq credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -5967,7 +5985,9 @@ def test_get_session_rest_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -6067,7 +6087,9 @@ def test_list_sessions_rest_bad_request( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -6098,7 +6120,9 @@ def test_list_sessions_rest_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -6200,7 +6224,9 @@ def test_update_session_rest_bad_request( ) # send a request that will satisfy transcoding request_init = { - "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"} + "session": { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } } request = request_type(**request_init) @@ -6233,10 +6259,12 @@ def test_update_session_rest_call_success(request_type): # send a request that will satisfy transcoding request_init = { - "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"} + "session": { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } } request_init["session"] = { - "name": "projects/sample1/locations/sample2/sessions/sample3", + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4", "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, "display_name": "display_name_value", @@ -6409,7 +6437,9 @@ def test_delete_session_rest_bad_request( credentials=ga_credentials.AnonymousCredentials(), transport="rest" ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -6440,7 +6470,9 @@ def test_delete_session_rest_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -6741,6 +6773,7 @@ def test_append_event_rest_call_success(request_type): "transfer_to_agent": True, "escalate": True, "requested_auth_configs": {}, + "transfer_agent": "transfer_agent_value", }, "timestamp": {"seconds": 751, "nanos": 543}, "error_code": "error_code_value", @@ -7766,7 +7799,9 @@ async def test_create_session_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -7801,7 +7836,9 @@ async def test_create_session_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } request_init["session"] = { "name": "name_value", "create_time": {"seconds": 751, "nanos": 543}, @@ -7981,7 +8018,9 @@ async def test_get_session_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8016,7 +8055,9 @@ async def test_get_session_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -8128,7 +8169,9 @@ async def test_list_sessions_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8163,7 +8206,9 @@ async def test_list_sessions_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"parent": "projects/sample1/locations/sample2"} + request_init = { + "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -8278,7 +8323,9 @@ async def test_update_session_rest_asyncio_bad_request( ) # send a request that will satisfy transcoding request_init = { - "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"} + "session": { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } } request = request_type(**request_init) @@ -8315,10 +8362,12 @@ async def test_update_session_rest_asyncio_call_success(request_type): # send a request that will satisfy transcoding request_init = { - "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"} + "session": { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } } request_init["session"] = { - "name": "projects/sample1/locations/sample2/sessions/sample3", + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4", "create_time": {"seconds": 751, "nanos": 543}, "update_time": {}, "display_name": "display_name_value", @@ -8504,7 +8553,9 @@ async def test_delete_session_rest_asyncio_bad_request( credentials=async_anonymous_credentials(), transport="rest_asyncio" ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a BadRequest error. @@ -8539,7 +8590,9 @@ async def test_delete_session_rest_asyncio_call_success(request_type): ) # send a request that will satisfy transcoding - request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"} + request_init = { + "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4" + } request = request_type(**request_init) # Mock the http request call within the method and fake a response. @@ -8877,6 +8930,7 @@ async def test_append_event_rest_asyncio_call_success(request_type): "transfer_to_agent": True, "escalate": True, "requested_auth_configs": {}, + "transfer_agent": "transfer_agent_value", }, "timestamp": {"seconds": 751, "nanos": 543}, "error_code": "error_code_value", @@ -10554,21 +10608,26 @@ def test_session_service_grpc_lro_async_client(): def test_session_path(): project = "squid" location = "clam" - session = "whelk" - expected = "projects/{project}/locations/{location}/sessions/{session}".format( + reasoning_engine = "whelk" + session = "octopus" + expected = "projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}".format( project=project, location=location, + reasoning_engine=reasoning_engine, session=session, ) - actual = SessionServiceClient.session_path(project, location, session) + actual = SessionServiceClient.session_path( + project, location, reasoning_engine, session + ) assert expected == actual def test_parse_session_path(): expected = { - "project": "octopus", - "location": "oyster", - "session": "nudibranch", + "project": "oyster", + "location": "nudibranch", + "reasoning_engine": "cuttlefish", + "session": "mussel", } path = SessionServiceClient.session_path(**expected) @@ -10578,26 +10637,31 @@ def test_parse_session_path(): def test_session_event_path(): - project = "cuttlefish" - location = "mussel" - session = "winkle" - event = "nautilus" - expected = "projects/{project}/locations/{location}/sessions/{session}/events/{event}".format( + project = "winkle" + location = "nautilus" + reasoning_engine = "scallop" + session = "abalone" + event = "squid" + expected = "projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}/events/{event}".format( project=project, location=location, + reasoning_engine=reasoning_engine, session=session, event=event, ) - actual = SessionServiceClient.session_event_path(project, location, session, event) + actual = SessionServiceClient.session_event_path( + project, location, reasoning_engine, session, event + ) assert expected == actual def test_parse_session_event_path(): expected = { - "project": "scallop", - "location": "abalone", - "session": "squid", - "event": "clam", + "project": "clam", + "location": "whelk", + "reasoning_engine": "octopus", + "session": "oyster", + "event": "nudibranch", } path = SessionServiceClient.session_event_path(**expected) @@ -10607,7 +10671,7 @@ def test_parse_session_event_path(): def test_common_billing_account_path(): - billing_account = "whelk" + billing_account = "cuttlefish" expected = "billingAccounts/{billing_account}".format( billing_account=billing_account, ) @@ -10617,7 +10681,7 @@ def test_common_billing_account_path(): def test_parse_common_billing_account_path(): expected = { - "billing_account": "octopus", + "billing_account": "mussel", } path = SessionServiceClient.common_billing_account_path(**expected) @@ -10627,7 +10691,7 @@ def test_parse_common_billing_account_path(): def test_common_folder_path(): - folder = "oyster" + folder = "winkle" expected = "folders/{folder}".format( folder=folder, ) @@ -10637,7 +10701,7 @@ def test_common_folder_path(): def test_parse_common_folder_path(): expected = { - "folder": "nudibranch", + "folder": "nautilus", } path = SessionServiceClient.common_folder_path(**expected) @@ -10647,7 +10711,7 @@ def test_parse_common_folder_path(): def test_common_organization_path(): - organization = "cuttlefish" + organization = "scallop" expected = "organizations/{organization}".format( organization=organization, ) @@ -10657,7 +10721,7 @@ def test_common_organization_path(): def test_parse_common_organization_path(): expected = { - "organization": "mussel", + "organization": "abalone", } path = SessionServiceClient.common_organization_path(**expected) @@ -10667,7 +10731,7 @@ def test_parse_common_organization_path(): def test_common_project_path(): - project = "winkle" + project = "squid" expected = "projects/{project}".format( project=project, ) @@ -10677,7 +10741,7 @@ def test_common_project_path(): def test_parse_common_project_path(): expected = { - "project": "nautilus", + "project": "clam", } path = SessionServiceClient.common_project_path(**expected) @@ -10687,8 +10751,8 @@ def test_parse_common_project_path(): def test_common_location_path(): - project = "scallop" - location = "abalone" + project = "whelk" + location = "octopus" expected = "projects/{project}/locations/{location}".format( project=project, location=location, @@ -10699,8 +10763,8 @@ def test_common_location_path(): def test_parse_common_location_path(): expected = { - "project": "squid", - "location": "clam", + "project": "oyster", + "location": "nudibranch", } path = SessionServiceClient.common_location_path(**expected) diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py index e576cc9120..46452de32e 100644 --- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py @@ -17,6 +17,7 @@ from unittest import mock from google import auth +from google.genai import types import vertexai from google.cloud.aiplatform import initializer from vertexai.preview import reasoning_engines @@ -172,6 +173,28 @@ def test_stream_query(self): ) assert len(events) == 1 + def test_stream_query_with_content(self): + app = reasoning_engines.AdkApp( + agent=Agent(name="test_agent", model=_TEST_MODEL) + ) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = list( + app.stream_query( + user_id="test_user_id", + message=types.Content( + role="user", + parts=[ + types.Part( + text="test message with content", + ) + ], + ).model_dump(), + ) + ) + assert len(events) == 1 + def test_streaming_agent_run_with_events(self): app = reasoning_engines.AdkApp( agent=Agent(name="test_agent", model=_TEST_MODEL) diff --git a/tests/unit/vertex_langchain/test_agent_engine_templates_module.py b/tests/unit/vertex_langchain/test_agent_engine_templates_module.py new file mode 100644 index 0000000000..20b9a0b61f --- /dev/null +++ b/tests/unit/vertex_langchain/test_agent_engine_templates_module.py @@ -0,0 +1,89 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from vertexai import agent_engines +from test_constants import test_agent + +_TEST_MODULE_NAME = "test_constants" +_TEST_AGENT_NAME = "test_agent" +_TEST_REGISTER_OPERATIONS = {"": ["query"], "stream": ["stream_query"]} +_TEST_QUERY_INPUT = "test query" +_TEST_STREAM_QUERY_INPUT = 5 + + +class TestModuleAgent: + def test_initialization(self): + agent = agent_engines.ModuleAgent( + module_name=_TEST_MODULE_NAME, + agent_name=_TEST_AGENT_NAME, + register_operations=_TEST_REGISTER_OPERATIONS, + ) + assert agent._tmpl_attrs.get("module_name") == _TEST_MODULE_NAME + assert agent._tmpl_attrs.get("agent_name") == _TEST_AGENT_NAME + assert agent._tmpl_attrs.get("register_operations") == _TEST_REGISTER_OPERATIONS + + def test_set_up(self): + agent = agent_engines.ModuleAgent( + module_name=_TEST_MODULE_NAME, + agent_name=_TEST_AGENT_NAME, + register_operations=_TEST_REGISTER_OPERATIONS, + ) + assert agent._tmpl_attrs.get("agent") is None + agent.set_up() + assert agent._tmpl_attrs.get("agent") is not None + + def test_clone(self): + agent = agent_engines.ModuleAgent( + module_name=_TEST_MODULE_NAME, + agent_name=_TEST_AGENT_NAME, + register_operations=_TEST_REGISTER_OPERATIONS, + ) + agent.set_up() + assert agent._tmpl_attrs.get("agent") is not None + agent_clone = agent.clone() + assert agent._tmpl_attrs.get("agent") is not None + assert agent_clone._tmpl_attrs.get("agent") is None + agent_clone.set_up() + assert agent_clone._tmpl_attrs.get("agent") is not None + + def test_query(self): + agent = agent_engines.ModuleAgent( + module_name=_TEST_MODULE_NAME, + agent_name=_TEST_AGENT_NAME, + register_operations=_TEST_REGISTER_OPERATIONS, + ) + agent.set_up() + got_result = agent.query(input=_TEST_QUERY_INPUT) + expected_result = agent._tmpl_attrs.get("agent").query(input=_TEST_QUERY_INPUT) + assert got_result == expected_result + expected_result = test_agent.query(input=_TEST_QUERY_INPUT) + assert got_result == expected_result + + def test_stream_query(self): + agent = agent_engines.ModuleAgent( + module_name=_TEST_MODULE_NAME, + agent_name=_TEST_AGENT_NAME, + register_operations=_TEST_REGISTER_OPERATIONS, + ) + agent.set_up() + for got_result, expected_result in zip( + agent.stream_query(n=_TEST_STREAM_QUERY_INPUT), + agent._tmpl_attrs.get("agent").stream_query(n=_TEST_STREAM_QUERY_INPUT), + ): + assert got_result == expected_result + for got_result, expected_result in zip( + agent.stream_query(n=_TEST_STREAM_QUERY_INPUT), + test_agent.stream_query(n=_TEST_STREAM_QUERY_INPUT), + ): + assert got_result == expected_result diff --git a/tests/unit/vertex_langchain/test_agent_engines.py b/tests/unit/vertex_langchain/test_agent_engines.py index 8e66529f57..e58d878635 100644 --- a/tests/unit/vertex_langchain/test_agent_engines.py +++ b/tests/unit/vertex_langchain/test_agent_engines.py @@ -2315,6 +2315,47 @@ def test_invalid_operation_schema( assert want_log_output in caplog.text +@pytest.mark.usefixtures("google_auth_mock") +class TestLightweightAgentEngine: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + self.test_agent = CapitalizeEngine() + + def test_create_agent_engine_with_no_spec( + self, + create_agent_engine_mock, + cloud_storage_create_bucket_mock, + tarfile_open_mock, + cloudpickle_dump_mock, + cloudpickle_load_mock, + importlib_metadata_version_mock, + get_agent_engine_mock, + ): + importlib.reload(initializer) + importlib.reload(aiplatform) + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + agent_engines.create( + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + description=_TEST_AGENT_ENGINE_DESCRIPTION, + ) + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_STAGING_BUCKET, + ) + + def _generate_agent_engine_to_update() -> "agent_engines.AgentEngine": test_agent_engine = agent_engines.create(CapitalizeEngine()) # Resource name is required for the update method. diff --git a/tests/unit/vertex_langchain/test_constants.py b/tests/unit/vertex_langchain/test_constants.py new file mode 100644 index 0000000000..7d66f2dec2 --- /dev/null +++ b/tests/unit/vertex_langchain/test_constants.py @@ -0,0 +1,24 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +class _CustomAgent: + def query(self, input: str): + return input + + def stream_query(self, n: int): + for i in range(n): + yield i + + +test_agent = _CustomAgent() diff --git a/tests/unit/vertexai/genai/test_genai_client.py b/tests/unit/vertexai/genai/test_genai_client.py new file mode 100644 index 0000000000..7732bb3496 --- /dev/null +++ b/tests/unit/vertexai/genai/test_genai_client.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pylint: disable=protected-access,bad-continuation +import pytest +import importlib +from unittest import mock +from google.cloud import aiplatform +import vertexai +from google.cloud.aiplatform import initializer as aiplatform_initializer + +from vertexai import _genai + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +class TestGenAiClient: + """Unit tests for the GenAI client.""" + + def setup_method(self): + importlib.reload(aiplatform_initializer) + importlib.reload(aiplatform) + importlib.reload(vertexai) + vertexai.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + @pytest.mark.usefixtures("google_auth_mock") + def test_genai_client(self): + test_client = _genai.client.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) + assert test_client is not None + assert test_client._api_client.vertexai + assert test_client._api_client.project == _TEST_PROJECT + assert test_client._api_client.location == _TEST_LOCATION + + @pytest.mark.usefixtures("google_auth_mock") + def test_evaluate_instances(self): + test_client = _genai.client.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) + with mock.patch.object( + test_client.evals, "_evaluate_instances" + ) as mock_evaluate: + test_client.evals._evaluate_instances(bleu_input=_genai.types.BleuInput()) + mock_evaluate.assert_called_once_with(bleu_input=_genai.types.BleuInput()) + + @pytest.mark.usefixtures("google_auth_mock") + def test_eval_run(self): + test_client = _genai.client.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) + with pytest.raises(NotImplementedError): + test_client.evals.run() diff --git a/tests/unit/vertexai/model_garden/test_model_garden.py b/tests/unit/vertexai/model_garden/test_model_garden.py index 23fefcfae2..beec2521cf 100644 --- a/tests/unit/vertexai/model_garden/test_model_garden.py +++ b/tests/unit/vertexai/model_garden/test_model_garden.py @@ -14,6 +14,7 @@ """Unit tests for ModelGarden class.""" import importlib +import textwrap from unittest import mock from google import auth @@ -41,6 +42,7 @@ _TEST_PROJECT = "test-project" _TEST_LOCATION = "us-central1" +_TEST_PROJECT_NUMBER = "1234567890" _TEST_MODEL_FULL_RESOURCE_NAME = ( "publishers/google/models/paligemma@paligemma-224-float32" @@ -179,14 +181,39 @@ def get_publisher_model_mock(): multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex( multi_deploy_vertex=[ types.PublisherModel.CallToAction.Deploy( + container_spec=types.ModelContainerSpec( + image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", + command=["python", "main.py"], + args=["--model-id=gemma-2b"], + env=[ + types.EnvVar(name="MODEL_ID", value="gemma-2b") + ], + ), dedicated_resources=types.DedicatedResources( machine_spec=types.MachineSpec( machine_type="g2-standard-16", accelerator_type="NVIDIA_L4", accelerator_count=1, ) - ) - ) + ), + ), + types.PublisherModel.CallToAction.Deploy( + container_spec=types.ModelContainerSpec( + image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", + command=["python", "main.py"], + args=["--model-id=gemma-2b"], + env=[ + types.EnvVar(name="MODEL_ID", value="gemma-2b") + ], + ), + dedicated_resources=types.DedicatedResources( + machine_spec=types.MachineSpec( + machine_type="g2-standard-32", + accelerator_type="NVIDIA_L4", + accelerator_count=4, + ) + ), + ), ] ) ), @@ -197,14 +224,39 @@ def get_publisher_model_mock(): multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex( multi_deploy_vertex=[ types.PublisherModel.CallToAction.Deploy( + container_spec=types.ModelContainerSpec( + image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", + command=["python", "main.py"], + args=["--model-id=gemma-2b"], + env=[ + types.EnvVar(name="MODEL_ID", value="gemma-2b") + ], + ), dedicated_resources=types.DedicatedResources( machine_spec=types.MachineSpec( machine_type="g2-standard-16", accelerator_type="NVIDIA_L4", accelerator_count=1, ) - ) - ) + ), + ), + types.PublisherModel.CallToAction.Deploy( + container_spec=types.ModelContainerSpec( + image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", + command=["python", "main.py"], + args=["--model-id=gemma-2b"], + env=[ + types.EnvVar(name="MODEL_ID", value="gemma-2b") + ], + ), + dedicated_resources=types.DedicatedResources( + machine_spec=types.MachineSpec( + machine_type="g2-standard-32", + accelerator_type="NVIDIA_L4", + accelerator_count=4, + ) + ), + ), ] ) ), @@ -398,6 +450,40 @@ def list_publisher_models_mock(): yield list_publisher_models +@pytest.fixture +def check_license_agreement_status_mock(): + """Mocks the check_license_agreement_status method.""" + with mock.patch.object( + model_garden_service.ModelGardenServiceClient, + "check_publisher_model_eula_acceptance", + ) as check_license_agreement_status: + check_license_agreement_status.return_value = ( + types.PublisherModelEulaAcceptance( + project_number=_TEST_PROJECT_NUMBER, + publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME, + publisher_model_eula_acked=True, + ) + ) + yield check_license_agreement_status + + +@pytest.fixture +def accept_model_license_agreement_mock(): + """Mocks the accept_model_license_agreement method.""" + with mock.patch.object( + model_garden_service.ModelGardenServiceClient, + "accept_publisher_model_eula", + ) as accept_model_license_agreement: + accept_model_license_agreement.return_value = ( + types.PublisherModelEulaAcceptance( + project_number=_TEST_PROJECT_NUMBER, + publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME, + publisher_model_eula_acked=True, + ) + ) + yield accept_model_license_agreement + + @pytest.mark.usefixtures( "google_auth_mock", "deploy_mock", @@ -406,9 +492,11 @@ def list_publisher_models_mock(): "export_publisher_model_mock", "batch_prediction_mock", "complete_bq_uri_mock", + "check_license_agreement_status_mock", + "accept_model_license_agreement_mock", ) -class TestModelGarden: - """Test cases for ModelGarden class.""" +class TestModelGardenOpenModel: + """Test cases for Model Garden OpenModel class.""" def setup_method(self): importlib.reload(aiplatform.initializer) @@ -709,6 +797,24 @@ def test_deploy_with_dedicated_endpoint_success(self, deploy_mock): ) ) + def test_deploy_with_system_labels_success(self, deploy_mock): + """Tests deploying a model with system labels.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME) + model.deploy(system_labels={"test-key": "test-value"}) + deploy_mock.assert_called_once_with( + types.DeployRequest( + publisher_model_name=_TEST_MODEL_FULL_RESOURCE_NAME, + destination=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + deploy_config=types.DeployRequest.DeployConfig( + system_labels={"test-key": "test-value"} + ), + ) + ) + def test_deploy_with_fast_tryout_enabled_success(self, deploy_mock): """Tests deploying a model with fast tryout enabled.""" aiplatform.init( @@ -880,9 +986,8 @@ def test_list_deploy_options(self, get_publisher_model_mock): ) expected_message = ( - "Model does not support deployment, please use a deploy-able model" - " instead. You can use the list_deployable_models() method to find out" - " which ones currently support deployment." + "Model does not support deployment. " + "Use `list_deployable_models()` to find supported models." ) with pytest.raises(ValueError) as exception: model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME) @@ -908,6 +1013,71 @@ def test_list_deploy_options(self, get_publisher_model_mock): ) ) + def test_list_deploy_options_concise(self, get_publisher_model_mock): + """Tests getting the supported deploy options for a model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + expected_message = ( + "Model does not support deployment. " + "Use `list_deployable_models()` to find supported models." + ) + with pytest.raises(ValueError) as exception: + model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME) + _ = model.list_deploy_options(concise=True) + assert str(exception.value) == expected_message + + result = model.list_deploy_options(concise=True) + expected_result = textwrap.dedent( + """\ + [Option 1] + serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", + machine_type="g2-standard-16", + accelerator_type="NVIDIA_L4", + accelerator_count=1, + + [Option 2] + serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", + machine_type="g2-standard-32", + accelerator_type="NVIDIA_L4", + accelerator_count=4,""" + ) + assert result == expected_result + get_publisher_model_mock.assert_called_with( + types.GetPublisherModelRequest( + name=_TEST_MODEL_FULL_RESOURCE_NAME, + is_hugging_face_model=False, + include_equivalent_model_garden_model_deployment_configs=True, + ) + ) + + hf_model = model_garden.OpenModel(_TEST_MODEL_HUGGING_FACE_ID) + hf_result = hf_model.list_deploy_options(concise=True) + expected_hf_result = textwrap.dedent( + """\ + [Option 1] + serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00", + machine_type="g2-standard-16", + accelerator_type="NVIDIA_L4", + accelerator_count=1, + + [Option 2] + serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest", + machine_type="g2-standard-32", + accelerator_type="NVIDIA_L4", + accelerator_count=4,""" + ) + assert hf_result == expected_hf_result + get_publisher_model_mock.assert_called_with( + types.GetPublisherModelRequest( + name=_TEST_HUGGING_FACE_MODEL_FULL_RESOURCE_NAME, + is_hugging_face_model=True, + include_equivalent_model_garden_model_deployment_configs=True, + ) + ) + def test_list_deployable_models(self, list_publisher_models_mock): """Tests getting the supported deploy options for a model.""" aiplatform.init( @@ -999,3 +1169,43 @@ def test_batch_prediction_success(self, batch_prediction_mock): batch_prediction_job=expected_gapic_batch_prediction_job, timeout=None, ) + + def test_check_license_agreement_status_success( + self, check_license_agreement_status_mock + ): + """Tests checking EULA acceptance for a model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME) + eula_acceptance = model.check_license_agreement_status() + check_license_agreement_status_mock.assert_called_once_with( + types.CheckPublisherModelEulaAcceptanceRequest( + parent=f"projects/{_TEST_PROJECT}", + publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME, + ) + ) + assert eula_acceptance + + def test_accept_model_license_agreement_success( + self, accept_model_license_agreement_mock + ): + """Tests accepting EULA for a model.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME) + eula_acceptance = model.accept_model_license_agreement() + accept_model_license_agreement_mock.assert_called_once_with( + types.AcceptPublisherModelEulaRequest( + parent=f"projects/{_TEST_PROJECT}", + publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME, + ) + ) + assert eula_acceptance == types.PublisherModelEulaAcceptance( + project_number=_TEST_PROJECT_NUMBER, + publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME, + publisher_model_eula_acked=True, + ) diff --git a/vertexai/_genai/__init__.py b/vertexai/_genai/__init__.py new file mode 100644 index 0000000000..f1d1598d81 --- /dev/null +++ b/vertexai/_genai/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The vertexai module.""" + +from . import evals +from .client import Client + +__all__ = [ + "Client", + "evals", +] diff --git a/vertexai/_genai/client.py b/vertexai/_genai/client.py new file mode 100644 index 0000000000..35e36a6064 --- /dev/null +++ b/vertexai/_genai/client.py @@ -0,0 +1,92 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, Union + +import google.auth +from google.genai import client +from google.genai import types +from .evals import Evals +from .evals import AsyncEvals + + +class AsyncClient: + """Async Client for the GenAI SDK.""" + + def __init__(self, api_client: client.Client): + self._api_client = api_client + self._aio = AsyncClient(self._api_client) + self._evals = AsyncEvals(self._api_client) + + @property + def evals(self) -> AsyncEvals: + return self._evals + + +class Client: + """Client for the GenAI SDK. + + Use this client to interact with Vertex-specific Gemini features. + """ + + def __init__( + self, + *, + credentials: Optional[google.auth.credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[client.DebugConfig] = None, + http_options: Optional[Union[types.HttpOptions, types.HttpOptionsDict]] = None, + ): + """Initializes the client. + + Args: + credentials (google.auth.credentials.Credentials): The credentials to use + for authentication when calling the Vertex AI APIs. Credentials can be + obtained from environment variables and default credentials. For more + information, see `Set up Application Default Credentials + `_. + project (str): The `Google Cloud project ID + `_ to + use for quota. Can be obtained from environment variables (for example, + ``GOOGLE_CLOUD_PROJECT``). + location (str): The `location + `_ + to send API requests to (for example, ``us-central1``). Can be obtained + from environment variables. + debug_config (DebugConfig): Config settings that control network behavior + of the client. This is typically used when running test code. + http_options (Union[HttpOptions, HttpOptionsDict]): Http options to use + for the client. + """ + + self._debug_config = debug_config or client.DebugConfig() + if isinstance(http_options, dict): + http_options = types.HttpOptions(**http_options) + + self._api_client = client.Client._get_api_client( + vertexai=True, + credentials=credentials, + project=project, + location=location, + debug_config=self._debug_config, + http_options=http_options, + ) + + self._evals = Evals(self._api_client) + + @property + def evals(self) -> Evals: + return self._evals diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py new file mode 100644 index 0000000000..0aa2d16bea --- /dev/null +++ b/vertexai/_genai/evals.py @@ -0,0 +1,894 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import logging +from typing import Any, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai import types as genai_types +from google.genai._api_client import BaseApiClient +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv + +from . import types + +logger = logging.getLogger("google_genai.evals") + + +def _BleuInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prediction"]) is not None: + setv(to_object, ["prediction"], getv(from_object, ["prediction"])) + + if getv(from_object, ["reference"]) is not None: + setv(to_object, ["reference"], getv(from_object, ["reference"])) + + return to_object + + +def _BleuSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["use_effective_order"]) is not None: + setv( + to_object, + ["useEffectiveOrder"], + getv(from_object, ["use_effective_order"]), + ) + + return to_object + + +def _BleuInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instances"]) is not None: + setv( + to_object, + ["instances"], + [ + _BleuInstance_to_vertex(api_client, item, to_object) + for item in getv(from_object, ["instances"]) + ], + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _BleuSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _ExactMatchInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prediction"]) is not None: + setv(to_object, ["prediction"], getv(from_object, ["prediction"])) + + if getv(from_object, ["reference"]) is not None: + setv(to_object, ["reference"], getv(from_object, ["reference"])) + + return to_object + + +def _ExactMatchSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + +def _ExactMatchInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instances"]) is not None: + setv( + to_object, + ["instances"], + [ + _ExactMatchInstance_to_vertex(api_client, item, to_object) + for item in getv(from_object, ["instances"]) + ], + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _ExactMatchSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _RougeInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prediction"]) is not None: + setv(to_object, ["prediction"], getv(from_object, ["prediction"])) + + if getv(from_object, ["reference"]) is not None: + setv(to_object, ["reference"], getv(from_object, ["reference"])) + + return to_object + + +def _RougeSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["rouge_type"]) is not None: + setv(to_object, ["rougeType"], getv(from_object, ["rouge_type"])) + + if getv(from_object, ["split_summaries"]) is not None: + setv( + to_object, + ["splitSummaries"], + getv(from_object, ["split_summaries"]), + ) + + if getv(from_object, ["use_stemmer"]) is not None: + setv(to_object, ["useStemmer"], getv(from_object, ["use_stemmer"])) + + return to_object + + +def _RougeInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instances"]) is not None: + setv( + to_object, + ["instances"], + [ + _RougeInstance_to_vertex(api_client, item, to_object) + for item in getv(from_object, ["instances"]) + ], + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _RougeSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _PointwiseMetricInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["json_instance"]) is not None: + setv(to_object, ["jsonInstance"], getv(from_object, ["json_instance"])) + + return to_object + + +def _PointwiseMetricSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metric_prompt_template"]) is not None: + setv( + to_object, + ["metricPromptTemplate"], + getv(from_object, ["metric_prompt_template"]), + ) + + return to_object + + +def _PointwiseMetricInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instance"]) is not None: + setv( + to_object, + ["instance"], + _PointwiseMetricInstance_to_vertex( + api_client, getv(from_object, ["instance"]), to_object + ), + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _PointwiseMetricSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _PairwiseMetricInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["json_instance"]) is not None: + setv(to_object, ["jsonInstance"], getv(from_object, ["json_instance"])) + + return to_object + + +def _PairwiseMetricSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metric_prompt_template"]) is not None: + setv( + to_object, + ["metricPromptTemplate"], + getv(from_object, ["metric_prompt_template"]), + ) + + return to_object + + +def _PairwiseMetricInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instance"]) is not None: + setv( + to_object, + ["instance"], + _PairwiseMetricInstance_to_vertex( + api_client, getv(from_object, ["instance"]), to_object + ), + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _PairwiseMetricSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _ToolCallValidInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prediction"]) is not None: + setv(to_object, ["prediction"], getv(from_object, ["prediction"])) + + if getv(from_object, ["reference"]) is not None: + setv(to_object, ["reference"], getv(from_object, ["reference"])) + + return to_object + + +def _ToolCallValidSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + +def _ToolCallValidInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instances"]) is not None: + setv( + to_object, + ["instances"], + [ + _ToolCallValidInstance_to_vertex(api_client, item, to_object) + for item in getv(from_object, ["instances"]) + ], + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _ToolCallValidSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _ToolNameMatchInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prediction"]) is not None: + setv(to_object, ["prediction"], getv(from_object, ["prediction"])) + + if getv(from_object, ["reference"]) is not None: + setv(to_object, ["reference"], getv(from_object, ["reference"])) + + return to_object + + +def _ToolNameMatchSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + +def _ToolNameMatchInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instances"]) is not None: + setv( + to_object, + ["instances"], + [ + _ToolNameMatchInstance_to_vertex(api_client, item, to_object) + for item in getv(from_object, ["instances"]) + ], + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _ToolNameMatchSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _ToolParameterKeyMatchInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prediction"]) is not None: + setv(to_object, ["prediction"], getv(from_object, ["prediction"])) + + if getv(from_object, ["reference"]) is not None: + setv(to_object, ["reference"], getv(from_object, ["reference"])) + + return to_object + + +def _ToolParameterKeyMatchSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + +def _ToolParameterKeyMatchInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instances"]) is not None: + setv( + to_object, + ["instances"], + [ + _ToolParameterKeyMatchInstance_to_vertex(api_client, item, to_object) + for item in getv(from_object, ["instances"]) + ], + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _ToolParameterKeyMatchSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _ToolParameterKVMatchInstance_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prediction"]) is not None: + setv(to_object, ["prediction"], getv(from_object, ["prediction"])) + + if getv(from_object, ["reference"]) is not None: + setv(to_object, ["reference"], getv(from_object, ["reference"])) + + return to_object + + +def _ToolParameterKVMatchSpec_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["use_strict_string_match"]) is not None: + setv( + to_object, + ["useStrictStringMatch"], + getv(from_object, ["use_strict_string_match"]), + ) + + return to_object + + +def _ToolParameterKVMatchInput_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["instances"]) is not None: + setv( + to_object, + ["instances"], + [ + _ToolParameterKVMatchInstance_to_vertex(api_client, item, to_object) + for item in getv(from_object, ["instances"]) + ], + ) + + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _ToolParameterKVMatchSpec_to_vertex( + api_client, getv(from_object, ["metric_spec"]), to_object + ), + ) + + return to_object + + +def _EvaluateInstancesRequestParameters_to_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["bleu_input"]) is not None: + setv( + to_object, + ["bleuInput"], + _BleuInput_to_vertex( + api_client, getv(from_object, ["bleu_input"]), to_object + ), + ) + + if getv(from_object, ["exact_match_input"]) is not None: + setv( + to_object, + ["exactMatchInput"], + _ExactMatchInput_to_vertex( + api_client, getv(from_object, ["exact_match_input"]), to_object + ), + ) + + if getv(from_object, ["rouge_input"]) is not None: + setv( + to_object, + ["rougeInput"], + _RougeInput_to_vertex( + api_client, getv(from_object, ["rouge_input"]), to_object + ), + ) + + if getv(from_object, ["pointwise_metric_input"]) is not None: + setv( + to_object, + ["pointwiseMetricInput"], + _PointwiseMetricInput_to_vertex( + api_client, + getv(from_object, ["pointwise_metric_input"]), + to_object, + ), + ) + + if getv(from_object, ["pairwise_metric_input"]) is not None: + setv( + to_object, + ["pairwiseMetricInput"], + _PairwiseMetricInput_to_vertex( + api_client, + getv(from_object, ["pairwise_metric_input"]), + to_object, + ), + ) + + if getv(from_object, ["tool_call_valid_input"]) is not None: + setv( + to_object, + ["toolCallValidInput"], + _ToolCallValidInput_to_vertex( + api_client, + getv(from_object, ["tool_call_valid_input"]), + to_object, + ), + ) + + if getv(from_object, ["tool_name_match_input"]) is not None: + setv( + to_object, + ["toolNameMatchInput"], + _ToolNameMatchInput_to_vertex( + api_client, + getv(from_object, ["tool_name_match_input"]), + to_object, + ), + ) + + if getv(from_object, ["tool_parameter_key_match_input"]) is not None: + setv( + to_object, + ["toolParameterKeyMatchInput"], + _ToolParameterKeyMatchInput_to_vertex( + api_client, + getv(from_object, ["tool_parameter_key_match_input"]), + to_object, + ), + ) + + if getv(from_object, ["tool_parameter_kv_match_input"]) is not None: + setv( + to_object, + ["toolParameterKvMatchInput"], + _ToolParameterKVMatchInput_to_vertex( + api_client, + getv(from_object, ["tool_parameter_kv_match_input"]), + to_object, + ), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _EvaluateInstancesResponse_from_vertex( + api_client: BaseApiClient, + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["bleuResults"]) is not None: + setv(to_object, ["bleu_results"], getv(from_object, ["bleuResults"])) + + if getv(from_object, ["exactMatchResults"]) is not None: + setv( + to_object, + ["exact_match_results"], + getv(from_object, ["exactMatchResults"]), + ) + + if getv(from_object, ["pairwiseMetricResult"]) is not None: + setv( + to_object, + ["pairwise_metric_result"], + getv(from_object, ["pairwiseMetricResult"]), + ) + + if getv(from_object, ["pointwiseMetricResult"]) is not None: + setv( + to_object, + ["pointwise_metric_result"], + getv(from_object, ["pointwiseMetricResult"]), + ) + + if getv(from_object, ["rougeResults"]) is not None: + setv(to_object, ["rouge_results"], getv(from_object, ["rougeResults"])) + + if getv(from_object, ["summarizationVerbosityResult"]) is not None: + setv( + to_object, + ["summarization_verbosity_result"], + getv(from_object, ["summarizationVerbosityResult"]), + ) + + if getv(from_object, ["toolCallValidResults"]) is not None: + setv( + to_object, + ["tool_call_valid_results"], + getv(from_object, ["toolCallValidResults"]), + ) + + if getv(from_object, ["toolNameMatchResults"]) is not None: + setv( + to_object, + ["tool_name_match_results"], + getv(from_object, ["toolNameMatchResults"]), + ) + + if getv(from_object, ["toolParameterKeyMatchResults"]) is not None: + setv( + to_object, + ["tool_parameter_key_match_results"], + getv(from_object, ["toolParameterKeyMatchResults"]), + ) + + if getv(from_object, ["toolParameterKvMatchResults"]) is not None: + setv( + to_object, + ["tool_parameter_kv_match_results"], + getv(from_object, ["toolParameterKvMatchResults"]), + ) + + return to_object + + +class Evals(_api_module.BaseModule): + def _evaluate_instances( + self, + *, + bleu_input: Optional[types.BleuInputOrDict] = None, + exact_match_input: Optional[types.ExactMatchInputOrDict] = None, + rouge_input: Optional[types.RougeInputOrDict] = None, + pointwise_metric_input: Optional[types.PointwiseMetricInputOrDict] = None, + pairwise_metric_input: Optional[types.PairwiseMetricInputOrDict] = None, + tool_call_valid_input: Optional[types.ToolCallValidInputOrDict] = None, + tool_name_match_input: Optional[types.ToolNameMatchInputOrDict] = None, + tool_parameter_key_match_input: Optional[ + types.ToolParameterKeyMatchInputOrDict + ] = None, + tool_parameter_kv_match_input: Optional[ + types.ToolParameterKVMatchInputOrDict + ] = None, + config: Optional[types.EvaluateInstancesConfigOrDict] = None, + ) -> types.EvaluateInstancesResponse: + """Evaluates instances based on a given metric.""" + + parameter_model = types._EvaluateInstancesRequestParameters( + bleu_input=bleu_input, + exact_match_input=exact_match_input, + rouge_input=rouge_input, + pointwise_metric_input=pointwise_metric_input, + pairwise_metric_input=pairwise_metric_input, + tool_call_valid_input=tool_call_valid_input, + tool_name_match_input=tool_name_match_input, + tool_parameter_key_match_input=tool_parameter_key_match_input, + tool_parameter_kv_match_input=tool_parameter_kv_match_input, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _EvaluateInstancesRequestParameters_to_vertex( + self._api_client, parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":evaluateInstances".format_map(request_url_dict) + else: + path = ":evaluateInstances" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = self._api_client.request( + "post", path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _EvaluateInstancesResponse_from_vertex( + self._api_client, response_dict + ) + + return_value = types.EvaluateInstancesResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + return return_value + + def run(self) -> types.EvaluateInstancesResponse: + """Evaluates an instance of a model. + + This should eventually call _evaluate_instances() + """ + raise NotImplementedError() + + def evaluate_instances( + self, + *, + metric_config: types._EvaluateInstancesRequestParameters, + ) -> types.EvaluateInstancesResponse: + """Evaluates an instance of a model.""" + + if isinstance(metric_config, types._EvaluateInstancesRequestParameters): + metric_config = metric_config.model_dump() + else: + metric_config = dict(metric_config) + + return self._evaluate_instances( + **metric_config, + ) + + +class AsyncEvals(_api_module.BaseModule): + async def _evaluate_instances( + self, + *, + bleu_input: Optional[types.BleuInputOrDict] = None, + exact_match_input: Optional[types.ExactMatchInputOrDict] = None, + rouge_input: Optional[types.RougeInputOrDict] = None, + pointwise_metric_input: Optional[types.PointwiseMetricInputOrDict] = None, + pairwise_metric_input: Optional[types.PairwiseMetricInputOrDict] = None, + tool_call_valid_input: Optional[types.ToolCallValidInputOrDict] = None, + tool_name_match_input: Optional[types.ToolNameMatchInputOrDict] = None, + tool_parameter_key_match_input: Optional[ + types.ToolParameterKeyMatchInputOrDict + ] = None, + tool_parameter_kv_match_input: Optional[ + types.ToolParameterKVMatchInputOrDict + ] = None, + config: Optional[types.EvaluateInstancesConfigOrDict] = None, + ) -> types.EvaluateInstancesResponse: + """Evaluates instances based on a given metric.""" + + parameter_model = types._EvaluateInstancesRequestParameters( + bleu_input=bleu_input, + exact_match_input=exact_match_input, + rouge_input=rouge_input, + pointwise_metric_input=pointwise_metric_input, + pairwise_metric_input=pairwise_metric_input, + tool_call_valid_input=tool_call_valid_input, + tool_name_match_input=tool_name_match_input, + tool_parameter_key_match_input=tool_parameter_key_match_input, + tool_parameter_kv_match_input=tool_parameter_kv_match_input, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _EvaluateInstancesRequestParameters_to_vertex( + self._api_client, parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":evaluateInstances".format_map(request_url_dict) + else: + path = ":evaluateInstances" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response_dict = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + if self._api_client.vertexai: + response_dict = _EvaluateInstancesResponse_from_vertex( + self._api_client, response_dict + ) + + return_value = types.EvaluateInstancesResponse._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + return return_value diff --git a/vertexai/_genai/types.py b/vertexai/_genai/types.py new file mode 100644 index 0000000000..7c02c495db --- /dev/null +++ b/vertexai/_genai/types.py @@ -0,0 +1,1245 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import logging +from typing import Any, Optional, Union +from google.genai import _common +from pydantic import Field +from typing_extensions import TypedDict + +logger = logging.getLogger("google_genai.types") + + +class PairwiseChoice(_common.CaseInSensitiveEnum): + """Output only. Pairwise metric choice.""" + + PAIRWISE_CHOICE_UNSPECIFIED = "PAIRWISE_CHOICE_UNSPECIFIED" + BASELINE = "BASELINE" + CANDIDATE = "CANDIDATE" + TIE = "TIE" + + +class BleuInstance(_common.BaseModel): + """Bleu instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class BleuInstanceDict(TypedDict, total=False): + """Bleu instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +BleuInstanceOrDict = Union[BleuInstance, BleuInstanceDict] + + +class BleuSpec(_common.BaseModel): + """Spec for bleu metric.""" + + use_effective_order: Optional[bool] = Field( + default=None, + description="""Optional. Whether to use_effective_order to compute bleu score.""", + ) + + +class BleuSpecDict(TypedDict, total=False): + """Spec for bleu metric.""" + + use_effective_order: Optional[bool] + """Optional. Whether to use_effective_order to compute bleu score.""" + + +BleuSpecOrDict = Union[BleuSpec, BleuSpecDict] + + +class BleuInput(_common.BaseModel): + + instances: Optional[list[BleuInstance]] = Field( + default=None, description="""Required. Repeated bleu instances.""" + ) + metric_spec: Optional[BleuSpec] = Field( + default=None, description="""Required. Spec for bleu score metric.""" + ) + + +class BleuInputDict(TypedDict, total=False): + + instances: Optional[list[BleuInstanceDict]] + """Required. Repeated bleu instances.""" + + metric_spec: Optional[BleuSpecDict] + """Required. Spec for bleu score metric.""" + + +BleuInputOrDict = Union[BleuInput, BleuInputDict] + + +class ExactMatchInstance(_common.BaseModel): + """Exact match instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ExactMatchInstanceDict(TypedDict, total=False): + """Exact match instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ExactMatchInstanceOrDict = Union[ExactMatchInstance, ExactMatchInstanceDict] + + +class ExactMatchSpec(_common.BaseModel): + """Spec for exact match metric.""" + + pass + + +class ExactMatchSpecDict(TypedDict, total=False): + """Spec for exact match metric.""" + + pass + + +ExactMatchSpecOrDict = Union[ExactMatchSpec, ExactMatchSpecDict] + + +class ExactMatchInput(_common.BaseModel): + + instances: Optional[list[ExactMatchInstance]] = Field( + default=None, + description="""Required. Repeated exact match instances.""", + ) + metric_spec: Optional[ExactMatchSpec] = Field( + default=None, description="""Required. Spec for exact match metric.""" + ) + + +class ExactMatchInputDict(TypedDict, total=False): + + instances: Optional[list[ExactMatchInstanceDict]] + """Required. Repeated exact match instances.""" + + metric_spec: Optional[ExactMatchSpecDict] + """Required. Spec for exact match metric.""" + + +ExactMatchInputOrDict = Union[ExactMatchInput, ExactMatchInputDict] + + +class RougeInstance(_common.BaseModel): + """Rouge instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class RougeInstanceDict(TypedDict, total=False): + """Rouge instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +RougeInstanceOrDict = Union[RougeInstance, RougeInstanceDict] + + +class RougeSpec(_common.BaseModel): + """Spec for rouge metric.""" + + rouge_type: Optional[str] = Field( + default=None, + description="""Optional. Supported rouge types are rougen[1-9], rougeL, and rougeLsum.""", + ) + split_summaries: Optional[bool] = Field( + default=None, + description="""Optional. Whether to split summaries while using rougeLsum.""", + ) + use_stemmer: Optional[bool] = Field( + default=None, + description="""Optional. Whether to use stemmer to compute rouge score.""", + ) + + +class RougeSpecDict(TypedDict, total=False): + """Spec for rouge metric.""" + + rouge_type: Optional[str] + """Optional. Supported rouge types are rougen[1-9], rougeL, and rougeLsum.""" + + split_summaries: Optional[bool] + """Optional. Whether to split summaries while using rougeLsum.""" + + use_stemmer: Optional[bool] + """Optional. Whether to use stemmer to compute rouge score.""" + + +RougeSpecOrDict = Union[RougeSpec, RougeSpecDict] + + +class RougeInput(_common.BaseModel): + """Rouge input.""" + + instances: Optional[list[RougeInstance]] = Field( + default=None, description="""Required. Repeated rouge instances.""" + ) + metric_spec: Optional[RougeSpec] = Field( + default=None, description="""Required. Spec for rouge score metric.""" + ) + + +class RougeInputDict(TypedDict, total=False): + """Rouge input.""" + + instances: Optional[list[RougeInstanceDict]] + """Required. Repeated rouge instances.""" + + metric_spec: Optional[RougeSpecDict] + """Required. Spec for rouge score metric.""" + + +RougeInputOrDict = Union[RougeInput, RougeInputDict] + + +class PointwiseMetricInstance(_common.BaseModel): + """Pointwise metric instance.""" + + json_instance: Optional[str] = Field( + default=None, + description="""Instance specified as a json string. String key-value pairs are expected in the json_instance to render PointwiseMetricSpec.instance_prompt_template.""", + ) + + +class PointwiseMetricInstanceDict(TypedDict, total=False): + """Pointwise metric instance.""" + + json_instance: Optional[str] + """Instance specified as a json string. String key-value pairs are expected in the json_instance to render PointwiseMetricSpec.instance_prompt_template.""" + + +PointwiseMetricInstanceOrDict = Union[ + PointwiseMetricInstance, PointwiseMetricInstanceDict +] + + +class PointwiseMetricSpec(_common.BaseModel): + """Spec for pointwise metric.""" + + metric_prompt_template: Optional[str] = Field( + default=None, + description="""Required. Metric prompt template for pointwise metric.""", + ) + + +class PointwiseMetricSpecDict(TypedDict, total=False): + """Spec for pointwise metric.""" + + metric_prompt_template: Optional[str] + """Required. Metric prompt template for pointwise metric.""" + + +PointwiseMetricSpecOrDict = Union[PointwiseMetricSpec, PointwiseMetricSpecDict] + + +class PointwiseMetricInput(_common.BaseModel): + """Pointwise metric input.""" + + instance: Optional[PointwiseMetricInstance] = Field( + default=None, description="""Required. Pointwise metric instance.""" + ) + metric_spec: Optional[PointwiseMetricSpec] = Field( + default=None, description="""Required. Spec for pointwise metric.""" + ) + + +class PointwiseMetricInputDict(TypedDict, total=False): + """Pointwise metric input.""" + + instance: Optional[PointwiseMetricInstanceDict] + """Required. Pointwise metric instance.""" + + metric_spec: Optional[PointwiseMetricSpecDict] + """Required. Spec for pointwise metric.""" + + +PointwiseMetricInputOrDict = Union[PointwiseMetricInput, PointwiseMetricInputDict] + + +class PairwiseMetricInstance(_common.BaseModel): + """Pairwise metric instance.""" + + json_instance: Optional[str] = Field( + default=None, + description="""Instance specified as a json string. String key-value pairs are expected in the json_instance to render PairwiseMetricSpec.instance_prompt_template.""", + ) + + +class PairwiseMetricInstanceDict(TypedDict, total=False): + """Pairwise metric instance.""" + + json_instance: Optional[str] + """Instance specified as a json string. String key-value pairs are expected in the json_instance to render PairwiseMetricSpec.instance_prompt_template.""" + + +PairwiseMetricInstanceOrDict = Union[PairwiseMetricInstance, PairwiseMetricInstanceDict] + + +class PairwiseMetricSpec(_common.BaseModel): + """Spec for pairwise metric.""" + + metric_prompt_template: Optional[str] = Field( + default=None, + description="""Required. Metric prompt template for pairwise metric.""", + ) + + +class PairwiseMetricSpecDict(TypedDict, total=False): + """Spec for pairwise metric.""" + + metric_prompt_template: Optional[str] + """Required. Metric prompt template for pairwise metric.""" + + +PairwiseMetricSpecOrDict = Union[PairwiseMetricSpec, PairwiseMetricSpecDict] + + +class PairwiseMetricInput(_common.BaseModel): + """Pairwise metric instance.""" + + instance: Optional[PairwiseMetricInstance] = Field( + default=None, description="""Required. Pairwise metric instance.""" + ) + metric_spec: Optional[PairwiseMetricSpec] = Field( + default=None, description="""Required. Spec for pairwise metric.""" + ) + + +class PairwiseMetricInputDict(TypedDict, total=False): + """Pairwise metric instance.""" + + instance: Optional[PairwiseMetricInstanceDict] + """Required. Pairwise metric instance.""" + + metric_spec: Optional[PairwiseMetricSpecDict] + """Required. Spec for pairwise metric.""" + + +PairwiseMetricInputOrDict = Union[PairwiseMetricInput, PairwiseMetricInputDict] + + +class ToolCallValidInstance(_common.BaseModel): + """Tool call valid instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ToolCallValidInstanceDict(TypedDict, total=False): + """Tool call valid instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ToolCallValidInstanceOrDict = Union[ToolCallValidInstance, ToolCallValidInstanceDict] + + +class ToolCallValidSpec(_common.BaseModel): + """Spec for tool call valid metric.""" + + pass + + +class ToolCallValidSpecDict(TypedDict, total=False): + """Spec for tool call valid metric.""" + + pass + + +ToolCallValidSpecOrDict = Union[ToolCallValidSpec, ToolCallValidSpecDict] + + +class ToolCallValidInput(_common.BaseModel): + """Tool call valid input.""" + + instances: Optional[list[ToolCallValidInstance]] = Field( + default=None, + description="""Required. Repeated tool call valid instances.""", + ) + metric_spec: Optional[ToolCallValidSpec] = Field( + default=None, + description="""Required. Spec for tool call valid metric.""", + ) + + +class ToolCallValidInputDict(TypedDict, total=False): + """Tool call valid input.""" + + instances: Optional[list[ToolCallValidInstanceDict]] + """Required. Repeated tool call valid instances.""" + + metric_spec: Optional[ToolCallValidSpecDict] + """Required. Spec for tool call valid metric.""" + + +ToolCallValidInputOrDict = Union[ToolCallValidInput, ToolCallValidInputDict] + + +class ToolNameMatchInstance(_common.BaseModel): + """Tool name match instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ToolNameMatchInstanceDict(TypedDict, total=False): + """Tool name match instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ToolNameMatchInstanceOrDict = Union[ToolNameMatchInstance, ToolNameMatchInstanceDict] + + +class ToolNameMatchSpec(_common.BaseModel): + """Spec for tool name match metric.""" + + pass + + +class ToolNameMatchSpecDict(TypedDict, total=False): + """Spec for tool name match metric.""" + + pass + + +ToolNameMatchSpecOrDict = Union[ToolNameMatchSpec, ToolNameMatchSpecDict] + + +class ToolNameMatchInput(_common.BaseModel): + """Tool name match input.""" + + instances: Optional[list[ToolNameMatchInstance]] = Field( + default=None, + description="""Required. Repeated tool name match instances.""", + ) + metric_spec: Optional[ToolNameMatchSpec] = Field( + default=None, + description="""Required. Spec for tool name match metric.""", + ) + + +class ToolNameMatchInputDict(TypedDict, total=False): + """Tool name match input.""" + + instances: Optional[list[ToolNameMatchInstanceDict]] + """Required. Repeated tool name match instances.""" + + metric_spec: Optional[ToolNameMatchSpecDict] + """Required. Spec for tool name match metric.""" + + +ToolNameMatchInputOrDict = Union[ToolNameMatchInput, ToolNameMatchInputDict] + + +class ToolParameterKeyMatchInstance(_common.BaseModel): + """Tool parameter key match instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ToolParameterKeyMatchInstanceDict(TypedDict, total=False): + """Tool parameter key match instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ToolParameterKeyMatchInstanceOrDict = Union[ + ToolParameterKeyMatchInstance, ToolParameterKeyMatchInstanceDict +] + + +class ToolParameterKeyMatchSpec(_common.BaseModel): + """Spec for tool parameter key match metric.""" + + pass + + +class ToolParameterKeyMatchSpecDict(TypedDict, total=False): + """Spec for tool parameter key match metric.""" + + pass + + +ToolParameterKeyMatchSpecOrDict = Union[ + ToolParameterKeyMatchSpec, ToolParameterKeyMatchSpecDict +] + + +class ToolParameterKeyMatchInput(_common.BaseModel): + """Tool parameter key match input.""" + + instances: Optional[list[ToolParameterKeyMatchInstance]] = Field( + default=None, + description="""Required. Repeated tool parameter key match instances.""", + ) + metric_spec: Optional[ToolParameterKeyMatchSpec] = Field( + default=None, + description="""Required. Spec for tool parameter key match metric.""", + ) + + +class ToolParameterKeyMatchInputDict(TypedDict, total=False): + """Tool parameter key match input.""" + + instances: Optional[list[ToolParameterKeyMatchInstanceDict]] + """Required. Repeated tool parameter key match instances.""" + + metric_spec: Optional[ToolParameterKeyMatchSpecDict] + """Required. Spec for tool parameter key match metric.""" + + +ToolParameterKeyMatchInputOrDict = Union[ + ToolParameterKeyMatchInput, ToolParameterKeyMatchInputDict +] + + +class ToolParameterKVMatchInstance(_common.BaseModel): + """Tool parameter kv match instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ToolParameterKVMatchInstanceDict(TypedDict, total=False): + """Tool parameter kv match instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ToolParameterKVMatchInstanceOrDict = Union[ + ToolParameterKVMatchInstance, ToolParameterKVMatchInstanceDict +] + + +class ToolParameterKVMatchSpec(_common.BaseModel): + """Spec for tool parameter kv match metric.""" + + use_strict_string_match: Optional[bool] = Field( + default=None, + description="""Optional. Whether to use STRICT string match on parameter values.""", + ) + + +class ToolParameterKVMatchSpecDict(TypedDict, total=False): + """Spec for tool parameter kv match metric.""" + + use_strict_string_match: Optional[bool] + """Optional. Whether to use STRICT string match on parameter values.""" + + +ToolParameterKVMatchSpecOrDict = Union[ + ToolParameterKVMatchSpec, ToolParameterKVMatchSpecDict +] + + +class ToolParameterKVMatchInput(_common.BaseModel): + """Tool parameter kv match input.""" + + instances: Optional[list[ToolParameterKVMatchInstance]] = Field( + default=None, + description="""Required. Repeated tool parameter key value match instances.""", + ) + metric_spec: Optional[ToolParameterKVMatchSpec] = Field( + default=None, + description="""Required. Spec for tool parameter key value match metric.""", + ) + + +class ToolParameterKVMatchInputDict(TypedDict, total=False): + """Tool parameter kv match input.""" + + instances: Optional[list[ToolParameterKVMatchInstanceDict]] + """Required. Repeated tool parameter key value match instances.""" + + metric_spec: Optional[ToolParameterKVMatchSpecDict] + """Required. Spec for tool parameter key value match metric.""" + + +ToolParameterKVMatchInputOrDict = Union[ + ToolParameterKVMatchInput, ToolParameterKVMatchInputDict +] + + +class HttpOptions(_common.BaseModel): + """HTTP options to be used in each of the requests.""" + + base_url: Optional[str] = Field( + default=None, + description="""The base URL for the AI platform service endpoint.""", + ) + api_version: Optional[str] = Field( + default=None, description="""Specifies the version of the API to use.""" + ) + headers: Optional[dict[str, str]] = Field( + default=None, + description="""Additional HTTP headers to be sent with the request.""", + ) + timeout: Optional[int] = Field( + default=None, description="""Timeout for the request in milliseconds.""" + ) + client_args: Optional[dict[str, Any]] = Field( + default=None, description="""Args passed to the HTTP client.""" + ) + async_client_args: Optional[dict[str, Any]] = Field( + default=None, description="""Args passed to the async HTTP client.""" + ) + + +class HttpOptionsDict(TypedDict, total=False): + """HTTP options to be used in each of the requests.""" + + base_url: Optional[str] + """The base URL for the AI platform service endpoint.""" + + api_version: Optional[str] + """Specifies the version of the API to use.""" + + headers: Optional[dict[str, str]] + """Additional HTTP headers to be sent with the request.""" + + timeout: Optional[int] + """Timeout for the request in milliseconds.""" + + client_args: Optional[dict[str, Any]] + """Args passed to the HTTP client.""" + + async_client_args: Optional[dict[str, Any]] + """Args passed to the async HTTP client.""" + + +HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict] + + +class EvaluateInstancesConfig(_common.BaseModel): + """Config for evaluate instances.""" + + http_options: Optional[HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class EvaluateInstancesConfigDict(TypedDict, total=False): + """Config for evaluate instances.""" + + http_options: Optional[HttpOptionsDict] + """Used to override HTTP request options.""" + + +EvaluateInstancesConfigOrDict = Union[ + EvaluateInstancesConfig, EvaluateInstancesConfigDict +] + + +class _EvaluateInstancesRequestParameters(_common.BaseModel): + """Parameters for evaluating instances.""" + + bleu_input: Optional[BleuInput] = Field(default=None, description="""""") + exact_match_input: Optional[ExactMatchInput] = Field( + default=None, description="""""" + ) + rouge_input: Optional[RougeInput] = Field(default=None, description="""""") + pointwise_metric_input: Optional[PointwiseMetricInput] = Field( + default=None, description="""""" + ) + pairwise_metric_input: Optional[PairwiseMetricInput] = Field( + default=None, description="""""" + ) + tool_call_valid_input: Optional[ToolCallValidInput] = Field( + default=None, description="""""" + ) + tool_name_match_input: Optional[ToolNameMatchInput] = Field( + default=None, description="""""" + ) + tool_parameter_key_match_input: Optional[ToolParameterKeyMatchInput] = Field( + default=None, description="""""" + ) + tool_parameter_kv_match_input: Optional[ToolParameterKVMatchInput] = Field( + default=None, description="""""" + ) + config: Optional[EvaluateInstancesConfig] = Field(default=None, description="""""") + + +class _EvaluateInstancesRequestParametersDict(TypedDict, total=False): + """Parameters for evaluating instances.""" + + bleu_input: Optional[BleuInputDict] + """""" + + exact_match_input: Optional[ExactMatchInputDict] + """""" + + rouge_input: Optional[RougeInputDict] + """""" + + pointwise_metric_input: Optional[PointwiseMetricInputDict] + """""" + + pairwise_metric_input: Optional[PairwiseMetricInputDict] + """""" + + tool_call_valid_input: Optional[ToolCallValidInputDict] + """""" + + tool_name_match_input: Optional[ToolNameMatchInputDict] + """""" + + tool_parameter_key_match_input: Optional[ToolParameterKeyMatchInputDict] + """""" + + tool_parameter_kv_match_input: Optional[ToolParameterKVMatchInputDict] + """""" + + config: Optional[EvaluateInstancesConfigDict] + """""" + + +_EvaluateInstancesRequestParametersOrDict = Union[ + _EvaluateInstancesRequestParameters, _EvaluateInstancesRequestParametersDict +] + + +class BleuMetricValue(_common.BaseModel): + """Bleu metric value for an instance.""" + + score: Optional[float] = Field( + default=None, description="""Output only. Bleu score.""" + ) + + +class BleuMetricValueDict(TypedDict, total=False): + """Bleu metric value for an instance.""" + + score: Optional[float] + """Output only. Bleu score.""" + + +BleuMetricValueOrDict = Union[BleuMetricValue, BleuMetricValueDict] + + +class BleuResults(_common.BaseModel): + """Results for bleu metric.""" + + bleu_metric_values: Optional[list[BleuMetricValue]] = Field( + default=None, description="""Output only. Bleu metric values.""" + ) + + +class BleuResultsDict(TypedDict, total=False): + """Results for bleu metric.""" + + bleu_metric_values: Optional[list[BleuMetricValueDict]] + """Output only. Bleu metric values.""" + + +BleuResultsOrDict = Union[BleuResults, BleuResultsDict] + + +class ExactMatchMetricValue(_common.BaseModel): + """Exact match metric value for an instance.""" + + score: Optional[float] = Field( + default=None, description="""Output only. Exact match score.""" + ) + + +class ExactMatchMetricValueDict(TypedDict, total=False): + """Exact match metric value for an instance.""" + + score: Optional[float] + """Output only. Exact match score.""" + + +ExactMatchMetricValueOrDict = Union[ExactMatchMetricValue, ExactMatchMetricValueDict] + + +class ExactMatchResults(_common.BaseModel): + """Results for exact match metric.""" + + exact_match_metric_values: Optional[list[ExactMatchMetricValue]] = Field( + default=None, description="""Output only. Exact match metric values.""" + ) + + +class ExactMatchResultsDict(TypedDict, total=False): + """Results for exact match metric.""" + + exact_match_metric_values: Optional[list[ExactMatchMetricValueDict]] + """Output only. Exact match metric values.""" + + +ExactMatchResultsOrDict = Union[ExactMatchResults, ExactMatchResultsDict] + + +class PairwiseMetricResult(_common.BaseModel): + """Spec for pairwise metric result.""" + + explanation: Optional[str] = Field( + default=None, + description="""Output only. Explanation for pairwise metric score.""", + ) + pairwise_choice: Optional[PairwiseChoice] = Field( + default=None, description="""Output only. Pairwise metric choice.""" + ) + + +class PairwiseMetricResultDict(TypedDict, total=False): + """Spec for pairwise metric result.""" + + explanation: Optional[str] + """Output only. Explanation for pairwise metric score.""" + + pairwise_choice: Optional[PairwiseChoice] + """Output only. Pairwise metric choice.""" + + +PairwiseMetricResultOrDict = Union[PairwiseMetricResult, PairwiseMetricResultDict] + + +class PointwiseMetricResult(_common.BaseModel): + """Spec for pointwise metric result.""" + + explanation: Optional[str] = Field( + default=None, + description="""Output only. Explanation for pointwise metric score.""", + ) + score: Optional[float] = Field( + default=None, description="""Output only. Pointwise metric score.""" + ) + + +class PointwiseMetricResultDict(TypedDict, total=False): + """Spec for pointwise metric result.""" + + explanation: Optional[str] + """Output only. Explanation for pointwise metric score.""" + + score: Optional[float] + """Output only. Pointwise metric score.""" + + +PointwiseMetricResultOrDict = Union[PointwiseMetricResult, PointwiseMetricResultDict] + + +class RougeMetricValue(_common.BaseModel): + """Rouge metric value for an instance.""" + + score: Optional[float] = Field( + default=None, description="""Output only. Rouge score.""" + ) + + +class RougeMetricValueDict(TypedDict, total=False): + """Rouge metric value for an instance.""" + + score: Optional[float] + """Output only. Rouge score.""" + + +RougeMetricValueOrDict = Union[RougeMetricValue, RougeMetricValueDict] + + +class RougeResults(_common.BaseModel): + """Results for rouge metric.""" + + rouge_metric_values: Optional[list[RougeMetricValue]] = Field( + default=None, description="""Output only. Rouge metric values.""" + ) + + +class RougeResultsDict(TypedDict, total=False): + """Results for rouge metric.""" + + rouge_metric_values: Optional[list[RougeMetricValueDict]] + """Output only. Rouge metric values.""" + + +RougeResultsOrDict = Union[RougeResults, RougeResultsDict] + + +class SummarizationVerbosityResult(_common.BaseModel): + """Spec for summarization verbosity result.""" + + confidence: Optional[float] = Field( + default=None, + description="""Output only. Confidence for summarization verbosity score.""", + ) + explanation: Optional[str] = Field( + default=None, + description="""Output only. Explanation for summarization verbosity score.""", + ) + score: Optional[float] = Field( + default=None, + description="""Output only. Summarization Verbosity score.""", + ) + + +class SummarizationVerbosityResultDict(TypedDict, total=False): + """Spec for summarization verbosity result.""" + + confidence: Optional[float] + """Output only. Confidence for summarization verbosity score.""" + + explanation: Optional[str] + """Output only. Explanation for summarization verbosity score.""" + + score: Optional[float] + """Output only. Summarization Verbosity score.""" + + +SummarizationVerbosityResultOrDict = Union[ + SummarizationVerbosityResult, SummarizationVerbosityResultDict +] + + +class ToolCallValidMetricValue(_common.BaseModel): + """Tool call valid metric value for an instance.""" + + score: Optional[float] = Field( + default=None, description="""Output only. Tool call valid score.""" + ) + + +class ToolCallValidMetricValueDict(TypedDict, total=False): + """Tool call valid metric value for an instance.""" + + score: Optional[float] + """Output only. Tool call valid score.""" + + +ToolCallValidMetricValueOrDict = Union[ + ToolCallValidMetricValue, ToolCallValidMetricValueDict +] + + +class ToolCallValidResults(_common.BaseModel): + """Results for tool call valid metric.""" + + tool_call_valid_metric_values: Optional[list[ToolCallValidMetricValue]] = Field( + default=None, + description="""Output only. Tool call valid metric values.""", + ) + + +class ToolCallValidResultsDict(TypedDict, total=False): + """Results for tool call valid metric.""" + + tool_call_valid_metric_values: Optional[list[ToolCallValidMetricValueDict]] + """Output only. Tool call valid metric values.""" + + +ToolCallValidResultsOrDict = Union[ToolCallValidResults, ToolCallValidResultsDict] + + +class ToolNameMatchMetricValue(_common.BaseModel): + """Tool name match metric value for an instance.""" + + score: Optional[float] = Field( + default=None, description="""Output only. Tool name match score.""" + ) + + +class ToolNameMatchMetricValueDict(TypedDict, total=False): + """Tool name match metric value for an instance.""" + + score: Optional[float] + """Output only. Tool name match score.""" + + +ToolNameMatchMetricValueOrDict = Union[ + ToolNameMatchMetricValue, ToolNameMatchMetricValueDict +] + + +class ToolNameMatchResults(_common.BaseModel): + """Results for tool name match metric.""" + + tool_name_match_metric_values: Optional[list[ToolNameMatchMetricValue]] = Field( + default=None, + description="""Output only. Tool name match metric values.""", + ) + + +class ToolNameMatchResultsDict(TypedDict, total=False): + """Results for tool name match metric.""" + + tool_name_match_metric_values: Optional[list[ToolNameMatchMetricValueDict]] + """Output only. Tool name match metric values.""" + + +ToolNameMatchResultsOrDict = Union[ToolNameMatchResults, ToolNameMatchResultsDict] + + +class ToolParameterKeyMatchMetricValue(_common.BaseModel): + """Tool parameter key match metric value for an instance.""" + + score: Optional[float] = Field( + default=None, + description="""Output only. Tool parameter key match score.""", + ) + + +class ToolParameterKeyMatchMetricValueDict(TypedDict, total=False): + """Tool parameter key match metric value for an instance.""" + + score: Optional[float] + """Output only. Tool parameter key match score.""" + + +ToolParameterKeyMatchMetricValueOrDict = Union[ + ToolParameterKeyMatchMetricValue, ToolParameterKeyMatchMetricValueDict +] + + +class ToolParameterKeyMatchResults(_common.BaseModel): + """Results for tool parameter key match metric.""" + + tool_parameter_key_match_metric_values: Optional[ + list[ToolParameterKeyMatchMetricValue] + ] = Field( + default=None, + description="""Output only. Tool parameter key match metric values.""", + ) + + +class ToolParameterKeyMatchResultsDict(TypedDict, total=False): + """Results for tool parameter key match metric.""" + + tool_parameter_key_match_metric_values: Optional[ + list[ToolParameterKeyMatchMetricValueDict] + ] + """Output only. Tool parameter key match metric values.""" + + +ToolParameterKeyMatchResultsOrDict = Union[ + ToolParameterKeyMatchResults, ToolParameterKeyMatchResultsDict +] + + +class ToolParameterKVMatchMetricValue(_common.BaseModel): + """Tool parameter key value match metric value for an instance.""" + + score: Optional[float] = Field( + default=None, + description="""Output only. Tool parameter key value match score.""", + ) + + +class ToolParameterKVMatchMetricValueDict(TypedDict, total=False): + """Tool parameter key value match metric value for an instance.""" + + score: Optional[float] + """Output only. Tool parameter key value match score.""" + + +ToolParameterKVMatchMetricValueOrDict = Union[ + ToolParameterKVMatchMetricValue, ToolParameterKVMatchMetricValueDict +] + + +class ToolParameterKVMatchResults(_common.BaseModel): + """Results for tool parameter key value match metric.""" + + tool_parameter_kv_match_metric_values: Optional[ + list[ToolParameterKVMatchMetricValue] + ] = Field( + default=None, + description="""Output only. Tool parameter key value match metric values.""", + ) + + +class ToolParameterKVMatchResultsDict(TypedDict, total=False): + """Results for tool parameter key value match metric.""" + + tool_parameter_kv_match_metric_values: Optional[ + list[ToolParameterKVMatchMetricValueDict] + ] + """Output only. Tool parameter key value match metric values.""" + + +ToolParameterKVMatchResultsOrDict = Union[ + ToolParameterKVMatchResults, ToolParameterKVMatchResultsDict +] + + +class EvaluateInstancesResponse(_common.BaseModel): + """Result of evaluating an LLM metric.""" + + bleu_results: Optional[BleuResults] = Field( + default=None, description="""Results for bleu metric.""" + ) + exact_match_results: Optional[ExactMatchResults] = Field( + default=None, + description="""Auto metric evaluation results. Results for exact match metric.""", + ) + pairwise_metric_result: Optional[PairwiseMetricResult] = Field( + default=None, description="""Result for pairwise metric.""" + ) + pointwise_metric_result: Optional[PointwiseMetricResult] = Field( + default=None, + description="""Generic metrics. Result for pointwise metric.""", + ) + rouge_results: Optional[RougeResults] = Field( + default=None, description="""Results for rouge metric.""" + ) + summarization_verbosity_result: Optional[SummarizationVerbosityResult] = Field( + default=None, + description="""Result for summarization verbosity metric.""", + ) + tool_call_valid_results: Optional[ToolCallValidResults] = Field( + default=None, + description="""Tool call metrics. Results for tool call valid metric.""", + ) + tool_name_match_results: Optional[ToolNameMatchResults] = Field( + default=None, description="""Results for tool name match metric.""" + ) + tool_parameter_key_match_results: Optional[ToolParameterKeyMatchResults] = Field( + default=None, + description="""Results for tool parameter key match metric.""", + ) + tool_parameter_kv_match_results: Optional[ToolParameterKVMatchResults] = Field( + default=None, + description="""Results for tool parameter key value match metric.""", + ) + + +class EvaluateInstancesResponseDict(TypedDict, total=False): + """Result of evaluating an LLM metric.""" + + bleu_results: Optional[BleuResultsDict] + """Results for bleu metric.""" + + exact_match_results: Optional[ExactMatchResultsDict] + """Auto metric evaluation results. Results for exact match metric.""" + + pairwise_metric_result: Optional[PairwiseMetricResultDict] + """Result for pairwise metric.""" + + pointwise_metric_result: Optional[PointwiseMetricResultDict] + """Generic metrics. Result for pointwise metric.""" + + rouge_results: Optional[RougeResultsDict] + """Results for rouge metric.""" + + summarization_verbosity_result: Optional[SummarizationVerbosityResultDict] + """Result for summarization verbosity metric.""" + + tool_call_valid_results: Optional[ToolCallValidResultsDict] + """Tool call metrics. Results for tool call valid metric.""" + + tool_name_match_results: Optional[ToolNameMatchResultsDict] + """Results for tool name match metric.""" + + tool_parameter_key_match_results: Optional[ToolParameterKeyMatchResultsDict] + """Results for tool parameter key match metric.""" + + tool_parameter_kv_match_results: Optional[ToolParameterKVMatchResultsDict] + """Results for tool parameter key value match metric.""" + + +EvaluateInstancesResponseOrDict = Union[ + EvaluateInstancesResponse, EvaluateInstancesResponseDict +] + + +class EvalDataset(_common.BaseModel): + + file: Optional[str] = Field(default=None, description="""""") + + +class EvalDatasetDict(TypedDict, total=False): + + file: Optional[str] + """""" + + +EvalDatasetOrDict = Union[EvalDataset, EvalDatasetDict] diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py index 37ba528134..983e3c572a 100644 --- a/vertexai/agent_engines/_agent_engines.py +++ b/vertexai/agent_engines/_agent_engines.py @@ -338,8 +338,11 @@ def create( """ sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" _validate_sys_version_or_raise(sys_version) + gcs_dir_name = gcs_dir_name or _DEFAULT_GCS_DIR_NAME + staging_bucket = initializer.global_config.staging_bucket if agent_engine is not None: agent_engine = _validate_agent_engine_or_raise(agent_engine) + _validate_staging_bucket_or_raise(staging_bucket) if agent_engine is None: if requirements is not None: raise ValueError("requirements must be None if agent_engine is None.") @@ -350,12 +353,9 @@ def create( requirements=requirements, ) extra_packages = _validate_extra_packages_or_raise(extra_packages) - gcs_dir_name = gcs_dir_name or _DEFAULT_GCS_DIR_NAME sdk_resource = cls.__new__(cls) base.VertexAiResourceNounWithFutureManager.__init__(sdk_resource) - staging_bucket = initializer.global_config.staging_bucket - _validate_staging_bucket_or_raise(staging_bucket) # Prepares the Agent Engine for creation in Vertex AI. # This involves packaging and uploading the artifacts for # agent_engine, requirements and extra_packages to @@ -881,17 +881,18 @@ def _prepare( gcs_dir_name (str): The GCS bucket directory under `staging_bucket` to use for staging the artifacts needed. """ + if agent_engine is None: + return gcs_bucket = _get_gcs_bucket( project=project, location=location, staging_bucket=staging_bucket, ) - if agent_engine is not None: - _upload_agent_engine( - agent_engine=agent_engine, - gcs_bucket=gcs_bucket, - gcs_dir_name=gcs_dir_name, - ) + _upload_agent_engine( + agent_engine=agent_engine, + gcs_bucket=gcs_bucket, + gcs_dir_name=gcs_dir_name, + ) if requirements is not None: _upload_requirements( requirements=requirements, @@ -992,7 +993,7 @@ def _generate_update_request_or_raise( Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]] ] = None, ) -> reasoning_engine_service.UpdateReasoningEngineRequest: - """Tries to generates the update request for the agent engine.""" + """Tries to generate the update request for the agent engine.""" is_spec_update = False update_masks: List[str] = [] agent_engine_spec = aip_types.ReasoningEngineSpec() diff --git a/vertexai/model_garden/_model_garden.py b/vertexai/model_garden/_model_garden.py index 826bc6a7dc..a311f40c8a 100644 --- a/vertexai/model_garden/_model_garden.py +++ b/vertexai/model_garden/_model_garden.py @@ -257,7 +257,17 @@ class _ModelGardenClientWithOverride(utils.ClientWithOverride): class OpenModel: - """Represents a Model Garden Open model.""" + """Represents a Model Garden Open model. + + Attributes: + model_name: Model Garden model resource name in the format of + `publishers/{publisher}/models/{model}@{version}`, or a + simplified resource name in the format of + `{publisher}/{model}@{version}`, or a Hugging Face model ID in + the format of `{organization}/{model}`. + """ + + __module__ = "vertexai.preview.model_garden" def __init__( self, @@ -373,6 +383,7 @@ def deploy( reservation_affinity_values: Optional[List[str]] = None, use_dedicated_endpoint: Optional[bool] = False, fast_tryout_enabled: Optional[bool] = False, + system_labels: Optional[Dict[str, str]] = None, endpoint_display_name: Optional[str] = None, model_display_name: Optional[str] = None, deploy_request_timeout: Optional[float] = None, @@ -399,139 +410,138 @@ def deploy( Args: accept_eula (bool): Whether to accept the End User License Agreement. hugging_face_access_token (str): The access token to access Hugging Face - models. Reference: https://huggingface.co/docs/hub/en/security-tokens - machine_type (str): - Optional. The type of machine. Not specifying machine type will - result in model to be deployed with automatic resources. - min_replica_count (int): - Optional. The minimum number of machine replicas this deployed - model will be always deployed on. If traffic against it increases, - it may dynamically be deployed onto more replicas, and as traffic - decreases, some of these extra replicas may be freed. - max_replica_count (int): - Optional. The maximum number of replicas this deployed model may - be deployed on when the traffic against it increases. If requested - value is too large, the deployment will error, but if deployment - succeeds then the ability to scale the model to that many replicas - is guaranteed (barring service outages). If traffic against the - deployed model increases beyond what its replicas at maximum may - handle, a portion of the traffic will be dropped. If this value - is not provided, the larger value of min_replica_count or 1 will - be used. If value provided is smaller than min_replica_count, it - will automatically be increased to be min_replica_count. - accelerator_type (str): - Optional. Hardware accelerator type. Must also set accelerator_count if used. - One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, - NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 - accelerator_count (int): - Optional. The number of accelerators to attach to a worker replica. - spot (bool): - Optional. Whether to schedule the deployment workload on spot VMs. - reservation_affinity_type (str): - Optional. The type of reservation affinity. - One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, - SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION - reservation_affinity_key (str): - Optional. Corresponds to the label key of a reservation resource. - To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key - and specify the name of your reservation as its value. - reservation_affinity_values (List[str]): - Optional. Corresponds to the label values of a reservation resource. - This must be the full resource name of the reservation. - Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' - use_dedicated_endpoint (bool): - Optional. Default value is False. If set to True, the underlying prediction call will be made - using the dedicated endpoint dns. - fast_tryout_enabled (bool): - Optional. Defaults to False. - If True, model will be deployed using faster deployment path. - Useful for quick experiments. Not for production workloads. Only - available for most popular models with certain machine types. + models. Reference: https://huggingface.co/docs/hub/en/security-tokens + machine_type (str): Optional. The type of machine. Not specifying + machine type will result in model to be deployed with automatic + resources. + min_replica_count (int): Optional. The minimum number of machine + replicas this deployed model will be always deployed on. If traffic + against it increases, it may dynamically be deployed onto more + replicas, and as traffic decreases, some of these extra replicas may + be freed. + max_replica_count (int): Optional. The maximum number of replicas this + deployed model may be deployed on when the traffic against it + increases. If requested value is too large, the deployment will error, + but if deployment succeeds then the ability to scale the model to that + many replicas is guaranteed (barring service outages). If traffic + against the deployed model increases beyond what its replicas at + maximum may handle, a portion of the traffic will be dropped. If this + value is not provided, the larger value of min_replica_count or 1 will + be used. If value provided is smaller than min_replica_count, it will + automatically be increased to be min_replica_count. + accelerator_type (str): Optional. Hardware accelerator type. Must also + set accelerator_count if used. One of ACCELERATOR_TYPE_UNSPECIFIED, + NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100, + NVIDIA_TESLA_P4, NVIDIA_TESLA_T4 + accelerator_count (int): Optional. The number of accelerators to attach + to a worker replica. + spot (bool): Optional. Whether to schedule the deployment workload on + spot VMs. + reservation_affinity_type (str): Optional. The type of reservation + affinity. One of NO_RESERVATION, ANY_RESERVATION, + SPECIFIC_RESERVATION, SPECIFIC_THEN_ANY_RESERVATION, + SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): Optional. Corresponds to the label key + of a reservation resource. To target a SPECIFIC_RESERVATION by name, + use `compute.googleapis.com/reservation-name` as the key and specify + the name of your reservation as its value. + reservation_affinity_values (List[str]): Optional. Corresponds to the + label values of a reservation resource. This must be the full resource + name of the reservation. + Format: + 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' + use_dedicated_endpoint (bool): Optional. Default value is False. If set + to True, the underlying prediction call will be made using the + dedicated endpoint dns. + fast_tryout_enabled (bool): Optional. Defaults to False. If True, model + will be deployed using faster deployment path. Useful for quick + experiments. Not for production workloads. Only available for most + popular models with certain machine types. + system_labels (Dict[str, str]): Optional. System labels for Model Garden + deployments. These labels are managed by Google and for tracking + purposes only. endpoint_display_name: The display name of the created endpoint. model_display_name: The display name of the uploaded model. - deploy_request_timeout: The timeout for the deploy request. Default - is 2 hours. - serving_container_spec (types.ModelContainerSpec): - Optional. The container specification for the model instance. - This specification overrides the default container specification - and other serving container parameters. - serving_container_image_uri (str): - Optional. The URI of the Model serving container. This parameter is required - if the parameter `local_model` is not specified. - serving_container_predict_route (str): - Optional. An HTTP path to send prediction requests to the container, and - which must be supported by it. If not specified a default HTTP path will - be used by Vertex AI. - serving_container_health_route (str): - Optional. An HTTP path to send health check requests to the container, and which - must be supported by it. If not specified a standard HTTP path will be - used by Vertex AI. - serving_container_command: Optional[Sequence[str]]=None, - The command with which the container is run. Not executed within a - shell. The Docker image's ENTRYPOINT is used if this is not provided. - Variable references $(VAR_NAME) are expanded using the container's - environment. If a variable cannot be resolved, the reference in the - input string will be unchanged. The $(VAR_NAME) syntax can be escaped - with a double $$, ie: $$(VAR_NAME). Escaped references will never be - expanded, regardless of whether the variable exists or not. - serving_container_args: Optional[Sequence[str]]=None, - The arguments to the command. The Docker image's CMD is used if this is - not provided. Variable references $(VAR_NAME) are expanded using the - container's environment. If a variable cannot be resolved, the reference - in the input string will be unchanged. The $(VAR_NAME) syntax can be - escaped with a double $$, ie: $$(VAR_NAME). Escaped references will - never be expanded, regardless of whether the variable exists or not. + deploy_request_timeout: The timeout for the deploy request. Default is 2 + hours. + serving_container_spec (types.ModelContainerSpec): Optional. The + container specification for the model instance. This specification + overrides the default container specification and other serving + container parameters. + serving_container_image_uri (str): Optional. The URI of the Model + serving container. This parameter is required if the parameter + `local_model` is not specified. + serving_container_predict_route (str): Optional. An HTTP path to send + prediction requests to the container, and which must be supported by + it. If not specified a default HTTP path will be used by Vertex AI. + serving_container_health_route (str): Optional. An HTTP path to send + health check requests to the container, and which must be supported by + it. If not specified a standard HTTP path will be used by Vertex AI. + serving_container_command: Optional[Sequence[str]]=None, The command + with which the container is run. Not executed within a shell. The + Docker image's ENTRYPOINT is used if this is not provided. Variable + references $(VAR_NAME) are expanded using the container's environment. + If a variable cannot be resolved, the reference in the input string + will be unchanged. The $(VAR_NAME) syntax can be escaped with a double + $$, ie: $$(VAR_NAME). Escaped references will never be expanded, + regardless of whether the variable exists or not. + serving_container_args: Optional[Sequence[str]]=None, The arguments to + the command. The Docker image's CMD is used if this is not provided. + Variable references $(VAR_NAME) are expanded using the container's + environment. If a variable cannot be resolved, the reference in the + input string will be unchanged. The $(VAR_NAME) syntax can be escaped + with a double $$, ie: $$(VAR_NAME). Escaped references will never be + expanded, regardless of whether the variable exists or not. serving_container_environment_variables: Optional[Dict[str, str]]=None, - The environment variables that are to be present in the container. - Should be a dictionary where keys are environment variable names - and values are environment variable values for those names. - serving_container_ports: Optional[Sequence[int]]=None, - Declaration of ports that are exposed by the container. This field is - primarily informational, it gives Vertex AI information about the - network connections the container uses. Listing or not a port here has - no impact on whether the port is actually exposed, any port listening on - the default "0.0.0.0" address inside a container will be accessible from - the network. - serving_container_grpc_ports: Optional[Sequence[int]]=None, - Declaration of ports that are exposed by the container. Vertex AI sends gRPC - prediction requests that it receives to the first port on this list. Vertex - AI also sends liveness and health checks to this port. - If you do not specify this field, gRPC requests to the container will be - disabled. - Vertex AI does not use ports other than the first one listed. This field - corresponds to the `ports` field of the Kubernetes Containers v1 core API. - serving_container_deployment_timeout (int): - Optional. Deployment timeout in seconds. - serving_container_shared_memory_size_mb (int): - Optional. The amount of the VM memory to reserve as the shared - memory for the model in megabytes. - serving_container_startup_probe_exec (Sequence[str]): - Optional. Exec specifies the action to take. Used by startup - probe. An example of this argument would be - ["cat", "/tmp/healthy"] - serving_container_startup_probe_period_seconds (int): - Optional. How often (in seconds) to perform the startup probe. - Default to 10 seconds. Minimum value is 1. - serving_container_startup_probe_timeout_seconds (int): - Optional. Number of seconds after which the startup probe times - out. Defaults to 1 second. Minimum value is 1. - serving_container_health_probe_exec (Sequence[str]): - Optional. Exec specifies the action to take. Used by health - probe. An example of this argument would be - ["cat", "/tmp/healthy"] - serving_container_health_probe_period_seconds (int): - Optional. How often (in seconds) to perform the health probe. - Default to 10 seconds. Minimum value is 1. - serving_container_health_probe_timeout_seconds (int): - Optional. Number of seconds after which the health probe times - out. Defaults to 1 second. Minimum value is 1. + The environment variables that are to be present in the container. + Should be a dictionary where keys are environment variable names and + values are environment variable values for those names. + serving_container_ports: Optional[Sequence[int]]=None, Declaration of + ports that are exposed by the container. This field is primarily + informational, it gives Vertex AI information about the network + connections the container uses. Listing or not a port here has no + impact on whether the port is actually exposed, any port listening on + the default "0.0.0.0" address inside a container will be accessible + from the network. + serving_container_grpc_ports: Optional[Sequence[int]]=None, Declaration + of ports that are exposed by the container. Vertex AI sends gRPC + prediction requests that it receives to the first port on this list. + Vertex AI also sends liveness and health checks to this port. If you + do not specify this field, gRPC requests to the container will be + disabled. Vertex AI does not use ports other than the first one + listed. This field corresponds to the `ports` field of the Kubernetes + Containers v1 core API. + serving_container_deployment_timeout (int): Optional. Deployment timeout + in seconds. + serving_container_shared_memory_size_mb (int): Optional. The amount of + the VM memory to reserve as the shared memory for the model in + megabytes. + serving_container_startup_probe_exec (Sequence[str]): Optional. Exec + specifies the action to take. Used by startup probe. An example of + this argument would be ["cat", "/tmp/healthy"] + serving_container_startup_probe_period_seconds (int): Optional. How + often (in seconds) to perform the startup probe. Default to 10 + seconds. Minimum value is 1. + serving_container_startup_probe_timeout_seconds (int): Optional. Number + of seconds after which the startup probe times out. Defaults to 1 + second. Minimum value is 1. + serving_container_health_probe_exec (Sequence[str]): Optional. Exec + specifies the action to take. Used by health probe. An example of this + argument would be ["cat", "/tmp/healthy"] + serving_container_health_probe_period_seconds (int): Optional. How often + (in seconds) to perform the health probe. Default to 10 seconds. + Minimum value is 1. + serving_container_health_probe_timeout_seconds (int): Optional. Number + of seconds after which the health probe times out. Defaults to 1 + second. Minimum value is 1. Returns: endpoint (aiplatform.Endpoint): Created endpoint. Raises: - ValueError: If ``serving_container_spec`` is specified but ``serving_container_spec.image_uri`` + ValueError: If ``serving_container_spec`` is specified but + ``serving_container_spec.image_uri`` is ``None``, or if ``serving_container_spec`` is specified but other serving container parameters are specified. """ @@ -589,14 +599,19 @@ def deploy( if fast_tryout_enabled: request.deploy_config.fast_tryout_enabled = fast_tryout_enabled + if system_labels: + request.deploy_config.system_labels = system_labels + if serving_container_spec: if not serving_container_spec.image_uri: raise ValueError( - "Serving container image uri is required for the serving container spec." + "Serving container image uri is required for the serving container" + " spec." ) if serving_container_image_uri: raise ValueError( - "Serving container image uri is already set in the serving container spec." + "Serving container image uri is already set in the serving" + " container spec." ) request.model_config.container_spec = serving_container_spec @@ -640,24 +655,65 @@ def deploy( def list_deploy_options( self, - ) -> Sequence[types.PublisherModel.CallToAction.Deploy]: - """Lists the verified deploy options for the model.""" + concise: bool = False, + ) -> Union[str, Sequence[types.PublisherModel.CallToAction.Deploy]]: + """Lists the verified deploy options for the model. + + Args: + concise: If true, returns a human-readable string with container and + machine specs. + + Returns: + A list of deploy options or a concise formatted string. + """ request = types.GetPublisherModelRequest( name=self._publisher_model_name, is_hugging_face_model=bool(self._is_hugging_face_model), include_equivalent_model_garden_model_deployment_configs=True, ) response = self._us_central1_model_garden_client.get_publisher_model(request) - multi_deploy = ( + deploy_options = ( response.supported_actions.multi_deploy_vertex.multi_deploy_vertex ) - if not multi_deploy: + + if not deploy_options: raise ValueError( - "Model does not support deployment, please use a deploy-able model" - " instead. You can use the list_deployable_models() method" - " to find out which ones currently support deployment." + "Model does not support deployment. " + "Use `list_deployable_models()` to find supported models." + ) + + if not concise: + return deploy_options + + def _extract_config(option): + container = ( + option.container_spec.image_uri if option.container_spec else None + ) + machine = ( + option.dedicated_resources.machine_spec + if option.dedicated_resources + else None + ) + + return { + "serving_container_image_uri": container, + "machine_type": getattr(machine, "machine_type", None), + "accelerator_type": getattr( + getattr(machine, "accelerator_type", None), "name", None + ), + "accelerator_count": getattr(machine, "accelerator_count", None), + } + + concise_deploy_options = [_extract_config(opt) for opt in deploy_options] + return "\n\n".join( + f"[Option {i + 1}]\n" + + "\n".join( + f' {k}="{v}",' if k != "accelerator_count" else f" {k}={v}," + for k, v in config.items() + if v is not None ) - return multi_deploy + for i, config in enumerate(concise_deploy_options) + ) def batch_predict( self, @@ -715,3 +771,44 @@ def batch_predict( starting_replica_count=starting_replica_count, max_replica_count=max_replica_count, ) + + def check_license_agreement_status(self) -> bool: + """Check whether the project has accepted the license agreement of the model. + + EULA (End User License Agreement) is a legal document that the user must + accept before using the model. For Models having license restrictions, + the user must accept the EULA before using the model. You can check the + details of the License in Model Garden. + + Returns: + bool : True if the project has accepted the End User License + Agreement, False otherwise. + """ + request = types.CheckPublisherModelEulaAcceptanceRequest( + parent=f"projects/{self._project}", + publisher_model=self._publisher_model_name, + ) + response = self._model_garden_client.check_publisher_model_eula_acceptance( + request + ) + return response.publisher_model_eula_acked + + def accept_model_license_agreement( + self, + ) -> types.model_garden_service.PublisherModelEulaAcceptance: + """Accepts the EULA(End User License Agreement) of the model for the project. + + For Models having license restrictions, the user must accept the EULA + before using the model. Calling this method will mark the EULA as accepted + for the project. + + Returns: + types.model_garden_service.PublisherModelEulaAcceptance: + The response of the accept_eula call, containing project number, + model name and acceptance status. + """ + request = types.AcceptPublisherModelEulaRequest( + parent=f"projects/{self._project}", + publisher_model=self._publisher_model_name, + ) + return self._model_garden_client.accept_publisher_model_eula(request) diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 435cafa6f4..3ef0d10a53 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -13,9 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union + if TYPE_CHECKING: + try: + from google.genai import types + + ContentDict = types.Content + except (ImportError, AttributeError): + ContentDict = Dict + try: from google.adk.events.event import Event @@ -397,6 +405,11 @@ def set_up(self): ) for key, value in self._tmpl_attrs.get("env_vars").items(): os.environ[key] = value + if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + self._tmpl_attrs["app_name"] = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_ID", + self._tmpl_attrs.get("app_name"), + ) artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder") if artifact_service_builder: @@ -416,10 +429,6 @@ def set_up(self): project=project, location=location, ) - self._tmpl_attrs["app_name"] = os.environ.get( - "GOOGLE_CLOUD_AGENT_ENGINE_ID", - self._tmpl_attrs.get("app_name"), - ) else: self._tmpl_attrs["session_service"] = InMemorySessionService() @@ -441,7 +450,7 @@ def set_up(self): def stream_query( self, *, - message: str, + message: Union[str, "ContentDict"], user_id: str, session_id: Optional[str] = None, **kwargs, @@ -465,7 +474,15 @@ def stream_query( """ from google.genai import types - content = types.Content(role="user", parts=[types.Part(text=message)]) + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( + "message must be a string or a dictionary representing a Content object." + ) + if not self._tmpl_attrs.get("runner"): self.set_up() if not session_id: