- As of January 1, 2020 this library no longer supports Python 2 on the latest released version.
+
+ As of January 1, 2020 this library no longer supports Python 2 on the latest released version.
Library versions released prior to that date will continue to be available. For more information please
visit
Python 2 support on Google Cloud.
diff --git a/docs/summary_overview.md b/docs/summary_overview.md
index 7d0ca2569b..38025e91ac 100644
--- a/docs/summary_overview.md
+++ b/docs/summary_overview.md
@@ -1,14 +1,22 @@
-# Vertex AI API
+[
+This is a templated file. Adding content to this file may result in it being
+reverted. Instead, if you want to place additional content, create an
+"overview_content.md" file in `docs/` directory. The Sphinx tool will
+pick up on the content and merge the content.
+]: #
-Overview of the APIs available for Vertex AI API.
+# AI Platform API
+
+Overview of the APIs available for AI Platform API.
## All entries
-Classes, methods and properties & attributes for Vertex AI API.
+Classes, methods and properties & attributes for
+AI Platform API.
[classes](https://cloud.google.com/python/docs/reference/aiplatform/latest/summary_class.html)
[methods](https://cloud.google.com/python/docs/reference/aiplatform/latest/summary_method.html)
[properties and
-attributes](https://cloud.google.com/python/docs/reference/aiplatform/latest/summary_property.html)
\ No newline at end of file
+attributes](https://cloud.google.com/python/docs/reference/aiplatform/latest/summary_property.html)
diff --git a/google/cloud/aiplatform/gapic_version.py b/google/cloud/aiplatform/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/gapic_version.py
+++ b/google/cloud/aiplatform/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/models.py b/google/cloud/aiplatform/models.py
index fd079a2cd0..0e2910acc9 100644
--- a/google/cloud/aiplatform/models.py
+++ b/google/cloud/aiplatform/models.py
@@ -694,8 +694,6 @@ def __init__(
self._gca_resource = gca_endpoint_compat.Endpoint(name=endpoint_name)
self.authorized_session = None
- self.raw_predict_request_url = None
- self.stream_raw_predict_request_url = None
@property
def _prediction_client(self) -> utils.PredictionClientWithOverride:
@@ -783,7 +781,30 @@ def private_service_connect_config(
) -> Optional[gca_service_networking.PrivateServiceConnectConfig]:
"""The Private Service Connect configuration for this Endpoint."""
self._assert_gca_resource_is_available()
- return self._gca_resource.private_service_connect_config
+ return getattr(self._gca_resource, "private_service_connect_config", None)
+
+ @property
+ def dedicated_endpoint_dns(self) -> Optional[str]:
+ """The dedicated endpoint dns for this Endpoint.
+
+ This property is only available if dedicated endpoint is enabled.
+ If dedicated endpoint is not enabled, this property returns None.
+ """
+ if re.match(r"^projects/.*/endpoints/.*$", self._gca_resource.name):
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource, "dedicated_endpoint_dns", None)
+ return None
+
+ @property
+ def dedicated_endpoint_enabled(self) -> bool:
+ """The dedicated endpoint is enabled for this Endpoint.
+
+ This property will be true if dedicated endpoint is enabled.
+ """
+ if re.match(r"^projects/.*/endpoints/.*$", self._gca_resource.name):
+ self._assert_gca_resource_is_available()
+ return getattr(self._gca_resource, "dedicated_endpoint_enabled", False)
+ return False
@classmethod
def create(
@@ -1113,8 +1134,6 @@ def _construct_sdk_resource_from_gapic(
credentials=credentials,
)
endpoint.authorized_session = None
- endpoint.raw_predict_request_url = None
- endpoint.stream_raw_predict_request_url = None
return endpoint
@@ -2338,10 +2357,9 @@ def predict(
) -> Prediction:
"""Make a prediction against this Endpoint.
- For dedicated endpoint, set use_dedicated_endpoint = True:
+ Example usage:
```
- response = my_endpoint.predict(instances=[...],
- use_dedicated_endpoint=True)
+ response = my_endpoint.predict(instances=[...])
my_predictions = response.predictions
```
@@ -2379,11 +2397,19 @@ def predict(
Raises:
ImportError: If there is an issue importing the `TCPKeepAliveAdapter` package.
+ ValueError: If the dedicated endpoint DNS is empty for dedicated endpoints.
+ ValueError: If the prediction request fails for dedicated endpoints.
"""
self.wait()
+
+ if parameters is not None:
+ data = json.dumps({"instances": instances, "parameters": parameters})
+ else:
+ data = json.dumps({"instances": instances})
+
if use_raw_predict:
raw_predict_response = self.raw_predict(
- body=json.dumps({"instances": instances, "parameters": parameters}),
+ body=data,
headers={"Content-Type": "application/json"},
use_dedicated_endpoint=use_dedicated_endpoint,
timeout=timeout,
@@ -2403,63 +2429,7 @@ def predict(
),
)
- if use_dedicated_endpoint:
- self._sync_gca_resource_if_skipped()
- if (
- not self._gca_resource.dedicated_endpoint_enabled
- or self._gca_resource.dedicated_endpoint_dns is None
- ):
- raise ValueError(
- "Dedicated endpoint is not enabled or DNS is empty."
- "Please make sure endpoint has dedicated endpoint enabled"
- "and model are ready before making a prediction."
- )
- try:
- from requests_toolbelt.adapters.socket_options import (
- TCPKeepAliveAdapter,
- )
- except ImportError:
- raise ImportError(
- "Cannot import the requests-toolbelt library. Please install requests-toolbelt."
- )
-
- if not self.authorized_session:
- self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
- self.authorized_session = google_auth_requests.AuthorizedSession(
- self.credentials
- )
-
- headers = {
- "Content-Type": "application/json",
- }
-
- url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:predict"
- # count * interval need to be larger than 1 hr (3600s)
- keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100)
- self.authorized_session.mount("https://", keep_alive)
- response = self.authorized_session.post(
- url=url,
- data=json.dumps(
- {
- "instances": instances,
- "parameters": parameters,
- }
- ),
- headers=headers,
- timeout=timeout,
- )
-
- prediction_response = json.loads(response.text)
-
- return Prediction(
- predictions=prediction_response.get("predictions"),
- metadata=prediction_response.get("metadata"),
- deployed_model_id=prediction_response.get("deployedModelId"),
- model_resource_name=prediction_response.get("model"),
- model_version_id=prediction_response.get("modelVersionId"),
- )
-
- else:
+ if not self.dedicated_endpoint_enabled:
prediction_response = self._prediction_client.predict(
endpoint=self._gca_resource.name,
instances=instances,
@@ -2482,6 +2452,58 @@ def predict(
model_resource_name=prediction_response.model,
)
+ if self.dedicated_endpoint_dns is None:
+ raise ValueError(
+ "Dedicated endpoint DNS is empty. Please make sure endpoint"
+ "and model are ready before making a prediction."
+ )
+
+ if not self.authorized_session:
+ self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
+ self.authorized_session = google_auth_requests.AuthorizedSession(
+ self.credentials
+ )
+
+ if timeout is not None and timeout > google_auth_requests._DEFAULT_TIMEOUT:
+ try:
+ from requests_toolbelt.adapters.socket_options import (
+ TCPKeepAliveAdapter,
+ )
+ except ImportError:
+ raise ImportError(
+ "Cannot import the requests-toolbelt library."
+ "Please install requests-toolbelt."
+ )
+ # count * interval need to be larger than 1 hr (3600s)
+ keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100)
+ self.authorized_session.mount("https://", keep_alive)
+
+ url = f"https://{self.dedicated_endpoint_dns}/v1/{self.resource_name}:predict"
+ headers = {
+ "Content-Type": "application/json",
+ }
+ response = self.authorized_session.post(
+ url=url,
+ data=data,
+ headers=headers,
+ timeout=timeout,
+ )
+
+ if response.status_code != 200:
+ raise ValueError(
+ f"Failed to make prediction request. Status code:"
+ f"{response.status_code}, response: {response.text}."
+ )
+ prediction_response = json.loads(response.text)
+
+ return Prediction(
+ predictions=prediction_response.get("predictions"),
+ metadata=prediction_response.get("metadata"),
+ deployed_model_id=prediction_response.get("deployedModelId"),
+ model_resource_name=prediction_response.get("model"),
+ model_version_id=prediction_response.get("modelVersionId"),
+ )
+
async def predict_async(
self,
instances: List,
@@ -2562,12 +2584,6 @@ def raw_predict(
body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}'
headers = {'Content-Type':'application/json'}
)
- # For dedicated endpoint:
- response = my_endpoint.raw_predict(
- body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
- headers = {'Content-Type':'application/json'},
- dedicated_endpoint=True,
- )
status_code = response.status_code
results = json.dumps(response.text)
@@ -2593,34 +2609,29 @@ def raw_predict(
self.credentials
)
- if self.raw_predict_request_url is None:
- self.raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict"
-
- url = self.raw_predict_request_url
-
- if use_dedicated_endpoint:
- try:
- from requests_toolbelt.adapters.socket_options import (
- TCPKeepAliveAdapter,
- )
- except ImportError:
- raise ImportError(
- "Cannot import the requests-toolbelt library. Please install requests-toolbelt."
- )
- self._sync_gca_resource_if_skipped()
- if (
- not self._gca_resource.dedicated_endpoint_enabled
- or self._gca_resource.dedicated_endpoint_dns is None
- ):
+ if self.dedicated_endpoint_enabled:
+ if self.dedicated_endpoint_dns is None:
raise ValueError(
- "Dedicated endpoint is not enabled or DNS is empty."
- "Please make sure endpoint has dedicated endpoint enabled"
+ "Dedicated endpoint DNS is empty. Please make sure endpoint"
"and model are ready before making a prediction."
)
- url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict"
- # count * interval need to be larger than 1 hr (3600s)
- keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100)
- self.authorized_session.mount("https://", keep_alive)
+ url = f"https://{self.dedicated_endpoint_dns}/v1/{self.resource_name}:rawPredict"
+
+ if timeout is not None and timeout > google_auth_requests._DEFAULT_TIMEOUT:
+ try:
+ from requests_toolbelt.adapters.socket_options import (
+ TCPKeepAliveAdapter,
+ )
+ except ImportError:
+ raise ImportError(
+ "Cannot import the requests-toolbelt library."
+ "Please install requests-toolbelt."
+ )
+ # count * interval need to be larger than 1 hr (3600s)
+ keep_alive = TCPKeepAliveAdapter(idle=120, count=100, interval=100)
+ self.authorized_session.mount("https://", keep_alive)
+ else:
+ url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict"
return self.authorized_session.post(
url=url, data=body, headers=headers, timeout=timeout
@@ -2648,18 +2659,6 @@ def stream_raw_predict(
stream_result = json.dumps(response.text)
```
- For dedicated endpoint:
- ```
- my_endpoint = aiplatform.Endpoint(ENDPOINT_ID)
- for stream_response in my_endpoint.stream_raw_predict(
- body = b'{"instances":[{"feat_1":val_1, "feat_2":val_2}]}',
- headers = {'Content-Type':'application/json'},
- use_dedicated_endpoint=True,
- ):
- status_code = response.status_code
- stream_result = json.dumps(response.text)
- ```
-
Args:
body (bytes):
The body of the prediction request in bytes. This must not
@@ -2682,23 +2681,15 @@ def stream_raw_predict(
self.credentials
)
- if self.stream_raw_predict_request_url is None:
- self.stream_raw_predict_request_url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict"
-
- url = self.stream_raw_predict_request_url
-
- if use_dedicated_endpoint:
- self._sync_gca_resource_if_skipped()
- if (
- not self._gca_resource.dedicated_endpoint_enabled
- or self._gca_resource.dedicated_endpoint_dns is None
- ):
+ if self.dedicated_endpoint_enabled:
+ if self.dedicated_endpoint_dns is None:
raise ValueError(
- "Dedicated endpoint is not enabled or DNS is empty."
- "Please make sure endpoint has dedicated endpoint enabled"
+ "Dedicated endpoint DNS is empty. Please make sure endpoint"
"and model are ready before making a prediction."
)
- url = f"https://{self._gca_resource.dedicated_endpoint_dns}/v1/{self.resource_name}:streamRawPredict"
+ url = f"https://{self.dedicated_endpoint_dns}/v1/{self.resource_name}:streamRawPredict"
+ else:
+ url = f"https://{self.location}-{constants.base.API_BASE_PATH}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict"
with self.authorized_session.post(
url=url,
@@ -3332,6 +3323,7 @@ def __init__(
)
self._http_client = urllib3.PoolManager(cert_reqs="CERT_NONE")
+ self._authorized_session = None
@property
def predict_http_uri(self) -> Optional[str]:
@@ -3604,6 +3596,7 @@ def _construct_sdk_resource_from_gapic(
)
endpoint._http_client = urllib3.PoolManager(cert_reqs="CERT_NONE")
+ endpoint._authorized_session = None
return endpoint
@@ -3776,24 +3769,31 @@ def predict(
"address or DNS."
)
- if not self.credentials.valid:
- self.credentials.refresh(google_auth_requests.Request())
+ if not self._authorized_session:
+ self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
+ self._authorized_session = google_auth_requests.AuthorizedSession(
+ self.credentials,
+ )
+ self._authorized_session.verify = False
- token = self.credentials.token
- headers = {
- "Authorization": f"Bearer {token}",
- "Content-Type": "application/json",
- }
+ if parameters:
+ data = json.dumps({"instances": instances, "parameters": parameters})
+ else:
+ data = json.dumps({"instances": instances})
url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:predict"
- response = self._http_request(
- method="POST",
+ response = self._authorized_session.post(
url=url,
- body=json.dumps({"instances": instances, "parameters": parameters}),
- headers=headers,
+ data=data,
+ headers={"Content-Type": "application/json"},
)
- prediction_response = json.loads(response.data)
+ if response.status_code != 200:
+ raise ValueError(
+ f"Failed to make prediction request. Status code:"
+ f"{response.status_code}, response: {response.text}."
+ )
+ prediction_response = json.loads(response.text)
return Prediction(
predictions=prediction_response.get("predictions"),
@@ -3873,19 +3873,19 @@ def raw_predict(
"Invalid endpoint override provided. Please only use IP"
"address or DNS."
)
- if not self.credentials.valid:
- self.credentials.refresh(google_auth_requests.Request())
- token = self.credentials.token
- headers_with_token = dict(headers)
- headers_with_token["Authorization"] = f"Bearer {token}"
+ if not self._authorized_session:
+ self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
+ self._authorized_session = google_auth_requests.AuthorizedSession(
+ self.credentials,
+ )
+ self._authorized_session.verify = False
url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:rawPredict"
- return self._http_request(
- method="POST",
+ return self._authorized_session.post(
url=url,
body=body,
- headers=headers_with_token,
+ headers=headers,
)
def stream_raw_predict(
@@ -3953,24 +3953,19 @@ def stream_raw_predict(
"Invalid endpoint override provided. Please only use IP"
"address or DNS."
)
- if not self.credentials.valid:
- self.credentials.refresh(google_auth_requests.Request())
- token = self.credentials.token
- headers_with_token = dict(headers)
- headers_with_token["Authorization"] = f"Bearer {token}"
-
- if not self.authorized_session:
+ if not self._authorized_session:
self.credentials._scopes = constants.base.DEFAULT_AUTHED_SCOPES
- self.authorized_session = google_auth_requests.AuthorizedSession(
- self.credentials
+ self._authorized_session = google_auth_requests.AuthorizedSession(
+ self.credentials,
)
+ self._authorized_session.verify = False
url = f"https://{endpoint_override}/v1/projects/{self.project}/locations/{self.location}/endpoints/{self.name}:streamRawPredict"
- with self.authorized_session.post(
+ with self._authorized_session.post(
url=url,
data=body,
- headers=headers_with_token,
+ headers=headers,
stream=True,
verify=False,
) as resp:
diff --git a/google/cloud/aiplatform/persistent_resource.py b/google/cloud/aiplatform/persistent_resource.py
index f0944a5bb4..06c1914676 100644
--- a/google/cloud/aiplatform/persistent_resource.py
+++ b/google/cloud/aiplatform/persistent_resource.py
@@ -171,6 +171,7 @@ def create(
labels: Optional[Dict[str, str]] = None,
network: Optional[str] = None,
kms_key_name: Optional[str] = None,
+ enable_custom_service_account: Optional[bool] = None,
service_account: Optional[str] = None,
reserved_ip_ranges: List[str] = None,
sync: Optional[bool] = True, # pylint: disable=unused-argument
@@ -234,6 +235,16 @@ def create(
PersistentResource. If set, this PersistentResource and all
sub-resources of this PersistentResource will be secured by
this key.
+ enable_custom_service_account (bool):
+ Optional. When set to True, allows the `service_account`
+ parameter to specify a custom service account for workloads on this
+ PersistentResource. Defaults to None (False behavior).
+
+ If True, the service account provided in the `service_account` parameter
+ will be used for workloads (runtimes, jobs), provided the user has the
+ ``iam.serviceAccounts.actAs`` permission. If False, the
+ `service_account` parameter is ignored, and the PersistentResource
+ will use the default service account.
service_account (str):
Optional. Default service account that this
PersistentResource's workloads run as. The workloads
@@ -295,7 +306,31 @@ def create(
gca_encryption_spec_compat.EncryptionSpec(kms_key_name=kms_key_name)
)
- if service_account:
+ # Raise ValueError if enable_custom_service_account is False but
+ # service_account is provided
+ if (
+ enable_custom_service_account is False and service_account is not None
+ ): # pylint: disable=g-bool-id-comparison
+ raise ValueError(
+ "The parameter `enable_custom_service_account` was set to False, "
+ "but a value was provided for `service_account`. These two "
+ "settings are incompatible. If you want to use a custom "
+ "service account, set `enable_custom_service_account` to True."
+ )
+
+ elif enable_custom_service_account:
+ service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec(
+ enable_custom_service_account=True,
+ # Set service_account if it is provided, otherwise set to None
+ service_account=service_account if service_account else None,
+ )
+ gca_persistent_resource.resource_runtime_spec = (
+ gca_persistent_resource_compat.ResourceRuntimeSpec(
+ service_account_spec=service_account_spec
+ )
+ )
+ elif service_account:
+ # Handle the deprecated case where only service_account is provided
service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec(
enable_custom_service_account=True, service_account=service_account
)
diff --git a/google/cloud/aiplatform/preview/persistent_resource.py b/google/cloud/aiplatform/preview/persistent_resource.py
index e06b08e90b..54a9d55d3d 100644
--- a/google/cloud/aiplatform/preview/persistent_resource.py
+++ b/google/cloud/aiplatform/preview/persistent_resource.py
@@ -177,6 +177,7 @@ def create(
labels: Optional[Dict[str, str]] = None,
network: Optional[str] = None,
kms_key_name: Optional[str] = None,
+ enable_custom_service_account: Optional[bool] = None,
service_account: Optional[str] = None,
reserved_ip_ranges: List[str] = None,
sync: Optional[bool] = True, # pylint: disable=unused-argument
@@ -240,6 +241,16 @@ def create(
PersistentResource. If set, this PersistentResource and all
sub-resources of this PersistentResource will be secured by
this key.
+ enable_custom_service_account (bool):
+ Optional. When set to True, allows the `service_account`
+ parameter to specify a custom service account for workloads on this
+ PersistentResource. Defaults to None (False behavior).
+
+ If True, the service account provided in the `service_account` parameter
+ will be used for workloads (runtimes, jobs), provided the user has the
+ ``iam.serviceAccounts.actAs`` permission. If False, the
+ `service_account` parameter is ignored, and the PersistentResource
+ will use the default service account.
service_account (str):
Optional. Default service account that this
PersistentResource's workloads run as. The workloads
@@ -301,7 +312,29 @@ def create(
gca_encryption_spec_compat.EncryptionSpec(kms_key_name=kms_key_name)
)
- if service_account:
+ if (
+ enable_custom_service_account is False and service_account is not None
+ ): # pylint: disable=g-bool-id-comparison
+ raise ValueError(
+ "The parameter `enable_custom_service_account` was set to False, "
+ "but a value was provided for `service_account`. These two "
+ "settings are incompatible. If you want to use a custom "
+ "service account, set `enable_custom_service_account` to True."
+ )
+
+ elif enable_custom_service_account:
+ service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec(
+ enable_custom_service_account=True,
+ # Set service_account if it is provided, otherwise set to None
+ service_account=service_account if service_account else None,
+ )
+ gca_persistent_resource.resource_runtime_spec = (
+ gca_persistent_resource_compat.ResourceRuntimeSpec(
+ service_account_spec=service_account_spec
+ )
+ )
+ elif service_account:
+ # Handle the deprecated case where only service_account is provided
service_account_spec = gca_persistent_resource_compat.ServiceAccountSpec(
enable_custom_service_account=True, service_account=service_account
)
diff --git a/google/cloud/aiplatform/releases.txt b/google/cloud/aiplatform/releases.txt
index ad3b915fdf..09a949b7da 100644
--- a/google/cloud/aiplatform/releases.txt
+++ b/google/cloud/aiplatform/releases.txt
@@ -1,4 +1,4 @@
Use this file when you need to force a patch release with release-please.
Edit line 4 below with the version for the release.
-1.55.0
\ No newline at end of file
+1.92.0
\ No newline at end of file
diff --git a/google/cloud/aiplatform/tensorboard/uploader_constants.py b/google/cloud/aiplatform/tensorboard/uploader_constants.py
index 78ce23c614..16e3c3672a 100644
--- a/google/cloud/aiplatform/tensorboard/uploader_constants.py
+++ b/google/cloud/aiplatform/tensorboard/uploader_constants.py
@@ -2,28 +2,24 @@
import dataclasses
-from tensorboard.plugins.distribution import (
- metadata as distribution_metadata,
-)
-from tensorboard.plugins.graph import metadata as graphs_metadata
-from tensorboard.plugins.histogram import (
- metadata as histogram_metadata,
-)
-from tensorboard.plugins.hparams import metadata as hparams_metadata
-from tensorboard.plugins.image import metadata as images_metadata
-from tensorboard.plugins.scalar import metadata as scalar_metadata
-from tensorboard.plugins.text import metadata as text_metadata
+from tensorboard.plugins.distribution import distributions_plugin
+from tensorboard.plugins.graph import graphs_plugin
+from tensorboard.plugins.histogram import histograms_plugin
+from tensorboard.plugins.hparams import hparams_plugin
+from tensorboard.plugins.image import images_plugin
+from tensorboard.plugins.scalar import scalars_plugin
+from tensorboard.plugins.text import text_plugin
PROFILE_PLUGIN_NAME = "profile"
ALLOWED_PLUGINS = frozenset(
[
- scalar_metadata.PLUGIN_NAME,
- histogram_metadata.PLUGIN_NAME,
- distribution_metadata.PLUGIN_NAME,
- text_metadata.PLUGIN_NAME,
- hparams_metadata.PLUGIN_NAME,
- images_metadata.PLUGIN_NAME,
- graphs_metadata.PLUGIN_NAME,
+ scalars_plugin.ScalarsPlugin.plugin_name,
+ histograms_plugin.HistogramsPlugin.plugin_name,
+ distributions_plugin.DistributionsPlugin.plugin_name,
+ text_plugin.TextPlugin.plugin_name,
+ hparams_plugin.HParamsPlugin.plugin_name,
+ images_plugin.ImagesPlugin.plugin_name,
+ graphs_plugin.GraphsPlugin.plugin_name,
PROFILE_PLUGIN_NAME,
]
)
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/instance_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/params_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/predict/prediction_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
+++ b/google/cloud/aiplatform/v1/schema/trainingjob/definition_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/instance_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/params_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/predict/prediction_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform/v1beta1/schema/trainingjob/definition_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform/version.py b/google/cloud/aiplatform/version.py
index fedf961642..2982cf33c8 100644
--- a/google/cloud/aiplatform/version.py
+++ b/google/cloud/aiplatform/version.py
@@ -15,4 +15,4 @@
# limitations under the License.
#
-__version__ = "1.91.0"
+__version__ = "1.92.0"
diff --git a/google/cloud/aiplatform_v1/__init__.py b/google/cloud/aiplatform_v1/__init__.py
index a4287d4f6c..71170aee11 100644
--- a/google/cloud/aiplatform_v1/__init__.py
+++ b/google/cloud/aiplatform_v1/__init__.py
@@ -652,6 +652,7 @@
from .types.migration_service import MigrateResourceResponse
from .types.migration_service import SearchMigratableResourcesRequest
from .types.migration_service import SearchMigratableResourcesResponse
+from .types.model import Checkpoint
from .types.model import GenieSource
from .types.model import LargeModelReference
from .types.model import Model
@@ -1134,6 +1135,7 @@
"CheckTrialEarlyStoppingStateMetatdata",
"CheckTrialEarlyStoppingStateRequest",
"CheckTrialEarlyStoppingStateResponse",
+ "Checkpoint",
"Citation",
"CitationMetadata",
"Claim",
diff --git a/google/cloud/aiplatform_v1/gapic_version.py b/google/cloud/aiplatform_v1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform_v1/gapic_version.py
+++ b/google/cloud/aiplatform_v1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform_v1/services/gen_ai_cache_service/async_client.py b/google/cloud/aiplatform_v1/services/gen_ai_cache_service/async_client.py
index 16267080e8..c229cffbbc 100644
--- a/google/cloud/aiplatform_v1/services/gen_ai_cache_service/async_client.py
+++ b/google/cloud/aiplatform_v1/services/gen_ai_cache_service/async_client.py
@@ -48,6 +48,7 @@
from google.cloud.aiplatform_v1.types import cached_content
from google.cloud.aiplatform_v1.types import cached_content as gca_cached_content
from google.cloud.aiplatform_v1.types import content
+from google.cloud.aiplatform_v1.types import encryption_spec
from google.cloud.aiplatform_v1.types import gen_ai_cache_service
from google.cloud.aiplatform_v1.types import tool
from google.cloud.location import locations_pb2 # type: ignore
diff --git a/google/cloud/aiplatform_v1/services/gen_ai_cache_service/client.py b/google/cloud/aiplatform_v1/services/gen_ai_cache_service/client.py
index 82ba257ad2..b327f67bd9 100644
--- a/google/cloud/aiplatform_v1/services/gen_ai_cache_service/client.py
+++ b/google/cloud/aiplatform_v1/services/gen_ai_cache_service/client.py
@@ -64,6 +64,7 @@
from google.cloud.aiplatform_v1.types import cached_content
from google.cloud.aiplatform_v1.types import cached_content as gca_cached_content
from google.cloud.aiplatform_v1.types import content
+from google.cloud.aiplatform_v1.types import encryption_spec
from google.cloud.aiplatform_v1.types import gen_ai_cache_service
from google.cloud.aiplatform_v1.types import tool
from google.cloud.location import locations_pb2 # type: ignore
diff --git a/google/cloud/aiplatform_v1/services/migration_service/client.py b/google/cloud/aiplatform_v1/services/migration_service/client.py
index 886fa6bf79..da85ad6aed 100644
--- a/google/cloud/aiplatform_v1/services/migration_service/client.py
+++ b/google/cloud/aiplatform_v1/services/migration_service/client.py
@@ -242,62 +242,57 @@ def parse_annotated_dataset_path(path: str) -> Dict[str, str]:
@staticmethod
def dataset_path(
project: str,
- location: str,
dataset: str,
) -> str:
"""Returns a fully-qualified dataset string."""
- return "projects/{project}/locations/{location}/datasets/{dataset}".format(
+ return "projects/{project}/datasets/{dataset}".format(
project=project,
- location=location,
dataset=dataset,
)
@staticmethod
def parse_dataset_path(path: str) -> Dict[str, str]:
"""Parses a dataset path into its component segments."""
- m = re.match(
- r"^projects/(?P
.+?)/locations/(?P.+?)/datasets/(?P.+?)$",
- path,
- )
+ m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path)
return m.groupdict() if m else {}
@staticmethod
def dataset_path(
project: str,
+ location: str,
dataset: str,
) -> str:
"""Returns a fully-qualified dataset string."""
- return "projects/{project}/datasets/{dataset}".format(
+ return "projects/{project}/locations/{location}/datasets/{dataset}".format(
project=project,
+ location=location,
dataset=dataset,
)
@staticmethod
def parse_dataset_path(path: str) -> Dict[str, str]:
"""Parses a dataset path into its component segments."""
- m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path)
+ m = re.match(
+ r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$",
+ path,
+ )
return m.groupdict() if m else {}
@staticmethod
def dataset_path(
project: str,
- location: str,
dataset: str,
) -> str:
"""Returns a fully-qualified dataset string."""
- return "projects/{project}/locations/{location}/datasets/{dataset}".format(
+ return "projects/{project}/datasets/{dataset}".format(
project=project,
- location=location,
dataset=dataset,
)
@staticmethod
def parse_dataset_path(path: str) -> Dict[str, str]:
"""Parses a dataset path into its component segments."""
- m = re.match(
- r"^projects/(?P.+?)/locations/(?P.+?)/datasets/(?P.+?)$",
- path,
- )
+ m = re.match(r"^projects/(?P.+?)/datasets/(?P.+?)$", path)
return m.groupdict() if m else {}
@staticmethod
diff --git a/google/cloud/aiplatform_v1/types/__init__.py b/google/cloud/aiplatform_v1/types/__init__.py
index cfffa3544d..48888a1e70 100644
--- a/google/cloud/aiplatform_v1/types/__init__.py
+++ b/google/cloud/aiplatform_v1/types/__init__.py
@@ -689,6 +689,7 @@
SearchMigratableResourcesResponse,
)
from .model import (
+ Checkpoint,
GenieSource,
LargeModelReference,
Model,
@@ -1713,6 +1714,7 @@
"MigrateResourceResponse",
"SearchMigratableResourcesRequest",
"SearchMigratableResourcesResponse",
+ "Checkpoint",
"GenieSource",
"LargeModelReference",
"Model",
diff --git a/google/cloud/aiplatform_v1/types/cached_content.py b/google/cloud/aiplatform_v1/types/cached_content.py
index f4e0711a8d..f3fb112609 100644
--- a/google/cloud/aiplatform_v1/types/cached_content.py
+++ b/google/cloud/aiplatform_v1/types/cached_content.py
@@ -20,6 +20,7 @@
import proto # type: ignore
from google.cloud.aiplatform_v1.types import content
+from google.cloud.aiplatform_v1.types import encryption_spec as gca_encryption_spec
from google.cloud.aiplatform_v1.types import tool
from google.protobuf import duration_pb2 # type: ignore
from google.protobuf import timestamp_pb2 # type: ignore
@@ -64,10 +65,10 @@ class CachedContent(proto.Message):
Optional. Immutable. The user-generated
meaningful display name of the cached content.
model (str):
- Immutable. The name of the publisher model to
- use for cached content. Format:
-
- projects/{project}/locations/{location}/publishers/{publisher}/models/{model}
+ Immutable. The name of the ``Model`` to use for cached
+ content. Currently, only the published Gemini base models
+ are supported, in form of
+ projects/{PROJECT}/locations/{LOCATION}/publishers/google/models/{MODEL}
system_instruction (google.cloud.aiplatform_v1.types.Content):
Optional. Input only. Immutable. Developer
set system instruction. Currently, text only
@@ -81,7 +82,7 @@ class CachedContent(proto.Message):
Optional. Input only. Immutable. Tool config.
This config is shared for all tools
create_time (google.protobuf.timestamp_pb2.Timestamp):
- Output only. Creatation time of the cache
+ Output only. Creation time of the cache
entry.
update_time (google.protobuf.timestamp_pb2.Timestamp):
Output only. When the cache entry was last
@@ -89,6 +90,10 @@ class CachedContent(proto.Message):
usage_metadata (google.cloud.aiplatform_v1.types.CachedContent.UsageMetadata):
Output only. Metadata on the usage of the
cached content.
+ encryption_spec (google.cloud.aiplatform_v1.types.EncryptionSpec):
+ Input only. Immutable. Customer-managed encryption key spec
+ for a ``CachedContent``. If set, this ``CachedContent`` and
+ all its sub-resources will be secured by this key.
"""
class UsageMetadata(proto.Message):
@@ -188,6 +193,11 @@ class UsageMetadata(proto.Message):
number=12,
message=UsageMetadata,
)
+ encryption_spec: gca_encryption_spec.EncryptionSpec = proto.Field(
+ proto.MESSAGE,
+ number=13,
+ message=gca_encryption_spec.EncryptionSpec,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1/types/content.py b/google/cloud/aiplatform_v1/types/content.py
index d94dbf0c1b..76c2a8ff61 100644
--- a/google/cloud/aiplatform_v1/types/content.py
+++ b/google/cloud/aiplatform_v1/types/content.py
@@ -70,7 +70,8 @@ class HarmCategory(proto.Enum):
The harm category is sexually explicit
content.
HARM_CATEGORY_CIVIC_INTEGRITY (5):
- The harm category is civic integrity.
+ Deprecated: Election filter is not longer
+ supported. The harm category is civic integrity.
"""
HARM_CATEGORY_UNSPECIFIED = 0
HARM_CATEGORY_HATE_SPEECH = 1
@@ -393,6 +394,10 @@ class GenerationConfig(proto.Message):
Optional. Routing configuration.
This field is a member of `oneof`_ ``_routing_config``.
+ thinking_config (google.cloud.aiplatform_v1.types.GenerationConfig.ThinkingConfig):
+ Optional. Config for thinking features.
+ An error will be returned if this field is set
+ for models that don't support thinking.
"""
class RoutingConfig(proto.Message):
@@ -491,6 +496,25 @@ class ManualRoutingMode(proto.Message):
message="GenerationConfig.RoutingConfig.ManualRoutingMode",
)
+ class ThinkingConfig(proto.Message):
+ r"""Config for thinking features.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ thinking_budget (int):
+ Optional. Indicates the thinking budget in tokens. This is
+ only applied when enable_thinking is true.
+
+ This field is a member of `oneof`_ ``_thinking_budget``.
+ """
+
+ thinking_budget: int = proto.Field(
+ proto.INT32,
+ number=3,
+ optional=True,
+ )
+
temperature: float = proto.Field(
proto.FLOAT,
number=1,
@@ -561,6 +585,11 @@ class ManualRoutingMode(proto.Message):
optional=True,
message=RoutingConfig,
)
+ thinking_config: ThinkingConfig = proto.Field(
+ proto.MESSAGE,
+ number=25,
+ message=ThinkingConfig,
+ )
class SafetySetting(proto.Message):
diff --git a/google/cloud/aiplatform_v1/types/io.py b/google/cloud/aiplatform_v1/types/io.py
index 7954b603e9..0fcc4b8080 100644
--- a/google/cloud/aiplatform_v1/types/io.py
+++ b/google/cloud/aiplatform_v1/types/io.py
@@ -82,7 +82,7 @@ class GcsSource(proto.Message):
Required. Google Cloud Storage URI(-s) to the
input file(s). May contain wildcards. For more
information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
+ https://cloud.google.com/storage/docs/wildcards.
"""
uris: MutableSequence[str] = proto.RepeatedField(
diff --git a/google/cloud/aiplatform_v1/types/model.py b/google/cloud/aiplatform_v1/types/model.py
index 3b1ac173e9..878ebdee0d 100644
--- a/google/cloud/aiplatform_v1/types/model.py
+++ b/google/cloud/aiplatform_v1/types/model.py
@@ -40,6 +40,7 @@
"Port",
"ModelSourceInfo",
"Probe",
+ "Checkpoint",
},
)
@@ -322,6 +323,9 @@ class Model(proto.Message):
Output only. Reserved for future use.
satisfies_pzi (bool):
Output only. Reserved for future use.
+ checkpoints (MutableSequence[google.cloud.aiplatform_v1.types.Checkpoint]):
+ Optional. Output only. The checkpoints of the
+ model.
"""
class DeploymentResourcesType(proto.Enum):
@@ -681,6 +685,11 @@ class BaseModelSource(proto.Message):
proto.BOOL,
number=52,
)
+ checkpoints: MutableSequence["Checkpoint"] = proto.RepeatedField(
+ proto.MESSAGE,
+ number=57,
+ message="Checkpoint",
+ )
class LargeModelReference(proto.Message):
@@ -1468,4 +1477,30 @@ class HttpHeader(proto.Message):
)
+class Checkpoint(proto.Message):
+ r"""Describes the machine learning model version checkpoint.
+
+ Attributes:
+ checkpoint_id (str):
+ The ID of the checkpoint.
+ epoch (int):
+ The epoch of the checkpoint.
+ step (int):
+ The step of the checkpoint.
+ """
+
+ checkpoint_id: str = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ epoch: int = proto.Field(
+ proto.INT64,
+ number=2,
+ )
+ step: int = proto.Field(
+ proto.INT64,
+ number=3,
+ )
+
+
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1/types/openapi.py b/google/cloud/aiplatform_v1/types/openapi.py
index 79525e4834..9f5da2ed40 100644
--- a/google/cloud/aiplatform_v1/types/openapi.py
+++ b/google/cloud/aiplatform_v1/types/openapi.py
@@ -138,6 +138,23 @@ class Schema(proto.Message):
Optional. The value should be validated
against any (one or more) of the subschemas in
the list.
+ ref (str):
+ Optional. Allows indirect references between schema nodes.
+ The value should be a valid reference to a child of the root
+ ``defs``.
+
+ For example, the following schema defines a reference to a
+ schema node named "Pet":
+
+ type: object properties: pet: ref: #/defs/Pet defs: Pet:
+ type: object properties: name: type: string
+
+ The value of the "pet" property is a reference to the schema
+ node named "Pet". See details in
+ https://json-schema.org/understanding-json-schema/structuring
+ defs (MutableMapping[str, google.cloud.aiplatform_v1.types.Schema]):
+ Optional. A map of definitions for use by ``ref`` Only
+ allowed at the root of the schema.
"""
type_: "Type" = proto.Field(
@@ -235,6 +252,16 @@ class Schema(proto.Message):
number=11,
message="Schema",
)
+ ref: str = proto.Field(
+ proto.STRING,
+ number=27,
+ )
+ defs: MutableMapping[str, "Schema"] = proto.MapField(
+ proto.STRING,
+ proto.MESSAGE,
+ number=28,
+ message="Schema",
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1/types/vertex_rag_data.py b/google/cloud/aiplatform_v1/types/vertex_rag_data.py
index fd251f3db8..22692ee84b 100644
--- a/google/cloud/aiplatform_v1/types/vertex_rag_data.py
+++ b/google/cloud/aiplatform_v1/types/vertex_rag_data.py
@@ -615,12 +615,21 @@ class RagFileTransformationConfig(proto.Message):
class RagFileParsingConfig(proto.Message):
r"""Specifies the parsing config for RagFiles.
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
.. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
Attributes:
layout_parser (google.cloud.aiplatform_v1.types.RagFileParsingConfig.LayoutParser):
The Layout Parser to use for RagFiles.
+ This field is a member of `oneof`_ ``parser``.
+ llm_parser (google.cloud.aiplatform_v1.types.RagFileParsingConfig.LlmParser):
+ The LLM Parser to use for RagFiles.
+
This field is a member of `oneof`_ ``parser``.
"""
@@ -656,12 +665,52 @@ class LayoutParser(proto.Message):
number=2,
)
+ class LlmParser(proto.Message):
+ r"""Specifies the advanced parsing for RagFiles.
+
+ Attributes:
+ model_name (str):
+ The name of a LLM model used for parsing. Format:
+
+ - ``projects/{project_id}/locations/{location}/publishers/{publisher}/models/{model}``
+ max_parsing_requests_per_min (int):
+ The maximum number of requests the job is
+ allowed to make to the LLM model per minute.
+ Consult
+ https://cloud.google.com/vertex-ai/generative-ai/docs/quotas
+ and your document size to set an appropriate
+ value here. If unspecified, a default value of
+ 5000 QPM would be used.
+ custom_parsing_prompt (str):
+ The prompt to use for parsing. If not
+ specified, a default prompt will be used.
+ """
+
+ model_name: str = proto.Field(
+ proto.STRING,
+ number=1,
+ )
+ max_parsing_requests_per_min: int = proto.Field(
+ proto.INT32,
+ number=2,
+ )
+ custom_parsing_prompt: str = proto.Field(
+ proto.STRING,
+ number=3,
+ )
+
layout_parser: LayoutParser = proto.Field(
proto.MESSAGE,
number=4,
oneof="parser",
message=LayoutParser,
)
+ llm_parser: LlmParser = proto.Field(
+ proto.MESSAGE,
+ number=5,
+ oneof="parser",
+ message=LlmParser,
+ )
class UploadRagFileConfig(proto.Message):
diff --git a/google/cloud/aiplatform_v1beta1/__init__.py b/google/cloud/aiplatform_v1beta1/__init__.py
index fe88ef9c23..07334a27e8 100644
--- a/google/cloud/aiplatform_v1beta1/__init__.py
+++ b/google/cloud/aiplatform_v1beta1/__init__.py
@@ -175,6 +175,7 @@
from .types.dataset_service import ExportDataRequest
from .types.dataset_service import ExportDataResponse
from .types.dataset_service import GeminiExample
+from .types.dataset_service import GeminiRequestReadConfig
from .types.dataset_service import GeminiTemplateConfig
from .types.dataset_service import GetAnnotationSpecRequest
from .types.dataset_service import GetDatasetRequest
@@ -1703,6 +1704,7 @@
"GcsDestination",
"GcsSource",
"GeminiExample",
+ "GeminiRequestReadConfig",
"GeminiTemplateConfig",
"GenAiCacheServiceClient",
"GenAiTuningServiceClient",
diff --git a/google/cloud/aiplatform_v1beta1/gapic_version.py b/google/cloud/aiplatform_v1beta1/gapic_version.py
index dbb6c23a4a..4cfdd5d77d 100644
--- a/google/cloud/aiplatform_v1beta1/gapic_version.py
+++ b/google/cloud/aiplatform_v1beta1/gapic_version.py
@@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-__version__ = "1.91.0" # {x-release-please-version}
+__version__ = "1.92.0" # {x-release-please-version}
diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py
index 2b03b00214..429f6b5505 100644
--- a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py
+++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/async_client.py
@@ -48,6 +48,7 @@
from google.cloud.aiplatform_v1beta1.types import cached_content
from google.cloud.aiplatform_v1beta1.types import cached_content as gca_cached_content
from google.cloud.aiplatform_v1beta1.types import content
+from google.cloud.aiplatform_v1beta1.types import encryption_spec
from google.cloud.aiplatform_v1beta1.types import gen_ai_cache_service
from google.cloud.aiplatform_v1beta1.types import tool
from google.cloud.location import locations_pb2 # type: ignore
diff --git a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py
index 6eb1bc6e15..bdcd049578 100644
--- a/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py
+++ b/google/cloud/aiplatform_v1beta1/services/gen_ai_cache_service/client.py
@@ -64,6 +64,7 @@
from google.cloud.aiplatform_v1beta1.types import cached_content
from google.cloud.aiplatform_v1beta1.types import cached_content as gca_cached_content
from google.cloud.aiplatform_v1beta1.types import content
+from google.cloud.aiplatform_v1beta1.types import encryption_spec
from google.cloud.aiplatform_v1beta1.types import gen_ai_cache_service
from google.cloud.aiplatform_v1beta1.types import tool
from google.cloud.location import locations_pb2 # type: ignore
diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/async_client.py b/google/cloud/aiplatform_v1beta1/services/session_service/async_client.py
index e605cf2bb3..80722280ed 100644
--- a/google/cloud/aiplatform_v1beta1/services/session_service/async_client.py
+++ b/google/cloud/aiplatform_v1beta1/services/session_service/async_client.py
@@ -313,8 +313,8 @@ async def create_session(
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, Union[str, bytes]]] = (),
) -> operation_async.AsyncOperation:
- r"""Creates a new [Session][google.cloud.aiplatform.v1beta1.Session]
- in a given project and location.
+ r"""Creates a new
+ [Session][google.cloud.aiplatform.v1beta1.Session].
.. code-block:: python
@@ -358,7 +358,6 @@ async def sample_create_session():
parent (:class:`str`):
Required. The resource name of the location to create
the session in. Format:
- ``projects/{project}/locations/{location}`` or
``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}``
This corresponds to the ``parent`` field
@@ -568,7 +567,7 @@ async def list_sessions(
metadata: Sequence[Tuple[str, Union[str, bytes]]] = (),
) -> pagers.ListSessionsAsyncPager:
r"""Lists [Sessions][google.cloud.aiplatform.v1beta1.Session] in a
- given project and location.
+ given reasoning engine.
.. code-block:: python
@@ -864,8 +863,6 @@ async def sample_delete_session():
[SessionService.DeleteSession][google.cloud.aiplatform.v1beta1.SessionService.DeleteSession].
name (:class:`str`):
Required. The resource name of the session. Format:
- ``projects/{project}/locations/{location}/sessions/{session}``
- or
``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}``
This corresponds to the ``name`` field
diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/client.py b/google/cloud/aiplatform_v1beta1/services/session_service/client.py
index d25604a46a..599f1bdc1b 100644
--- a/google/cloud/aiplatform_v1beta1/services/session_service/client.py
+++ b/google/cloud/aiplatform_v1beta1/services/session_service/client.py
@@ -225,12 +225,14 @@ def transport(self) -> SessionServiceTransport:
def session_path(
project: str,
location: str,
+ reasoning_engine: str,
session: str,
) -> str:
"""Returns a fully-qualified session string."""
- return "projects/{project}/locations/{location}/sessions/{session}".format(
+ return "projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}".format(
project=project,
location=location,
+ reasoning_engine=reasoning_engine,
session=session,
)
@@ -238,7 +240,7 @@ def session_path(
def parse_session_path(path: str) -> Dict[str, str]:
"""Parses a session path into its component segments."""
m = re.match(
- r"^projects/(?P.+?)/locations/(?P.+?)/sessions/(?P.+?)$",
+ r"^projects/(?P.+?)/locations/(?P.+?)/reasoningEngines/(?P.+?)/sessions/(?P.+?)$",
path,
)
return m.groupdict() if m else {}
@@ -247,13 +249,15 @@ def parse_session_path(path: str) -> Dict[str, str]:
def session_event_path(
project: str,
location: str,
+ reasoning_engine: str,
session: str,
event: str,
) -> str:
"""Returns a fully-qualified session_event string."""
- return "projects/{project}/locations/{location}/sessions/{session}/events/{event}".format(
+ return "projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}/events/{event}".format(
project=project,
location=location,
+ reasoning_engine=reasoning_engine,
session=session,
event=event,
)
@@ -262,7 +266,7 @@ def session_event_path(
def parse_session_event_path(path: str) -> Dict[str, str]:
"""Parses a session_event path into its component segments."""
m = re.match(
- r"^projects/(?P.+?)/locations/(?P.+?)/sessions/(?P.+?)/events/(?P.+?)$",
+ r"^projects/(?P.+?)/locations/(?P.+?)/reasoningEngines/(?P.+?)/sessions/(?P.+?)/events/(?P.+?)$",
path,
)
return m.groupdict() if m else {}
@@ -794,8 +798,8 @@ def create_session(
timeout: Union[float, object] = gapic_v1.method.DEFAULT,
metadata: Sequence[Tuple[str, Union[str, bytes]]] = (),
) -> gac_operation.Operation:
- r"""Creates a new [Session][google.cloud.aiplatform.v1beta1.Session]
- in a given project and location.
+ r"""Creates a new
+ [Session][google.cloud.aiplatform.v1beta1.Session].
.. code-block:: python
@@ -839,7 +843,6 @@ def sample_create_session():
parent (str):
Required. The resource name of the location to create
the session in. Format:
- ``projects/{project}/locations/{location}`` or
``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}``
This corresponds to the ``parent`` field
@@ -1043,7 +1046,7 @@ def list_sessions(
metadata: Sequence[Tuple[str, Union[str, bytes]]] = (),
) -> pagers.ListSessionsPager:
r"""Lists [Sessions][google.cloud.aiplatform.v1beta1.Session] in a
- given project and location.
+ given reasoning engine.
.. code-block:: python
@@ -1333,8 +1336,6 @@ def sample_delete_session():
[SessionService.DeleteSession][google.cloud.aiplatform.v1beta1.SessionService.DeleteSession].
name (str):
Required. The resource name of the session. Format:
- ``projects/{project}/locations/{location}/sessions/{session}``
- or
``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}``
This corresponds to the ``name`` field
diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc.py b/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc.py
index 938378fc83..7f0ffd3884 100644
--- a/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc.py
+++ b/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc.py
@@ -350,8 +350,8 @@ def create_session(
) -> Callable[[session_service.CreateSessionRequest], operations_pb2.Operation]:
r"""Return a callable for the create session method over gRPC.
- Creates a new [Session][google.cloud.aiplatform.v1beta1.Session]
- in a given project and location.
+ Creates a new
+ [Session][google.cloud.aiplatform.v1beta1.Session].
Returns:
Callable[[~.CreateSessionRequest],
@@ -407,7 +407,7 @@ def list_sessions(
r"""Return a callable for the list sessions method over gRPC.
Lists [Sessions][google.cloud.aiplatform.v1beta1.Session] in a
- given project and location.
+ given reasoning engine.
Returns:
Callable[[~.ListSessionsRequest],
diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc_asyncio.py b/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc_asyncio.py
index 6842c62242..36f4cc53ec 100644
--- a/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc_asyncio.py
+++ b/google/cloud/aiplatform_v1beta1/services/session_service/transports/grpc_asyncio.py
@@ -360,8 +360,8 @@ def create_session(
]:
r"""Return a callable for the create session method over gRPC.
- Creates a new [Session][google.cloud.aiplatform.v1beta1.Session]
- in a given project and location.
+ Creates a new
+ [Session][google.cloud.aiplatform.v1beta1.Session].
Returns:
Callable[[~.CreateSessionRequest],
@@ -418,7 +418,7 @@ def list_sessions(
r"""Return a callable for the list sessions method over gRPC.
Lists [Sessions][google.cloud.aiplatform.v1beta1.Session] in a
- given project and location.
+ given reasoning engine.
Returns:
Callable[[~.ListSessionsRequest],
diff --git a/google/cloud/aiplatform_v1beta1/services/session_service/transports/rest_base.py b/google/cloud/aiplatform_v1beta1/services/session_service/transports/rest_base.py
index f7a621b52a..73df9faad5 100644
--- a/google/cloud/aiplatform_v1beta1/services/session_service/transports/rest_base.py
+++ b/google/cloud/aiplatform_v1beta1/services/session_service/transports/rest_base.py
@@ -169,11 +169,6 @@ def _get_unset_required_fields(cls, message_dict):
@staticmethod
def _get_http_options():
http_options: List[Dict[str, str]] = [
- {
- "method": "post",
- "uri": "/v1beta1/{parent=projects/*/locations/*}/sessions",
- "body": "session",
- },
{
"method": "post",
"uri": "/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*}/sessions",
@@ -231,10 +226,6 @@ def _get_unset_required_fields(cls, message_dict):
@staticmethod
def _get_http_options():
http_options: List[Dict[str, str]] = [
- {
- "method": "delete",
- "uri": "/v1beta1/{name=projects/*/locations/*/sessions/*}",
- },
{
"method": "delete",
"uri": "/v1beta1/{name=projects/*/locations/*/reasoningEngines/*/sessions/*}",
@@ -282,10 +273,6 @@ def _get_unset_required_fields(cls, message_dict):
@staticmethod
def _get_http_options():
http_options: List[Dict[str, str]] = [
- {
- "method": "get",
- "uri": "/v1beta1/{name=projects/*/locations/*/sessions/*}",
- },
{
"method": "get",
"uri": "/v1beta1/{name=projects/*/locations/*/reasoningEngines/*/sessions/*}",
@@ -337,10 +324,6 @@ def _get_http_options():
"method": "get",
"uri": "/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*/sessions/*}/events",
},
- {
- "method": "get",
- "uri": "/v1beta1/{parent=projects/*/locations/*/sessions/*}/events",
- },
]
return http_options
@@ -384,10 +367,6 @@ def _get_unset_required_fields(cls, message_dict):
@staticmethod
def _get_http_options():
http_options: List[Dict[str, str]] = [
- {
- "method": "get",
- "uri": "/v1beta1/{parent=projects/*/locations/*}/sessions",
- },
{
"method": "get",
"uri": "/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*}/sessions",
@@ -435,11 +414,6 @@ def _get_unset_required_fields(cls, message_dict):
@staticmethod
def _get_http_options():
http_options: List[Dict[str, str]] = [
- {
- "method": "patch",
- "uri": "/v1beta1/{session.name=projects/*/locations/*/sessions/*}",
- "body": "session",
- },
{
"method": "patch",
"uri": "/v1beta1/{session.name=projects/*/locations/*/reasoningEngines/*/sessions/*}",
diff --git a/google/cloud/aiplatform_v1beta1/types/__init__.py b/google/cloud/aiplatform_v1beta1/types/__init__.py
index b72fd73b19..121f8afc9e 100644
--- a/google/cloud/aiplatform_v1beta1/types/__init__.py
+++ b/google/cloud/aiplatform_v1beta1/types/__init__.py
@@ -108,6 +108,7 @@
ExportDataRequest,
ExportDataResponse,
GeminiExample,
+ GeminiRequestReadConfig,
GeminiTemplateConfig,
GetAnnotationSpecRequest,
GetDatasetRequest,
@@ -1481,6 +1482,7 @@
"ExportDataRequest",
"ExportDataResponse",
"GeminiExample",
+ "GeminiRequestReadConfig",
"GeminiTemplateConfig",
"GetAnnotationSpecRequest",
"GetDatasetRequest",
diff --git a/google/cloud/aiplatform_v1beta1/types/cached_content.py b/google/cloud/aiplatform_v1beta1/types/cached_content.py
index b804f9d2f1..a202168461 100644
--- a/google/cloud/aiplatform_v1beta1/types/cached_content.py
+++ b/google/cloud/aiplatform_v1beta1/types/cached_content.py
@@ -20,6 +20,7 @@
import proto # type: ignore
from google.cloud.aiplatform_v1beta1.types import content
+from google.cloud.aiplatform_v1beta1.types import encryption_spec as gca_encryption_spec
from google.cloud.aiplatform_v1beta1.types import tool
from google.protobuf import duration_pb2 # type: ignore
from google.protobuf import timestamp_pb2 # type: ignore
@@ -64,10 +65,10 @@ class CachedContent(proto.Message):
Optional. Immutable. The user-generated
meaningful display name of the cached content.
model (str):
- Immutable. The name of the publisher model to
- use for cached content. Format:
-
- projects/{project}/locations/{location}/publishers/{publisher}/models/{model}
+ Immutable. The name of the ``Model`` to use for cached
+ content. Currently, only the published Gemini base models
+ are supported, in form of
+ projects/{PROJECT}/locations/{LOCATION}/publishers/google/models/{MODEL}
system_instruction (google.cloud.aiplatform_v1beta1.types.Content):
Optional. Input only. Immutable. Developer
set system instruction. Currently, text only
@@ -81,7 +82,7 @@ class CachedContent(proto.Message):
Optional. Input only. Immutable. Tool config.
This config is shared for all tools
create_time (google.protobuf.timestamp_pb2.Timestamp):
- Output only. Creatation time of the cache
+ Output only. Creation time of the cache
entry.
update_time (google.protobuf.timestamp_pb2.Timestamp):
Output only. When the cache entry was last
@@ -89,6 +90,10 @@ class CachedContent(proto.Message):
usage_metadata (google.cloud.aiplatform_v1beta1.types.CachedContent.UsageMetadata):
Output only. Metadata on the usage of the
cached content.
+ encryption_spec (google.cloud.aiplatform_v1beta1.types.EncryptionSpec):
+ Input only. Immutable. Customer-managed encryption key spec
+ for a ``CachedContent``. If set, this ``CachedContent`` and
+ all its sub-resources will be secured by this key.
"""
class UsageMetadata(proto.Message):
@@ -188,6 +193,11 @@ class UsageMetadata(proto.Message):
number=12,
message=UsageMetadata,
)
+ encryption_spec: gca_encryption_spec.EncryptionSpec = proto.Field(
+ proto.MESSAGE,
+ number=13,
+ message=gca_encryption_spec.EncryptionSpec,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1beta1/types/content.py b/google/cloud/aiplatform_v1beta1/types/content.py
index 9da0b87d36..438ab47ec8 100644
--- a/google/cloud/aiplatform_v1beta1/types/content.py
+++ b/google/cloud/aiplatform_v1beta1/types/content.py
@@ -73,7 +73,8 @@ class HarmCategory(proto.Enum):
The harm category is sexually explicit
content.
HARM_CATEGORY_CIVIC_INTEGRITY (5):
- The harm category is civic integrity.
+ Deprecated: Election filter is not longer
+ supported. The harm category is civic integrity.
"""
HARM_CATEGORY_UNSPECIFIED = 0
HARM_CATEGORY_HATE_SPEECH = 1
@@ -474,6 +475,10 @@ class GenerationConfig(proto.Message):
Optional. The speech generation config.
This field is a member of `oneof`_ ``_speech_config``.
+ thinking_config (google.cloud.aiplatform_v1beta1.types.GenerationConfig.ThinkingConfig):
+ Optional. Config for thinking features.
+ An error will be returned if this field is set
+ for models that don't support thinking.
model_config (google.cloud.aiplatform_v1beta1.types.GenerationConfig.ModelConfig):
Optional. Config for model selection.
"""
@@ -612,6 +617,25 @@ class ManualRoutingMode(proto.Message):
message="GenerationConfig.RoutingConfig.ManualRoutingMode",
)
+ class ThinkingConfig(proto.Message):
+ r"""Config for thinking features.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ thinking_budget (int):
+ Optional. Indicates the thinking budget in tokens. This is
+ only applied when enable_thinking is true.
+
+ This field is a member of `oneof`_ ``_thinking_budget``.
+ """
+
+ thinking_budget: int = proto.Field(
+ proto.INT32,
+ number=3,
+ optional=True,
+ )
+
class ModelConfig(proto.Message):
r"""Config for model selection.
@@ -736,6 +760,11 @@ class FeatureSelectionPreference(proto.Enum):
optional=True,
message="SpeechConfig",
)
+ thinking_config: ThinkingConfig = proto.Field(
+ proto.MESSAGE,
+ number=25,
+ message=ThinkingConfig,
+ )
model_config: ModelConfig = proto.Field(
proto.MESSAGE,
number=27,
diff --git a/google/cloud/aiplatform_v1beta1/types/dataset_service.py b/google/cloud/aiplatform_v1beta1/types/dataset_service.py
index 4760a85c17..4acfa55464 100644
--- a/google/cloud/aiplatform_v1beta1/types/dataset_service.py
+++ b/google/cloud/aiplatform_v1beta1/types/dataset_service.py
@@ -70,6 +70,7 @@
"AssessDataResponse",
"AssessDataOperationMetadata",
"GeminiTemplateConfig",
+ "GeminiRequestReadConfig",
"GeminiExample",
"AssembleDataRequest",
"AssembleDataResponse",
@@ -1179,6 +1180,9 @@ class AssessDataRequest(proto.Message):
Required. The name of the Dataset resource. Used only for
MULTIMODAL datasets. Format:
``projects/{project}/locations/{location}/datasets/{dataset}``
+ gemini_request_read_config (google.cloud.aiplatform_v1beta1.types.GeminiRequestReadConfig):
+ Optional. The Gemini request read config for
+ the dataset.
"""
class TuningValidationAssessmentConfig(proto.Message):
@@ -1302,6 +1306,11 @@ class BatchPredictionResourceUsageAssessmentConfig(proto.Message):
proto.STRING,
number=1,
)
+ gemini_request_read_config: "GeminiRequestReadConfig" = proto.Field(
+ proto.MESSAGE,
+ number=8,
+ message="GeminiRequestReadConfig",
+ )
class AssessDataResponse(proto.Message):
@@ -1450,7 +1459,7 @@ class GeminiTemplateConfig(proto.Message):
assembling the request to use for downstream
applications.
field_mapping (MutableMapping[str, str]):
- Required. Map of template params to the
+ Required. Map of template parameters to the
columns in the dataset table.
"""
@@ -1466,6 +1475,43 @@ class GeminiTemplateConfig(proto.Message):
)
+class GeminiRequestReadConfig(proto.Message):
+ r"""Configuration for how to read Gemini requests from a
+ multimodal dataset.
+
+ This message has `oneof`_ fields (mutually exclusive fields).
+ For each oneof, at most one member field can be set at the same time.
+ Setting any member of the oneof automatically clears all other
+ members.
+
+ .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields
+
+ Attributes:
+ template_config (google.cloud.aiplatform_v1beta1.types.GeminiTemplateConfig):
+ Gemini request template with placeholders.
+
+ This field is a member of `oneof`_ ``read_config``.
+ assembled_request_column_name (str):
+ Optional. Column name in the dataset table
+ that contains already fully assembled Gemini
+ requests.
+
+ This field is a member of `oneof`_ ``read_config``.
+ """
+
+ template_config: "GeminiTemplateConfig" = proto.Field(
+ proto.MESSAGE,
+ number=1,
+ oneof="read_config",
+ message="GeminiTemplateConfig",
+ )
+ assembled_request_column_name: str = proto.Field(
+ proto.STRING,
+ number=4,
+ oneof="read_config",
+ )
+
+
class GeminiExample(proto.Message):
r"""Format for Gemini examples used for Vertex Multimodal
datasets.
@@ -1607,6 +1653,8 @@ class AssembleDataRequest(proto.Message):
Required. The name of the Dataset resource (used only for
MULTIMODAL datasets). Format:
``projects/{project}/locations/{location}/datasets/{dataset}``
+ gemini_request_read_config (google.cloud.aiplatform_v1beta1.types.GeminiRequestReadConfig):
+ Optional. The read config for the dataset.
"""
gemini_template_config: "GeminiTemplateConfig" = proto.Field(
@@ -1624,6 +1672,11 @@ class AssembleDataRequest(proto.Message):
proto.STRING,
number=1,
)
+ gemini_request_read_config: "GeminiRequestReadConfig" = proto.Field(
+ proto.MESSAGE,
+ number=6,
+ message="GeminiRequestReadConfig",
+ )
class AssembleDataResponse(proto.Message):
diff --git a/google/cloud/aiplatform_v1beta1/types/io.py b/google/cloud/aiplatform_v1beta1/types/io.py
index f55a7fb21d..93e761fdfe 100644
--- a/google/cloud/aiplatform_v1beta1/types/io.py
+++ b/google/cloud/aiplatform_v1beta1/types/io.py
@@ -82,7 +82,7 @@ class GcsSource(proto.Message):
Required. Google Cloud Storage URI(-s) to the
input file(s). May contain wildcards. For more
information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
+ https://cloud.google.com/storage/docs/wildcards.
"""
uris: MutableSequence[str] = proto.RepeatedField(
diff --git a/google/cloud/aiplatform_v1beta1/types/model_garden_service.py b/google/cloud/aiplatform_v1beta1/types/model_garden_service.py
index 9b195a10c7..36da1c66c3 100644
--- a/google/cloud/aiplatform_v1beta1/types/model_garden_service.py
+++ b/google/cloud/aiplatform_v1beta1/types/model_garden_service.py
@@ -356,6 +356,10 @@ class DeployConfig(proto.Message):
fast_tryout_enabled (bool):
Optional. If true, enable the QMT fast tryout
feature for this model if possible.
+ system_labels (MutableMapping[str, str]):
+ Optional. System labels for Model Garden
+ deployments. These labels are managed by Google
+ and for tracking purposes only.
"""
dedicated_resources: machine_resources.DedicatedResources = proto.Field(
@@ -367,6 +371,11 @@ class DeployConfig(proto.Message):
proto.BOOL,
number=2,
)
+ system_labels: MutableMapping[str, str] = proto.MapField(
+ proto.STRING,
+ proto.STRING,
+ number=3,
+ )
publisher_model_name: str = proto.Field(
proto.STRING,
diff --git a/google/cloud/aiplatform_v1beta1/types/model_monitoring_spec.py b/google/cloud/aiplatform_v1beta1/types/model_monitoring_spec.py
index 4ee655ba70..fcab13f678 100644
--- a/google/cloud/aiplatform_v1beta1/types/model_monitoring_spec.py
+++ b/google/cloud/aiplatform_v1beta1/types/model_monitoring_spec.py
@@ -349,7 +349,7 @@ class ModelMonitoringGcsSource(proto.Message):
Google Cloud Storage URI to the input
file(s). May contain wildcards. For more
information on wildcards, see
- https://cloud.google.com/storage/docs/gsutil/addlhelp/WildcardNames.
+ https://cloud.google.com/storage/docs/wildcards.
format_ (google.cloud.aiplatform_v1beta1.types.ModelMonitoringInput.ModelMonitoringDataset.ModelMonitoringGcsSource.DataFormat):
Data format of the dataset.
"""
diff --git a/google/cloud/aiplatform_v1beta1/types/openapi.py b/google/cloud/aiplatform_v1beta1/types/openapi.py
index c5f3916a4c..0c1300840f 100644
--- a/google/cloud/aiplatform_v1beta1/types/openapi.py
+++ b/google/cloud/aiplatform_v1beta1/types/openapi.py
@@ -135,6 +135,23 @@ class Schema(proto.Message):
Optional. The value should be validated
against any (one or more) of the subschemas in
the list.
+ ref (str):
+ Optional. Allows indirect references between schema nodes.
+ The value should be a valid reference to a child of the root
+ ``defs``.
+
+ For example, the following schema defines a reference to a
+ schema node named "Pet":
+
+ type: object properties: pet: ref: #/defs/Pet defs: Pet:
+ type: object properties: name: type: string
+
+ The value of the "pet" property is a reference to the schema
+ node named "Pet". See details in
+ https://json-schema.org/understanding-json-schema/structuring
+ defs (MutableMapping[str, google.cloud.aiplatform_v1beta1.types.Schema]):
+ Optional. A map of definitions for use by ``ref`` Only
+ allowed at the root of the schema.
"""
type_: "Type" = proto.Field(
@@ -232,6 +249,16 @@ class Schema(proto.Message):
number=11,
message="Schema",
)
+ ref: str = proto.Field(
+ proto.STRING,
+ number=27,
+ )
+ defs: MutableMapping[str, "Schema"] = proto.MapField(
+ proto.STRING,
+ proto.MESSAGE,
+ number=28,
+ message="Schema",
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1beta1/types/session.py b/google/cloud/aiplatform_v1beta1/types/session.py
index 75ee6c7263..17063ba983 100644
--- a/google/cloud/aiplatform_v1beta1/types/session.py
+++ b/google/cloud/aiplatform_v1beta1/types/session.py
@@ -238,8 +238,8 @@ class EventActions(proto.Message):
updating an artifact. key is the filename, value
is the version.
transfer_to_agent (bool):
- Optional. If set, the event transfers to the
- specified agent.
+ Deprecated. If set, the event transfers to
+ the specified agent.
escalate (bool):
Optional. The agent is escalating to a higher
level agent.
@@ -251,6 +251,9 @@ class EventActions(proto.Message):
multiple function calls. Struct value is the
required auth config, which can be another
struct.
+ transfer_agent (str):
+ Optional. If set, the event transfers to the
+ specified agent.
"""
skip_summarization: bool = proto.Field(
@@ -280,6 +283,10 @@ class EventActions(proto.Message):
number=7,
message=struct_pb2.Struct,
)
+ transfer_agent: str = proto.Field(
+ proto.STRING,
+ number=8,
+ )
__all__ = tuple(sorted(__protobuf__.manifest))
diff --git a/google/cloud/aiplatform_v1beta1/types/session_service.py b/google/cloud/aiplatform_v1beta1/types/session_service.py
index 88243b1317..06aa0d576d 100644
--- a/google/cloud/aiplatform_v1beta1/types/session_service.py
+++ b/google/cloud/aiplatform_v1beta1/types/session_service.py
@@ -50,7 +50,6 @@ class CreateSessionRequest(proto.Message):
parent (str):
Required. The resource name of the location to create the
session in. Format:
- ``projects/{project}/locations/{location}`` or
``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}``
session (google.cloud.aiplatform_v1beta1.types.Session):
Required. The session to create.
@@ -220,8 +219,6 @@ class DeleteSessionRequest(proto.Message):
Attributes:
name (str):
Required. The resource name of the session. Format:
- ``projects/{project}/locations/{location}/sessions/{session}``
- or
``projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}``
"""
@@ -244,7 +241,8 @@ class ListEventsRequest(proto.Message):
Optional. The maximum number of events to
return. The service may return fewer than this
value. If unspecified, at most 100 events will
- be returned.
+ be returned. These events are ordered by
+ timestamp in ascending order.
page_token (str):
Optional. The
[next_page_token][google.cloud.aiplatform.v1beta1.ListEventsResponse.next_page_token]
@@ -274,6 +272,7 @@ class ListEventsResponse(proto.Message):
Attributes:
session_events (MutableSequence[google.cloud.aiplatform_v1beta1.types.SessionEvent]):
A list of events matching the request.
+ Ordered by timestamp in ascending order.
next_page_token (str):
A token, which can be sent as
[ListEventsRequest.page_token][google.cloud.aiplatform.v1beta1.ListEventsRequest.page_token]
diff --git a/owlbot.py b/owlbot.py
index 0f6378b815..55c08974e6 100644
--- a/owlbot.py
+++ b/owlbot.py
@@ -117,6 +117,8 @@
".kokoro/release.sh",
".kokoro/release/common.cfg",
".kokoro/requirements*",
+ ".kokoro/samples/python3.7/**",
+ ".kokoro/samples/python3.8/**",
# exclude sample configs so periodic samples are tested against main
# instead of pypi
".kokoro/samples/python3.9/common.cfg",
@@ -127,6 +129,8 @@
".kokoro/samples/python3.10/periodic.cfg",
".kokoro/samples/python3.11/periodic.cfg",
".kokoro/samples/python3.12/periodic.cfg",
+ ".github/auto-label.yaml",
+ ".github/blunderbuss.yml",
".github/CODEOWNERS",
".github/PULL_REQUEST_TEMPLATE.md",
".github/workflows", # exclude gh actions as credentials are needed for tests
@@ -161,41 +165,4 @@
"python3",
)
- # Update publish-docs to include gemini docs workflow.
- s.replace(
- ".kokoro/publish-docs.sh",
- "# build docs",
- """\
-# build Gemini docs
-nox -s gemini_docs
-# create metadata
-python3 -m docuploader create-metadata \\
- --name="vertexai" \\
- --version=$(python3 setup.py --version) \\
- --language=$(jq --raw-output '.language // empty' .repo-metadata.json) \\
- --distribution-name="google-cloud-vertexai" \\
- --product-page=$(jq --raw-output '.product_documentation // empty' .repo-metadata.json) \\
- --github-repository=$(jq --raw-output '.repo // empty' .repo-metadata.json) \\
- --issue-tracker=$(jq --raw-output '.issue_tracker // empty' .repo-metadata.json)
-cat docs.metadata
-# upload docs
-python3 -m docuploader upload gemini_docs/_build/html --metadata-file docs.metadata --staging-bucket "${STAGING_BUCKET}"
-# Gemini docfx yaml files
-nox -s gemini_docfx
-# create metadata.
-python3 -m docuploader create-metadata \\
- --name="vertexai" \\
- --version=$(python3 setup.py --version) \\
- --language=$(jq --raw-output '.language // empty' .repo-metadata.json) \\
- --distribution-name="google-cloud-vertexai" \\
- --product-page=$(jq --raw-output '.product_documentation // empty' .repo-metadata.json) \\
- --github-repository=$(jq --raw-output '.repo // empty' .repo-metadata.json) \\
- --issue-tracker=$(jq --raw-output '.issue_tracker // empty' .repo-metadata.json) \\
- --stem="/vertex-ai/generative-ai/docs/reference/python"
-cat docs.metadata
-# upload docs
-python3 -m docuploader upload gemini_docs/_build/html/docfx_yaml --metadata-file docs.metadata --destination-prefix docfx --staging-bucket "${V2_STAGING_BUCKET}"
-# build docs""",
- )
-
s.shell.run(["nox", "-s", "blacken"], hide_output=False)
diff --git a/pypi/_vertex_ai_placeholder/version.py b/pypi/_vertex_ai_placeholder/version.py
index 9600dc46d6..83d974d1fc 100644
--- a/pypi/_vertex_ai_placeholder/version.py
+++ b/pypi/_vertex_ai_placeholder/version.py
@@ -15,4 +15,4 @@
# limitations under the License.
#
-__version__ = "1.91.0"
+__version__ = "1.92.0"
diff --git a/renovate.json b/renovate.json
index 39b2a0ec92..c7875c469b 100644
--- a/renovate.json
+++ b/renovate.json
@@ -5,7 +5,7 @@
":preserveSemverRanges",
":disableDependencyDashboard"
],
- "ignorePaths": [".pre-commit-config.yaml", ".kokoro/requirements.txt", "setup.py"],
+ "ignorePaths": [".pre-commit-config.yaml", ".kokoro/requirements.txt", "setup.py", ".github/workflows/unittest.yml"],
"pip_requirements": {
"fileMatch": ["requirements-test.txt", "samples/[\\S/]*constraints.txt", "samples/[\\S/]*constraints-test.txt"]
}
diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
index d046841f64..1b98adb366 100644
--- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
+++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1.json
@@ -8,7 +8,7 @@
],
"language": "PYTHON",
"name": "google-cloud-aiplatform",
- "version": "1.91.0"
+ "version": "1.92.0"
},
"snippets": [
{
diff --git a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
index ba930474b0..ef15cef83c 100644
--- a/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
+++ b/samples/generated_samples/snippet_metadata_google.cloud.aiplatform.v1beta1.json
@@ -8,7 +8,7 @@
],
"language": "PYTHON",
"name": "google-cloud-aiplatform",
- "version": "1.91.0"
+ "version": "1.92.0"
},
"snippets": [
{
diff --git a/samples/model-builder/noxfile.py b/samples/model-builder/noxfile.py
index 483b559017..a169b5b5b4 100644
--- a/samples/model-builder/noxfile.py
+++ b/samples/model-builder/noxfile.py
@@ -89,7 +89,7 @@ def get_pytest_env_vars() -> Dict[str, str]:
# DO NOT EDIT - automatically generated.
# All versions used to test samples.
-ALL_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
+ALL_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
# Any default versions that should be ignored.
IGNORED_VERSIONS = TEST_CONFIG["ignored_versions"]
diff --git a/samples/snippets/noxfile.py b/samples/snippets/noxfile.py
index 483b559017..a169b5b5b4 100644
--- a/samples/snippets/noxfile.py
+++ b/samples/snippets/noxfile.py
@@ -89,7 +89,7 @@ def get_pytest_env_vars() -> Dict[str, str]:
# DO NOT EDIT - automatically generated.
# All versions used to test samples.
-ALL_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"]
+ALL_VERSIONS = ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
# Any default versions that should be ignored.
IGNORED_VERSIONS = TEST_CONFIG["ignored_versions"]
diff --git a/setup.py b/setup.py
index ba9bafc0b5..7ca103ce29 100644
--- a/setup.py
+++ b/setup.py
@@ -84,7 +84,7 @@
prediction_extra_require = [
"docker >= 5.0.3",
"fastapi >= 0.71.0, <=0.114.0",
- "httpx >=0.23.0, <0.25.0", # Optional dependency of fastapi
+ "httpx >=0.23.0, <=0.28.1", # Optional dependency of fastapi
"starlette >= 0.17.1",
"uvicorn[standard] >= 0.16.0",
]
@@ -103,7 +103,7 @@
ray_extra_require = [
# Cluster only supports 2.9.3, 2.33.0, and 2.42.0. Keep 2.4.0 for our
# testing environment.
- # Note that testing is submiting a job in a cluster with Ray 2.9.3 remotely.
+ # Note that testing is submitting a job in a cluster with Ray 2.9.3 remotely.
(
"ray[default] >= 2.4, <= 2.42.0,!= 2.5.*,!= 2.6.*,!= 2.7.*,!="
" 2.8.*,!=2.9.0,!=2.9.1,!=2.9.2, !=2.10.*, !=2.11.*, !=2.12.*, !=2.13.*, !="
@@ -258,8 +258,9 @@
"scikit-learn<1.6.0; python_version<='3.10'",
"scikit-learn; python_version>'3.10'",
# Lazy import requires > 2.12.0
- "tensorflow == 2.13.0; python_version<='3.11'",
- "tensorflow == 2.16.1; python_version>'3.11'",
+ "tensorflow == 2.14.1; python_version<='3.11'",
+ "tensorflow == 2.19.0; python_version>'3.11'",
+ "protobuf <= 5.29.4",
# TODO(jayceeli) torch 2.1.0 has conflict with pyfakefs, will check if
# future versions fix this issue
"torch >= 2.0.0, < 2.1.0; python_version<='3.11'",
@@ -304,6 +305,7 @@
"google-cloud-bigquery >= 1.15.0, < 4.0.0, !=3.20.0",
"google-cloud-resource-manager >= 1.3.3, < 3.0.0",
"shapely < 3.0.0",
+ "google-genai >= 1.0.0, <2.0.0",
)
+ genai_requires,
extras_require={
diff --git a/testing/constraints-3.12.txt b/testing/constraints-3.12.txt
index caa5b754bc..64a02b12c6 100644
--- a/testing/constraints-3.12.txt
+++ b/testing/constraints-3.12.txt
@@ -4,11 +4,11 @@
google-api-core==2.21.0 # Tests google-api-core with rest async support
google-auth==2.35.0 # Tests google-auth with rest async support
proto-plus
-protobuf>=6
mock==4.0.2
google-cloud-storage==2.2.1 # Increased for kfp 2.0 compatibility
packaging==24.1 # Increased to unbreak canonicalize_version error (b/377774673)
pytest-xdist==3.3.1 # Pinned to unbreak unit tests
ray==2.5.0 # Pinned until 2.9.3 is verified for Ray tests
ipython==8.22.2 # Pinned to unbreak TypeAliasType import error
-google-adk==0.0.2
\ No newline at end of file
+google-adk==0.0.2
+google-genai>=1.10.0
\ No newline at end of file
diff --git a/tests/system/aiplatform/test_persistent_resource.py b/tests/system/aiplatform/test_persistent_resource.py
index 66268d1818..fc8cc3dd24 100644
--- a/tests/system/aiplatform/test_persistent_resource.py
+++ b/tests/system/aiplatform/test_persistent_resource.py
@@ -27,11 +27,15 @@
persistent_resource_v1 as gca_persistent_resource,
)
from tests.system.aiplatform import e2e_base
+from google.cloud.aiplatform.tests.unit.aiplatform import constants as test_constants
import pytest
_TEST_MACHINE_TYPE = "n1-standard-4"
_TEST_INITIAL_REPLICA_COUNT = 2
+_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = (
+ test_constants.ProjectConstants._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE
+)
@pytest.mark.usefixtures("tear_down_resources")
@@ -59,7 +63,9 @@ def test_create_persistent_resource(self, shared_state):
]
test_resource = persistent_resource.PersistentResource.create(
- persistent_resource_id=resource_id, resource_pools=resource_pools
+ persistent_resource_id=resource_id,
+ resource_pools=resource_pools,
+ enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE,
)
shared_state["resources"] = [test_resource]
diff --git a/tests/unit/aiplatform/constants.py b/tests/unit/aiplatform/constants.py
index 8c3897141b..633659b449 100644
--- a/tests/unit/aiplatform/constants.py
+++ b/tests/unit/aiplatform/constants.py
@@ -59,6 +59,7 @@ class ProjectConstants:
_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}"
_TEST_SERVICE_ACCOUNT = "vinnys@my-project.iam.gserviceaccount.com"
_TEST_LABELS = {"my_key": "my_value"}
+ _TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = True
@dataclasses.dataclass(frozen=True)
diff --git a/tests/unit/aiplatform/test_endpoints.py b/tests/unit/aiplatform/test_endpoints.py
index a5d9fa08eb..14e27abbc4 100644
--- a/tests/unit/aiplatform/test_endpoints.py
+++ b/tests/unit/aiplatform/test_endpoints.py
@@ -97,6 +97,7 @@
_TEST_PROJECT_ALLOWLIST = [_TEST_PROJECT]
_TEST_ENDPOINT_OVERRIDE = "endpoint-override.aiplatform.vertex.goog"
+_TEST_SHARED_ENDPOINT_DNS = f"{_TEST_LOCATION}-aiplatform.googleapis.com"
_TEST_DEDICATED_ENDPOINT_DNS = (
f"{_TEST_ID}.{_TEST_PROJECT}.{_TEST_LOCATION}-aiplatform.vertex.goog"
)
@@ -2689,19 +2690,30 @@ def test_predict_dedicated_endpoint_with_timeout(self, predict_endpoint_http_moc
)
@pytest.mark.usefixtures("get_endpoint_mock")
- def test_predict_use_dedicated_endpoint_for_regular_endpoint(self):
- test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
+ def test_predict_use_dedicated_endpoint_for_regular_endpoint(
+ self, predict_client_predict_mock
+ ):
+ test_endpoint = models.Endpoint(_TEST_ID)
+ test_prediction = test_endpoint.predict(
+ instances=_TEST_INSTANCES,
+ parameters={"param": 3.0},
+ use_dedicated_endpoint=True,
+ )
- with pytest.raises(ValueError) as err:
- test_endpoint.predict(
- instances=_TEST_INSTANCES,
- parameters={"param": 3.0},
- use_dedicated_endpoint=True,
- )
- assert err.match(
- regexp=r"Dedicated endpoint is not enabled or DNS is empty."
- "Please make sure endpoint has dedicated endpoint enabled"
- "and model are ready before making a prediction."
+ true_prediction = models.Prediction(
+ predictions=_TEST_PREDICTION,
+ deployed_model_id=_TEST_ID,
+ metadata=_TEST_METADATA,
+ model_version_id=_TEST_VERSION_ID,
+ model_resource_name=_TEST_MODEL_NAME,
+ )
+
+ assert true_prediction == test_prediction
+ predict_client_predict_mock.assert_called_once_with(
+ endpoint=_TEST_ENDPOINT_NAME,
+ instances=_TEST_INSTANCES,
+ parameters={"param": 3.0},
+ timeout=None,
)
@pytest.mark.usefixtures("get_dedicated_endpoint_mock")
@@ -2768,19 +2780,35 @@ def test_raw_predict_dedicated_endpoint_with_timeout(
)
@pytest.mark.usefixtures("get_endpoint_mock")
- def test_raw_predict_use_dedicated_endpoint_for_regular_endpoint(self):
- test_endpoint = models.Endpoint(_TEST_ENDPOINT_NAME)
+ def test_raw_predict_use_dedicated_endpoint_for_regular_endpoint(
+ self, predict_endpoint_http_mock
+ ):
+ test_endpoint = models.Endpoint(_TEST_ID)
- with pytest.raises(ValueError) as err:
- test_endpoint.raw_predict(
- body=_TEST_RAW_INPUTS,
- headers={"Content-Type": "application/json"},
- use_dedicated_endpoint=True,
- )
- assert err.match(
- regexp=r"Dedicated endpoint is not enabled or DNS is empty."
- "Please make sure endpoint has dedicated endpoint enabled"
- "and model are ready before making a prediction."
+ test_prediction = test_endpoint.raw_predict(
+ body=_TEST_RAW_INPUTS,
+ headers={"Content-Type": "application/json"},
+ use_dedicated_endpoint=True,
+ )
+
+ true_prediction = requests.Response()
+ true_prediction.status_code = 200
+ true_prediction._content = json.dumps(
+ {
+ "predictions": _TEST_PREDICTION,
+ "metadata": _TEST_METADATA,
+ "deployedModelId": _TEST_DEPLOYED_MODELS[0].id,
+ "model": _TEST_MODEL_NAME,
+ "modelVersionId": "1",
+ }
+ ).encode("utf-8")
+ assert true_prediction.status_code == test_prediction.status_code
+ assert true_prediction.text == test_prediction.text
+ predict_endpoint_http_mock.assert_called_once_with(
+ url=f"https://{_TEST_SHARED_ENDPOINT_DNS}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:rawPredict",
+ data=_TEST_RAW_INPUTS,
+ headers={"Content-Type": "application/json"},
+ timeout=None,
)
@pytest.mark.asyncio
@@ -3658,7 +3686,7 @@ def test_psa_predict(self, predict_private_endpoint_mock):
)
@pytest.mark.usefixtures("get_psc_private_endpoint_mock")
- def test_psc_predict(self, predict_private_endpoint_mock):
+ def test_psc_predict(self, predict_endpoint_http_mock):
test_endpoint = models.PrivateEndpoint(
project=_TEST_PROJECT, location=_TEST_LOCATION, endpoint_name=_TEST_ID
)
@@ -3677,14 +3705,10 @@ def test_psc_predict(self, predict_private_endpoint_mock):
)
assert true_prediction == test_prediction
- predict_private_endpoint_mock.assert_called_once_with(
- method="POST",
+ predict_endpoint_http_mock.assert_called_once_with(
url=f"https://{_TEST_ENDPOINT_OVERRIDE}/v1/projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}:predict",
- body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}',
- headers={
- "Content-Type": "application/json",
- "Authorization": "Bearer None",
- },
+ data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]], "parameters": {"param": 3.0}}',
+ headers={"Content-Type": "application/json"},
)
@pytest.mark.usefixtures("get_psc_private_endpoint_mock")
@@ -3697,7 +3721,6 @@ def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock):
body='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}',
headers={
"Content-Type": "application/json",
- "Authorization": "Bearer None",
},
endpoint_override=_TEST_ENDPOINT_OVERRIDE,
)
@@ -3709,7 +3732,6 @@ def test_psc_stream_raw_predict(self, stream_raw_predict_private_endpoint_mock):
data='{"instances": [[1.0, 2.0, 3.0], [1.0, 3.0, 4.0]]}',
headers={
"Content-Type": "application/json",
- "Authorization": "Bearer None",
},
stream=True,
verify=False,
diff --git a/tests/unit/aiplatform/test_persistent_resource.py b/tests/unit/aiplatform/test_persistent_resource.py
index 55c460b113..fe6e2bbd90 100644
--- a/tests/unit/aiplatform/test_persistent_resource.py
+++ b/tests/unit/aiplatform/test_persistent_resource.py
@@ -49,6 +49,10 @@
_TEST_RESERVED_IP_RANGES = test_constants.TrainingJobConstants._TEST_RESERVED_IP_RANGES
_TEST_KEY_NAME = test_constants.TrainingJobConstants._TEST_DEFAULT_ENCRYPTION_KEY_NAME
_TEST_SERVICE_ACCOUNT = test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT
+_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = (
+ test_constants.ProjectConstants._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE
+)
+
_TEST_PERSISTENT_RESOURCE_PROTO = persistent_resource_v1.PersistentResource(
name=_TEST_PERSISTENT_RESOURCE_ID,
@@ -298,7 +302,7 @@ def test_create_persistent_resource_with_kms_key(
)
@pytest.mark.parametrize("sync", [True, False])
- def test_create_persistent_resource_with_service_account(
+ def test_create_persistent_resource_enable_custom_sa_true_with_sa(
self,
create_persistent_resource_mock,
get_persistent_resource_mock,
@@ -309,6 +313,7 @@ def test_create_persistent_resource_with_service_account(
resource_pools=[
test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
],
+ enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE,
service_account=_TEST_SERVICE_ACCOUNT,
sync=sync,
)
@@ -321,7 +326,8 @@ def test_create_persistent_resource_with_service_account(
)
service_account_spec = persistent_resource_v1.ServiceAccountSpec(
- enable_custom_service_account=True, service_account=_TEST_SERVICE_ACCOUNT
+ enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE,
+ service_account=_TEST_SERVICE_ACCOUNT,
)
expected_persistent_resource_arg.resource_runtime_spec = (
persistent_resource_v1.ResourceRuntimeSpec(
@@ -341,6 +347,164 @@ def test_create_persistent_resource_with_service_account(
name=_TEST_PERSISTENT_RESOURCE_ID
)
+ @pytest.mark.parametrize("sync", [True, False])
+ def test_create_persistent_resource_enable_custom_sa_true_no_sa(
+ self,
+ create_persistent_resource_mock,
+ get_persistent_resource_mock,
+ sync,
+ ):
+ my_test_resource = persistent_resource.PersistentResource.create(
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ resource_pools=[
+ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
+ ],
+ enable_custom_service_account=True,
+ sync=sync,
+ )
+
+ if not sync:
+ my_test_resource.wait()
+
+ expected_persistent_resource_arg = _get_persistent_resource_proto(
+ name=_TEST_PERSISTENT_RESOURCE_ID,
+ )
+ service_account_spec = persistent_resource_v1.ServiceAccountSpec(
+ enable_custom_service_account=True,
+ service_account=None,
+ )
+ expected_persistent_resource_arg.resource_runtime_spec = (
+ persistent_resource_v1.ResourceRuntimeSpec(
+ service_account_spec=service_account_spec
+ )
+ )
+
+ create_persistent_resource_mock.assert_called_once_with(
+ parent=_TEST_PARENT,
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ persistent_resource=expected_persistent_resource_arg,
+ timeout=None,
+ )
+ get_persistent_resource_mock.assert_called_once()
+ _, mock_kwargs = get_persistent_resource_mock.call_args
+ assert mock_kwargs["name"] == _get_resource_name(
+ name=_TEST_PERSISTENT_RESOURCE_ID
+ )
+
+ @pytest.mark.parametrize("sync", [True, False])
+ def test_create_persistent_resource_enable_custom_sa_false_raises_error(
+ self,
+ create_persistent_resource_mock,
+ get_persistent_resource_mock,
+ sync,
+ ):
+ with pytest.raises(ValueError) as excinfo:
+ my_test_resource = persistent_resource.PersistentResource.create(
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ resource_pools=[
+ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
+ ],
+ enable_custom_service_account=False,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ sync=sync,
+ )
+ if not sync:
+ my_test_resource.wait()
+
+ assert str(excinfo.value) == (
+ "The parameter `enable_custom_service_account` was set to False, "
+ "but a value was provided for `service_account`. These two "
+ "settings are incompatible. If you want to use a custom "
+ "service account, set `enable_custom_service_account` to True."
+ )
+
+ create_persistent_resource_mock.assert_not_called()
+ get_persistent_resource_mock.assert_not_called()
+
+ @pytest.mark.parametrize("sync", [True, False])
+ def test_create_persistent_resource_enable_custom_sa_none_with_sa(
+ self,
+ create_persistent_resource_mock,
+ get_persistent_resource_mock,
+ sync,
+ ):
+ my_test_resource = persistent_resource.PersistentResource.create(
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ resource_pools=[
+ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
+ ],
+ enable_custom_service_account=None,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ sync=sync,
+ )
+
+ if not sync:
+ my_test_resource.wait()
+
+ expected_persistent_resource_arg = _get_persistent_resource_proto(
+ name=_TEST_PERSISTENT_RESOURCE_ID,
+ )
+ service_account_spec = persistent_resource_v1.ServiceAccountSpec(
+ enable_custom_service_account=True,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ )
+ expected_persistent_resource_arg.resource_runtime_spec = (
+ persistent_resource_v1.ResourceRuntimeSpec(
+ service_account_spec=service_account_spec
+ )
+ )
+
+ create_persistent_resource_mock.assert_called_once_with(
+ parent=_TEST_PARENT,
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ persistent_resource=expected_persistent_resource_arg,
+ timeout=None,
+ )
+ get_persistent_resource_mock.assert_called_once()
+ _, mock_kwargs = get_persistent_resource_mock.call_args
+ assert mock_kwargs["name"] == _get_resource_name(
+ name=_TEST_PERSISTENT_RESOURCE_ID
+ )
+
+ @pytest.mark.parametrize("sync", [True, False])
+ def test_create_persistent_resource_enable_custom_sa_none_no_sa(
+ self,
+ create_persistent_resource_mock,
+ get_persistent_resource_mock,
+ sync,
+ ):
+ my_test_resource = persistent_resource.PersistentResource.create(
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ resource_pools=[
+ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
+ ],
+ enable_custom_service_account=None,
+ sync=sync,
+ )
+
+ if not sync:
+ my_test_resource.wait()
+
+ expected_persistent_resource_arg = _get_persistent_resource_proto(
+ name=_TEST_PERSISTENT_RESOURCE_ID,
+ )
+
+ # Assert that resource_runtime_spec is NOT set
+ call_args = create_persistent_resource_mock.call_args.kwargs
+ assert "resource_runtime_spec" not in call_args["persistent_resource"]
+
+ create_persistent_resource_mock.assert_called_once_with(
+ parent=_TEST_PARENT,
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ persistent_resource=expected_persistent_resource_arg,
+ timeout=None,
+ )
+ get_persistent_resource_mock.assert_called_once()
+ _, mock_kwargs = get_persistent_resource_mock.call_args
+ assert mock_kwargs["name"] == _get_resource_name(
+ name=_TEST_PERSISTENT_RESOURCE_ID
+ )
+
def test_list_persistent_resources(self, list_persistent_resources_mock):
resource_list = persistent_resource.PersistentResource.list()
diff --git a/tests/unit/aiplatform/test_persistent_resource_preview.py b/tests/unit/aiplatform/test_persistent_resource_preview.py
index f189af2a8b..d45bf51816 100644
--- a/tests/unit/aiplatform/test_persistent_resource_preview.py
+++ b/tests/unit/aiplatform/test_persistent_resource_preview.py
@@ -56,6 +56,9 @@
_TEST_RESERVED_IP_RANGES = test_constants.TrainingJobConstants._TEST_RESERVED_IP_RANGES
_TEST_KEY_NAME = test_constants.TrainingJobConstants._TEST_DEFAULT_ENCRYPTION_KEY_NAME
_TEST_SERVICE_ACCOUNT = test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT
+_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE = (
+ test_constants.ProjectConstants._TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE
+)
_TEST_PERSISTENT_RESOURCE_PROTO = persistent_resource_compat.PersistentResource(
name=_TEST_PERSISTENT_RESOURCE_ID,
@@ -291,7 +294,7 @@ def test_create_persistent_resource_with_kms_key(
)
@pytest.mark.parametrize("sync", [True, False])
- def test_create_persistent_resource_with_service_account(
+ def test_create_persistent_resource_enable_custom_sa_true_with_sa(
self,
create_preview_persistent_resource_mock,
get_persistent_resource_mock,
@@ -302,6 +305,7 @@ def test_create_persistent_resource_with_service_account(
resource_pools=[
test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
],
+ enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE,
service_account=_TEST_SERVICE_ACCOUNT,
sync=sync,
)
@@ -312,9 +316,53 @@ def test_create_persistent_resource_with_service_account(
expected_persistent_resource_arg = _get_persistent_resource_proto(
name=_TEST_PERSISTENT_RESOURCE_ID,
)
+ service_account_spec = persistent_resource_compat.ServiceAccountSpec(
+ enable_custom_service_account=_TEST_ENABLE_CUSTOM_SERVICE_ACCOUNT_TRUE,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ )
+ expected_persistent_resource_arg.resource_runtime_spec = (
+ persistent_resource_compat.ResourceRuntimeSpec(
+ service_account_spec=service_account_spec
+ )
+ )
+
+ create_preview_persistent_resource_mock.assert_called_once_with(
+ parent=_TEST_PARENT,
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ persistent_resource=expected_persistent_resource_arg,
+ timeout=None,
+ )
+ get_persistent_resource_mock.assert_called_once()
+ _, mock_kwargs = get_persistent_resource_mock.call_args
+ assert mock_kwargs["name"] == _get_resource_name(
+ name=_TEST_PERSISTENT_RESOURCE_ID
+ )
+
+ @pytest.mark.parametrize("sync", [True, False])
+ def test_create_persistent_resource_enable_custom_sa_true_no_sa(
+ self,
+ create_preview_persistent_resource_mock,
+ get_persistent_resource_mock,
+ sync,
+ ):
+ my_test_resource = persistent_resource.PersistentResource.create(
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ resource_pools=[
+ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
+ ],
+ enable_custom_service_account=True,
+ sync=sync,
+ )
+
+ if not sync:
+ my_test_resource.wait()
+ expected_persistent_resource_arg = _get_persistent_resource_proto(
+ name=_TEST_PERSISTENT_RESOURCE_ID,
+ )
service_account_spec = persistent_resource_compat.ServiceAccountSpec(
- enable_custom_service_account=True, service_account=_TEST_SERVICE_ACCOUNT
+ enable_custom_service_account=True,
+ service_account=None,
)
expected_persistent_resource_arg.resource_runtime_spec = (
persistent_resource_compat.ResourceRuntimeSpec(
@@ -334,6 +382,120 @@ def test_create_persistent_resource_with_service_account(
name=_TEST_PERSISTENT_RESOURCE_ID
)
+ @pytest.mark.parametrize("sync", [True, False])
+ def test_create_persistent_resource_enable_custom_sa_false_with_sa_raises_error(
+ self,
+ create_preview_persistent_resource_mock,
+ get_persistent_resource_mock,
+ sync,
+ ):
+ with pytest.raises(ValueError) as excinfo:
+ my_test_resource = persistent_resource.PersistentResource.create(
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ resource_pools=[
+ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
+ ],
+ enable_custom_service_account=False,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ sync=sync,
+ )
+ if not sync:
+ my_test_resource.wait()
+
+ assert str(excinfo.value) == (
+ "The parameter `enable_custom_service_account` was set to False, "
+ "but a value was provided for `service_account`. These two "
+ "settings are incompatible. If you want to use a custom "
+ "service account, set `enable_custom_service_account` to True."
+ )
+
+ create_preview_persistent_resource_mock.assert_not_called()
+ get_persistent_resource_mock.assert_not_called()
+
+ @pytest.mark.parametrize("sync", [True, False])
+ def test_create_persistent_resource_enable_custom_sa_none_with_sa(
+ self,
+ create_preview_persistent_resource_mock,
+ get_persistent_resource_mock,
+ sync,
+ ):
+ my_test_resource = persistent_resource.PersistentResource.create(
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ resource_pools=[
+ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
+ ],
+ enable_custom_service_account=None,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ sync=sync,
+ )
+
+ if not sync:
+ my_test_resource.wait()
+
+ expected_persistent_resource_arg = _get_persistent_resource_proto(
+ name=_TEST_PERSISTENT_RESOURCE_ID,
+ )
+ service_account_spec = persistent_resource_compat.ServiceAccountSpec(
+ enable_custom_service_account=True,
+ service_account=_TEST_SERVICE_ACCOUNT,
+ )
+ expected_persistent_resource_arg.resource_runtime_spec = (
+ persistent_resource_compat.ResourceRuntimeSpec(
+ service_account_spec=service_account_spec
+ )
+ )
+
+ create_preview_persistent_resource_mock.assert_called_once_with(
+ parent=_TEST_PARENT,
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ persistent_resource=expected_persistent_resource_arg,
+ timeout=None,
+ )
+ get_persistent_resource_mock.assert_called_once()
+ _, mock_kwargs = get_persistent_resource_mock.call_args
+ assert mock_kwargs["name"] == _get_resource_name(
+ name=_TEST_PERSISTENT_RESOURCE_ID
+ )
+
+ @pytest.mark.parametrize("sync", [True, False])
+ def test_create_persistent_resource_enable_custom_sa_none_no_sa(
+ self,
+ create_preview_persistent_resource_mock,
+ get_persistent_resource_mock,
+ sync,
+ ):
+ my_test_resource = persistent_resource.PersistentResource.create(
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ resource_pools=[
+ test_constants.PersistentResourceConstants._TEST_RESOURCE_POOL,
+ ],
+ enable_custom_service_account=None,
+ sync=sync,
+ )
+
+ if not sync:
+ my_test_resource.wait()
+
+ expected_persistent_resource_arg = _get_persistent_resource_proto(
+ name=_TEST_PERSISTENT_RESOURCE_ID,
+ )
+
+ # Assert that resource_runtime_spec is NOT set
+ call_args = create_preview_persistent_resource_mock.call_args.kwargs
+ assert "resource_runtime_spec" not in call_args["persistent_resource"]
+
+ create_preview_persistent_resource_mock.assert_called_once_with(
+ parent=_TEST_PARENT,
+ persistent_resource_id=_TEST_PERSISTENT_RESOURCE_ID,
+ persistent_resource=expected_persistent_resource_arg,
+ timeout=None,
+ )
+ get_persistent_resource_mock.assert_called_once()
+ _, mock_kwargs = get_persistent_resource_mock.call_args
+ assert mock_kwargs["name"] == _get_resource_name(
+ name=_TEST_PERSISTENT_RESOURCE_ID
+ )
+
def test_list_persistent_resources(self, list_persistent_resources_mock):
resource_list = persistent_resource.PersistentResource.list()
diff --git a/tests/unit/gapic/aiplatform_v1/test_gen_ai_cache_service.py b/tests/unit/gapic/aiplatform_v1/test_gen_ai_cache_service.py
index ce2fe8eddd..20effd6e03 100644
--- a/tests/unit/gapic/aiplatform_v1/test_gen_ai_cache_service.py
+++ b/tests/unit/gapic/aiplatform_v1/test_gen_ai_cache_service.py
@@ -73,6 +73,7 @@
from google.cloud.aiplatform_v1.types import cached_content
from google.cloud.aiplatform_v1.types import cached_content as gca_cached_content
from google.cloud.aiplatform_v1.types import content
+from google.cloud.aiplatform_v1.types import encryption_spec
from google.cloud.aiplatform_v1.types import gen_ai_cache_service
from google.cloud.aiplatform_v1.types import openapi
from google.cloud.aiplatform_v1.types import tool
@@ -4665,6 +4666,8 @@ def test_create_cached_content_rest_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
}
@@ -4731,6 +4734,7 @@ def test_create_cached_content_rest_call_success(request_type):
"video_duration_seconds": 2346,
"audio_duration_seconds": 2341,
},
+ "encryption_spec": {"kms_key_name": "kms_key_name_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
@@ -5144,6 +5148,8 @@ def test_update_cached_content_rest_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
}
@@ -5210,6 +5216,7 @@ def test_update_cached_content_rest_call_success(request_type):
"video_duration_seconds": 2346,
"audio_duration_seconds": 2341,
},
+ "encryption_spec": {"kms_key_name": "kms_key_name_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
@@ -6485,6 +6492,8 @@ async def test_create_cached_content_rest_asyncio_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
}
@@ -6551,6 +6560,7 @@ async def test_create_cached_content_rest_asyncio_call_success(request_type):
"video_duration_seconds": 2346,
"audio_duration_seconds": 2341,
},
+ "encryption_spec": {"kms_key_name": "kms_key_name_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
@@ -6996,6 +7006,8 @@ async def test_update_cached_content_rest_asyncio_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
}
@@ -7062,6 +7074,7 @@ async def test_update_cached_content_rest_asyncio_call_success(request_type):
"video_duration_seconds": 2346,
"audio_duration_seconds": 2341,
},
+ "encryption_spec": {"kms_key_name": "kms_key_name_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
diff --git a/tests/unit/gapic/aiplatform_v1/test_migration_service.py b/tests/unit/gapic/aiplatform_v1/test_migration_service.py
index 9e8dced432..f0c5ab4866 100644
--- a/tests/unit/gapic/aiplatform_v1/test_migration_service.py
+++ b/tests/unit/gapic/aiplatform_v1/test_migration_service.py
@@ -5398,22 +5398,19 @@ def test_parse_annotated_dataset_path():
def test_dataset_path():
project = "cuttlefish"
- location = "mussel"
- dataset = "winkle"
- expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(
+ dataset = "mussel"
+ expected = "projects/{project}/datasets/{dataset}".format(
project=project,
- location=location,
dataset=dataset,
)
- actual = MigrationServiceClient.dataset_path(project, location, dataset)
+ actual = MigrationServiceClient.dataset_path(project, dataset)
assert expected == actual
def test_parse_dataset_path():
expected = {
- "project": "nautilus",
- "location": "scallop",
- "dataset": "abalone",
+ "project": "winkle",
+ "dataset": "nautilus",
}
path = MigrationServiceClient.dataset_path(**expected)
@@ -5424,18 +5421,21 @@ def test_parse_dataset_path():
def test_dataset_path():
project = "squid"
- dataset = "clam"
- expected = "projects/{project}/datasets/{dataset}".format(
+ location = "clam"
+ dataset = "whelk"
+ expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(
project=project,
+ location=location,
dataset=dataset,
)
- actual = MigrationServiceClient.dataset_path(project, dataset)
+ actual = MigrationServiceClient.dataset_path(project, location, dataset)
assert expected == actual
def test_parse_dataset_path():
expected = {
- "project": "whelk",
+ "project": "clam",
+ "location": "whelk",
"dataset": "octopus",
}
path = MigrationServiceClient.dataset_path(**expected)
@@ -5446,22 +5446,19 @@ def test_parse_dataset_path():
def test_dataset_path():
- project = "oyster"
- location = "nudibranch"
- dataset = "cuttlefish"
- expected = "projects/{project}/locations/{location}/datasets/{dataset}".format(
+ project = "cuttlefish"
+ dataset = "mussel"
+ expected = "projects/{project}/datasets/{dataset}".format(
project=project,
- location=location,
dataset=dataset,
)
- actual = MigrationServiceClient.dataset_path(project, location, dataset)
+ actual = MigrationServiceClient.dataset_path(project, dataset)
assert expected == actual
def test_parse_dataset_path():
expected = {
- "project": "mussel",
- "location": "winkle",
+ "project": "winkle",
"dataset": "nautilus",
}
path = MigrationServiceClient.dataset_path(**expected)
diff --git a/tests/unit/gapic/aiplatform_v1/test_model_service.py b/tests/unit/gapic/aiplatform_v1/test_model_service.py
index ea52d506db..4a299dca9c 100644
--- a/tests/unit/gapic/aiplatform_v1/test_model_service.py
+++ b/tests/unit/gapic/aiplatform_v1/test_model_service.py
@@ -14989,6 +14989,9 @@ def test_update_model_rest_call_success(request_type):
},
"satisfies_pzs": True,
"satisfies_pzi": True,
+ "checkpoints": [
+ {"checkpoint_id": "checkpoint_id_value", "epoch": 527, "step": 444}
+ ],
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
@@ -18970,6 +18973,9 @@ async def test_update_model_rest_asyncio_call_success(request_type):
},
"satisfies_pzs": True,
"satisfies_pzi": True,
+ "checkpoints": [
+ {"checkpoint_id": "checkpoint_id_value", "epoch": 527, "step": 444}
+ ],
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
diff --git a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py
index 9bf770ce4e..d60c2e7eb5 100644
--- a/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py
+++ b/tests/unit/gapic/aiplatform_v1/test_pipeline_service.py
@@ -9226,6 +9226,9 @@ def test_create_training_pipeline_rest_call_success(request_type):
},
"satisfies_pzs": True,
"satisfies_pzi": True,
+ "checkpoints": [
+ {"checkpoint_id": "checkpoint_id_value", "epoch": 527, "step": 444}
+ ],
},
"model_id": "model_id_value",
"parent_model": "parent_model_value",
@@ -12161,6 +12164,9 @@ async def test_create_training_pipeline_rest_asyncio_call_success(request_type):
},
"satisfies_pzs": True,
"satisfies_pzi": True,
+ "checkpoints": [
+ {"checkpoint_id": "checkpoint_id_value", "epoch": 527, "step": 444}
+ ],
},
"model_id": "model_id_value",
"parent_model": "parent_model_value",
diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py
index 5123fc4e2d..0288684a06 100644
--- a/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py
+++ b/tests/unit/gapic/aiplatform_v1beta1/test_extension_registry_service.py
@@ -4577,6 +4577,8 @@ def test_import_extension_rest_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
},
@@ -5153,6 +5155,8 @@ def test_update_extension_rest_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
},
@@ -6359,6 +6363,8 @@ async def test_import_extension_rest_asyncio_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
},
@@ -6983,6 +6989,8 @@ async def test_update_extension_rest_asyncio_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
},
diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py
index 5ad43802b0..13ad36d108 100644
--- a/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py
+++ b/tests/unit/gapic/aiplatform_v1beta1/test_gen_ai_cache_service.py
@@ -73,6 +73,7 @@
from google.cloud.aiplatform_v1beta1.types import cached_content
from google.cloud.aiplatform_v1beta1.types import cached_content as gca_cached_content
from google.cloud.aiplatform_v1beta1.types import content
+from google.cloud.aiplatform_v1beta1.types import encryption_spec
from google.cloud.aiplatform_v1beta1.types import gen_ai_cache_service
from google.cloud.aiplatform_v1beta1.types import openapi
from google.cloud.aiplatform_v1beta1.types import tool
@@ -4674,6 +4675,8 @@ def test_create_cached_content_rest_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
}
@@ -4742,6 +4745,7 @@ def test_create_cached_content_rest_call_success(request_type):
"video_duration_seconds": 2346,
"audio_duration_seconds": 2341,
},
+ "encryption_spec": {"kms_key_name": "kms_key_name_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
@@ -5164,6 +5168,8 @@ def test_update_cached_content_rest_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
}
@@ -5232,6 +5238,7 @@ def test_update_cached_content_rest_call_success(request_type):
"video_duration_seconds": 2346,
"audio_duration_seconds": 2341,
},
+ "encryption_spec": {"kms_key_name": "kms_key_name_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
@@ -6516,6 +6523,8 @@ async def test_create_cached_content_rest_asyncio_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
}
@@ -6584,6 +6593,7 @@ async def test_create_cached_content_rest_asyncio_call_success(request_type):
"video_duration_seconds": 2346,
"audio_duration_seconds": 2341,
},
+ "encryption_spec": {"kms_key_name": "kms_key_name_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
@@ -7038,6 +7048,8 @@ async def test_update_cached_content_rest_asyncio_call_success(request_type):
"pattern": "pattern_value",
"example": {},
"any_of": {},
+ "ref": "ref_value",
+ "defs": {},
},
"response": {},
}
@@ -7106,6 +7118,7 @@ async def test_update_cached_content_rest_asyncio_call_success(request_type):
"video_duration_seconds": 2346,
"audio_duration_seconds": 2341,
},
+ "encryption_spec": {"kms_key_name": "kms_key_name_value"},
}
# The version of a generated dependency at test runtime may differ from the version used during generation.
# Delete any fields which are not present in the current runtime dependency
diff --git a/tests/unit/gapic/aiplatform_v1beta1/test_session_service.py b/tests/unit/gapic/aiplatform_v1beta1/test_session_service.py
index 99c757d313..66289a380b 100644
--- a/tests/unit/gapic/aiplatform_v1beta1/test_session_service.py
+++ b/tests/unit/gapic/aiplatform_v1beta1/test_session_service.py
@@ -3980,7 +3980,9 @@ def test_create_session_rest_flattened():
return_value = operations_pb2.Operation(name="operations/spam")
# get arguments that satisfy an http rule for this method
- sample_request = {"parent": "projects/sample1/locations/sample2"}
+ sample_request = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
# get truthy value for each flattened field
mock_args = dict(
@@ -4004,7 +4006,7 @@ def test_create_session_rest_flattened():
assert len(req.mock_calls) == 1
_, args, _ = req.mock_calls[0]
assert path_template.validate(
- "%s/v1beta1/{parent=projects/*/locations/*}/sessions"
+ "%s/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*}/sessions"
% client.transport._host,
args[1],
)
@@ -4158,7 +4160,9 @@ def test_get_session_rest_flattened():
return_value = session.Session()
# get arguments that satisfy an http rule for this method
- sample_request = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ sample_request = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
# get truthy value for each flattened field
mock_args = dict(
@@ -4183,7 +4187,7 @@ def test_get_session_rest_flattened():
assert len(req.mock_calls) == 1
_, args, _ = req.mock_calls[0]
assert path_template.validate(
- "%s/v1beta1/{name=projects/*/locations/*/sessions/*}"
+ "%s/v1beta1/{name=projects/*/locations/*/reasoningEngines/*/sessions/*}"
% client.transport._host,
args[1],
)
@@ -4355,7 +4359,9 @@ def test_list_sessions_rest_flattened():
return_value = session_service.ListSessionsResponse()
# get arguments that satisfy an http rule for this method
- sample_request = {"parent": "projects/sample1/locations/sample2"}
+ sample_request = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
# get truthy value for each flattened field
mock_args = dict(
@@ -4380,7 +4386,7 @@ def test_list_sessions_rest_flattened():
assert len(req.mock_calls) == 1
_, args, _ = req.mock_calls[0]
assert path_template.validate(
- "%s/v1beta1/{parent=projects/*/locations/*}/sessions"
+ "%s/v1beta1/{parent=projects/*/locations/*/reasoningEngines/*}/sessions"
% client.transport._host,
args[1],
)
@@ -4451,7 +4457,9 @@ def test_list_sessions_rest_pager(transport: str = "rest"):
return_val.status_code = 200
req.side_effect = return_values
- sample_request = {"parent": "projects/sample1/locations/sample2"}
+ sample_request = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
pager = client.list_sessions(request=sample_request)
@@ -4595,7 +4603,9 @@ def test_update_session_rest_flattened():
# get arguments that satisfy an http rule for this method
sample_request = {
- "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ "session": {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
}
# get truthy value for each flattened field
@@ -4622,7 +4632,7 @@ def test_update_session_rest_flattened():
assert len(req.mock_calls) == 1
_, args, _ = req.mock_calls[0]
assert path_template.validate(
- "%s/v1beta1/{session.name=projects/*/locations/*/sessions/*}"
+ "%s/v1beta1/{session.name=projects/*/locations/*/reasoningEngines/*/sessions/*}"
% client.transport._host,
args[1],
)
@@ -4777,7 +4787,9 @@ def test_delete_session_rest_flattened():
return_value = operations_pb2.Operation(name="operations/spam")
# get arguments that satisfy an http rule for this method
- sample_request = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ sample_request = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
# get truthy value for each flattened field
mock_args = dict(
@@ -4800,7 +4812,7 @@ def test_delete_session_rest_flattened():
assert len(req.mock_calls) == 1
_, args, _ = req.mock_calls[0]
assert path_template.validate(
- "%s/v1beta1/{name=projects/*/locations/*/sessions/*}"
+ "%s/v1beta1/{name=projects/*/locations/*/reasoningEngines/*/sessions/*}"
% client.transport._host,
args[1],
)
@@ -5740,7 +5752,9 @@ def test_create_session_rest_bad_request(
credentials=ga_credentials.AnonymousCredentials(), transport="rest"
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -5771,7 +5785,9 @@ def test_create_session_rest_call_success(request_type):
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
request_init["session"] = {
"name": "name_value",
"create_time": {"seconds": 751, "nanos": 543},
@@ -5936,7 +5952,9 @@ def test_get_session_rest_bad_request(request_type=session_service.GetSessionReq
credentials=ga_credentials.AnonymousCredentials(), transport="rest"
)
# send a request that will satisfy transcoding
- request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ request_init = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -5967,7 +5985,9 @@ def test_get_session_rest_call_success(request_type):
)
# send a request that will satisfy transcoding
- request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ request_init = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a response.
@@ -6067,7 +6087,9 @@ def test_list_sessions_rest_bad_request(
credentials=ga_credentials.AnonymousCredentials(), transport="rest"
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -6098,7 +6120,9 @@ def test_list_sessions_rest_call_success(request_type):
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a response.
@@ -6200,7 +6224,9 @@ def test_update_session_rest_bad_request(
)
# send a request that will satisfy transcoding
request_init = {
- "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ "session": {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
}
request = request_type(**request_init)
@@ -6233,10 +6259,12 @@ def test_update_session_rest_call_success(request_type):
# send a request that will satisfy transcoding
request_init = {
- "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ "session": {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
}
request_init["session"] = {
- "name": "projects/sample1/locations/sample2/sessions/sample3",
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4",
"create_time": {"seconds": 751, "nanos": 543},
"update_time": {},
"display_name": "display_name_value",
@@ -6409,7 +6437,9 @@ def test_delete_session_rest_bad_request(
credentials=ga_credentials.AnonymousCredentials(), transport="rest"
)
# send a request that will satisfy transcoding
- request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ request_init = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -6440,7 +6470,9 @@ def test_delete_session_rest_call_success(request_type):
)
# send a request that will satisfy transcoding
- request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ request_init = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a response.
@@ -6741,6 +6773,7 @@ def test_append_event_rest_call_success(request_type):
"transfer_to_agent": True,
"escalate": True,
"requested_auth_configs": {},
+ "transfer_agent": "transfer_agent_value",
},
"timestamp": {"seconds": 751, "nanos": 543},
"error_code": "error_code_value",
@@ -7766,7 +7799,9 @@ async def test_create_session_rest_asyncio_bad_request(
credentials=async_anonymous_credentials(), transport="rest_asyncio"
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -7801,7 +7836,9 @@ async def test_create_session_rest_asyncio_call_success(request_type):
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
request_init["session"] = {
"name": "name_value",
"create_time": {"seconds": 751, "nanos": 543},
@@ -7981,7 +8018,9 @@ async def test_get_session_rest_asyncio_bad_request(
credentials=async_anonymous_credentials(), transport="rest_asyncio"
)
# send a request that will satisfy transcoding
- request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ request_init = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -8016,7 +8055,9 @@ async def test_get_session_rest_asyncio_call_success(request_type):
)
# send a request that will satisfy transcoding
- request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ request_init = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a response.
@@ -8128,7 +8169,9 @@ async def test_list_sessions_rest_asyncio_bad_request(
credentials=async_anonymous_credentials(), transport="rest_asyncio"
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -8163,7 +8206,9 @@ async def test_list_sessions_rest_asyncio_call_success(request_type):
)
# send a request that will satisfy transcoding
- request_init = {"parent": "projects/sample1/locations/sample2"}
+ request_init = {
+ "parent": "projects/sample1/locations/sample2/reasoningEngines/sample3"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a response.
@@ -8278,7 +8323,9 @@ async def test_update_session_rest_asyncio_bad_request(
)
# send a request that will satisfy transcoding
request_init = {
- "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ "session": {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
}
request = request_type(**request_init)
@@ -8315,10 +8362,12 @@ async def test_update_session_rest_asyncio_call_success(request_type):
# send a request that will satisfy transcoding
request_init = {
- "session": {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ "session": {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
}
request_init["session"] = {
- "name": "projects/sample1/locations/sample2/sessions/sample3",
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4",
"create_time": {"seconds": 751, "nanos": 543},
"update_time": {},
"display_name": "display_name_value",
@@ -8504,7 +8553,9 @@ async def test_delete_session_rest_asyncio_bad_request(
credentials=async_anonymous_credentials(), transport="rest_asyncio"
)
# send a request that will satisfy transcoding
- request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ request_init = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a BadRequest error.
@@ -8539,7 +8590,9 @@ async def test_delete_session_rest_asyncio_call_success(request_type):
)
# send a request that will satisfy transcoding
- request_init = {"name": "projects/sample1/locations/sample2/sessions/sample3"}
+ request_init = {
+ "name": "projects/sample1/locations/sample2/reasoningEngines/sample3/sessions/sample4"
+ }
request = request_type(**request_init)
# Mock the http request call within the method and fake a response.
@@ -8877,6 +8930,7 @@ async def test_append_event_rest_asyncio_call_success(request_type):
"transfer_to_agent": True,
"escalate": True,
"requested_auth_configs": {},
+ "transfer_agent": "transfer_agent_value",
},
"timestamp": {"seconds": 751, "nanos": 543},
"error_code": "error_code_value",
@@ -10554,21 +10608,26 @@ def test_session_service_grpc_lro_async_client():
def test_session_path():
project = "squid"
location = "clam"
- session = "whelk"
- expected = "projects/{project}/locations/{location}/sessions/{session}".format(
+ reasoning_engine = "whelk"
+ session = "octopus"
+ expected = "projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}".format(
project=project,
location=location,
+ reasoning_engine=reasoning_engine,
session=session,
)
- actual = SessionServiceClient.session_path(project, location, session)
+ actual = SessionServiceClient.session_path(
+ project, location, reasoning_engine, session
+ )
assert expected == actual
def test_parse_session_path():
expected = {
- "project": "octopus",
- "location": "oyster",
- "session": "nudibranch",
+ "project": "oyster",
+ "location": "nudibranch",
+ "reasoning_engine": "cuttlefish",
+ "session": "mussel",
}
path = SessionServiceClient.session_path(**expected)
@@ -10578,26 +10637,31 @@ def test_parse_session_path():
def test_session_event_path():
- project = "cuttlefish"
- location = "mussel"
- session = "winkle"
- event = "nautilus"
- expected = "projects/{project}/locations/{location}/sessions/{session}/events/{event}".format(
+ project = "winkle"
+ location = "nautilus"
+ reasoning_engine = "scallop"
+ session = "abalone"
+ event = "squid"
+ expected = "projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}/events/{event}".format(
project=project,
location=location,
+ reasoning_engine=reasoning_engine,
session=session,
event=event,
)
- actual = SessionServiceClient.session_event_path(project, location, session, event)
+ actual = SessionServiceClient.session_event_path(
+ project, location, reasoning_engine, session, event
+ )
assert expected == actual
def test_parse_session_event_path():
expected = {
- "project": "scallop",
- "location": "abalone",
- "session": "squid",
- "event": "clam",
+ "project": "clam",
+ "location": "whelk",
+ "reasoning_engine": "octopus",
+ "session": "oyster",
+ "event": "nudibranch",
}
path = SessionServiceClient.session_event_path(**expected)
@@ -10607,7 +10671,7 @@ def test_parse_session_event_path():
def test_common_billing_account_path():
- billing_account = "whelk"
+ billing_account = "cuttlefish"
expected = "billingAccounts/{billing_account}".format(
billing_account=billing_account,
)
@@ -10617,7 +10681,7 @@ def test_common_billing_account_path():
def test_parse_common_billing_account_path():
expected = {
- "billing_account": "octopus",
+ "billing_account": "mussel",
}
path = SessionServiceClient.common_billing_account_path(**expected)
@@ -10627,7 +10691,7 @@ def test_parse_common_billing_account_path():
def test_common_folder_path():
- folder = "oyster"
+ folder = "winkle"
expected = "folders/{folder}".format(
folder=folder,
)
@@ -10637,7 +10701,7 @@ def test_common_folder_path():
def test_parse_common_folder_path():
expected = {
- "folder": "nudibranch",
+ "folder": "nautilus",
}
path = SessionServiceClient.common_folder_path(**expected)
@@ -10647,7 +10711,7 @@ def test_parse_common_folder_path():
def test_common_organization_path():
- organization = "cuttlefish"
+ organization = "scallop"
expected = "organizations/{organization}".format(
organization=organization,
)
@@ -10657,7 +10721,7 @@ def test_common_organization_path():
def test_parse_common_organization_path():
expected = {
- "organization": "mussel",
+ "organization": "abalone",
}
path = SessionServiceClient.common_organization_path(**expected)
@@ -10667,7 +10731,7 @@ def test_parse_common_organization_path():
def test_common_project_path():
- project = "winkle"
+ project = "squid"
expected = "projects/{project}".format(
project=project,
)
@@ -10677,7 +10741,7 @@ def test_common_project_path():
def test_parse_common_project_path():
expected = {
- "project": "nautilus",
+ "project": "clam",
}
path = SessionServiceClient.common_project_path(**expected)
@@ -10687,8 +10751,8 @@ def test_parse_common_project_path():
def test_common_location_path():
- project = "scallop"
- location = "abalone"
+ project = "whelk"
+ location = "octopus"
expected = "projects/{project}/locations/{location}".format(
project=project,
location=location,
@@ -10699,8 +10763,8 @@ def test_common_location_path():
def test_parse_common_location_path():
expected = {
- "project": "squid",
- "location": "clam",
+ "project": "oyster",
+ "location": "nudibranch",
}
path = SessionServiceClient.common_location_path(**expected)
diff --git a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
index e576cc9120..46452de32e 100644
--- a/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
+++ b/tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py
@@ -17,6 +17,7 @@
from unittest import mock
from google import auth
+from google.genai import types
import vertexai
from google.cloud.aiplatform import initializer
from vertexai.preview import reasoning_engines
@@ -172,6 +173,28 @@ def test_stream_query(self):
)
assert len(events) == 1
+ def test_stream_query_with_content(self):
+ app = reasoning_engines.AdkApp(
+ agent=Agent(name="test_agent", model=_TEST_MODEL)
+ )
+ assert app._tmpl_attrs.get("runner") is None
+ app.set_up()
+ app._tmpl_attrs["runner"] = _MockRunner()
+ events = list(
+ app.stream_query(
+ user_id="test_user_id",
+ message=types.Content(
+ role="user",
+ parts=[
+ types.Part(
+ text="test message with content",
+ )
+ ],
+ ).model_dump(),
+ )
+ )
+ assert len(events) == 1
+
def test_streaming_agent_run_with_events(self):
app = reasoning_engines.AdkApp(
agent=Agent(name="test_agent", model=_TEST_MODEL)
diff --git a/tests/unit/vertex_langchain/test_agent_engine_templates_module.py b/tests/unit/vertex_langchain/test_agent_engine_templates_module.py
new file mode 100644
index 0000000000..20b9a0b61f
--- /dev/null
+++ b/tests/unit/vertex_langchain/test_agent_engine_templates_module.py
@@ -0,0 +1,89 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from vertexai import agent_engines
+from test_constants import test_agent
+
+_TEST_MODULE_NAME = "test_constants"
+_TEST_AGENT_NAME = "test_agent"
+_TEST_REGISTER_OPERATIONS = {"": ["query"], "stream": ["stream_query"]}
+_TEST_QUERY_INPUT = "test query"
+_TEST_STREAM_QUERY_INPUT = 5
+
+
+class TestModuleAgent:
+ def test_initialization(self):
+ agent = agent_engines.ModuleAgent(
+ module_name=_TEST_MODULE_NAME,
+ agent_name=_TEST_AGENT_NAME,
+ register_operations=_TEST_REGISTER_OPERATIONS,
+ )
+ assert agent._tmpl_attrs.get("module_name") == _TEST_MODULE_NAME
+ assert agent._tmpl_attrs.get("agent_name") == _TEST_AGENT_NAME
+ assert agent._tmpl_attrs.get("register_operations") == _TEST_REGISTER_OPERATIONS
+
+ def test_set_up(self):
+ agent = agent_engines.ModuleAgent(
+ module_name=_TEST_MODULE_NAME,
+ agent_name=_TEST_AGENT_NAME,
+ register_operations=_TEST_REGISTER_OPERATIONS,
+ )
+ assert agent._tmpl_attrs.get("agent") is None
+ agent.set_up()
+ assert agent._tmpl_attrs.get("agent") is not None
+
+ def test_clone(self):
+ agent = agent_engines.ModuleAgent(
+ module_name=_TEST_MODULE_NAME,
+ agent_name=_TEST_AGENT_NAME,
+ register_operations=_TEST_REGISTER_OPERATIONS,
+ )
+ agent.set_up()
+ assert agent._tmpl_attrs.get("agent") is not None
+ agent_clone = agent.clone()
+ assert agent._tmpl_attrs.get("agent") is not None
+ assert agent_clone._tmpl_attrs.get("agent") is None
+ agent_clone.set_up()
+ assert agent_clone._tmpl_attrs.get("agent") is not None
+
+ def test_query(self):
+ agent = agent_engines.ModuleAgent(
+ module_name=_TEST_MODULE_NAME,
+ agent_name=_TEST_AGENT_NAME,
+ register_operations=_TEST_REGISTER_OPERATIONS,
+ )
+ agent.set_up()
+ got_result = agent.query(input=_TEST_QUERY_INPUT)
+ expected_result = agent._tmpl_attrs.get("agent").query(input=_TEST_QUERY_INPUT)
+ assert got_result == expected_result
+ expected_result = test_agent.query(input=_TEST_QUERY_INPUT)
+ assert got_result == expected_result
+
+ def test_stream_query(self):
+ agent = agent_engines.ModuleAgent(
+ module_name=_TEST_MODULE_NAME,
+ agent_name=_TEST_AGENT_NAME,
+ register_operations=_TEST_REGISTER_OPERATIONS,
+ )
+ agent.set_up()
+ for got_result, expected_result in zip(
+ agent.stream_query(n=_TEST_STREAM_QUERY_INPUT),
+ agent._tmpl_attrs.get("agent").stream_query(n=_TEST_STREAM_QUERY_INPUT),
+ ):
+ assert got_result == expected_result
+ for got_result, expected_result in zip(
+ agent.stream_query(n=_TEST_STREAM_QUERY_INPUT),
+ test_agent.stream_query(n=_TEST_STREAM_QUERY_INPUT),
+ ):
+ assert got_result == expected_result
diff --git a/tests/unit/vertex_langchain/test_agent_engines.py b/tests/unit/vertex_langchain/test_agent_engines.py
index 8e66529f57..e58d878635 100644
--- a/tests/unit/vertex_langchain/test_agent_engines.py
+++ b/tests/unit/vertex_langchain/test_agent_engines.py
@@ -2315,6 +2315,47 @@ def test_invalid_operation_schema(
assert want_log_output in caplog.text
+@pytest.mark.usefixtures("google_auth_mock")
+class TestLightweightAgentEngine:
+ def setup_method(self):
+ importlib.reload(initializer)
+ importlib.reload(aiplatform)
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ )
+ self.test_agent = CapitalizeEngine()
+
+ def test_create_agent_engine_with_no_spec(
+ self,
+ create_agent_engine_mock,
+ cloud_storage_create_bucket_mock,
+ tarfile_open_mock,
+ cloudpickle_dump_mock,
+ cloudpickle_load_mock,
+ importlib_metadata_version_mock,
+ get_agent_engine_mock,
+ ):
+ importlib.reload(initializer)
+ importlib.reload(aiplatform)
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ )
+ agent_engines.create(
+ display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
+ description=_TEST_AGENT_ENGINE_DESCRIPTION,
+ )
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ credentials=_TEST_CREDENTIALS,
+ staging_bucket=_TEST_STAGING_BUCKET,
+ )
+
+
def _generate_agent_engine_to_update() -> "agent_engines.AgentEngine":
test_agent_engine = agent_engines.create(CapitalizeEngine())
# Resource name is required for the update method.
diff --git a/tests/unit/vertex_langchain/test_constants.py b/tests/unit/vertex_langchain/test_constants.py
new file mode 100644
index 0000000000..7d66f2dec2
--- /dev/null
+++ b/tests/unit/vertex_langchain/test_constants.py
@@ -0,0 +1,24 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+class _CustomAgent:
+ def query(self, input: str):
+ return input
+
+ def stream_query(self, n: int):
+ for i in range(n):
+ yield i
+
+
+test_agent = _CustomAgent()
diff --git a/tests/unit/vertexai/genai/test_genai_client.py b/tests/unit/vertexai/genai/test_genai_client.py
new file mode 100644
index 0000000000..7732bb3496
--- /dev/null
+++ b/tests/unit/vertexai/genai/test_genai_client.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=protected-access,bad-continuation
+import pytest
+import importlib
+from unittest import mock
+from google.cloud import aiplatform
+import vertexai
+from google.cloud.aiplatform import initializer as aiplatform_initializer
+
+from vertexai import _genai
+
+_TEST_PROJECT = "test-project"
+_TEST_LOCATION = "us-central1"
+
+
+pytestmark = pytest.mark.usefixtures("google_auth_mock")
+
+
+class TestGenAiClient:
+ """Unit tests for the GenAI client."""
+
+ def setup_method(self):
+ importlib.reload(aiplatform_initializer)
+ importlib.reload(aiplatform)
+ importlib.reload(vertexai)
+ vertexai.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ )
+
+ @pytest.mark.usefixtures("google_auth_mock")
+ def test_genai_client(self):
+ test_client = _genai.client.Client(
+ project=_TEST_PROJECT, location=_TEST_LOCATION
+ )
+ assert test_client is not None
+ assert test_client._api_client.vertexai
+ assert test_client._api_client.project == _TEST_PROJECT
+ assert test_client._api_client.location == _TEST_LOCATION
+
+ @pytest.mark.usefixtures("google_auth_mock")
+ def test_evaluate_instances(self):
+ test_client = _genai.client.Client(
+ project=_TEST_PROJECT, location=_TEST_LOCATION
+ )
+ with mock.patch.object(
+ test_client.evals, "_evaluate_instances"
+ ) as mock_evaluate:
+ test_client.evals._evaluate_instances(bleu_input=_genai.types.BleuInput())
+ mock_evaluate.assert_called_once_with(bleu_input=_genai.types.BleuInput())
+
+ @pytest.mark.usefixtures("google_auth_mock")
+ def test_eval_run(self):
+ test_client = _genai.client.Client(
+ project=_TEST_PROJECT, location=_TEST_LOCATION
+ )
+ with pytest.raises(NotImplementedError):
+ test_client.evals.run()
diff --git a/tests/unit/vertexai/model_garden/test_model_garden.py b/tests/unit/vertexai/model_garden/test_model_garden.py
index 23fefcfae2..beec2521cf 100644
--- a/tests/unit/vertexai/model_garden/test_model_garden.py
+++ b/tests/unit/vertexai/model_garden/test_model_garden.py
@@ -14,6 +14,7 @@
"""Unit tests for ModelGarden class."""
import importlib
+import textwrap
from unittest import mock
from google import auth
@@ -41,6 +42,7 @@
_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
+_TEST_PROJECT_NUMBER = "1234567890"
_TEST_MODEL_FULL_RESOURCE_NAME = (
"publishers/google/models/paligemma@paligemma-224-float32"
@@ -179,14 +181,39 @@ def get_publisher_model_mock():
multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex(
multi_deploy_vertex=[
types.PublisherModel.CallToAction.Deploy(
+ container_spec=types.ModelContainerSpec(
+ image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00",
+ command=["python", "main.py"],
+ args=["--model-id=gemma-2b"],
+ env=[
+ types.EnvVar(name="MODEL_ID", value="gemma-2b")
+ ],
+ ),
dedicated_resources=types.DedicatedResources(
machine_spec=types.MachineSpec(
machine_type="g2-standard-16",
accelerator_type="NVIDIA_L4",
accelerator_count=1,
)
- )
- )
+ ),
+ ),
+ types.PublisherModel.CallToAction.Deploy(
+ container_spec=types.ModelContainerSpec(
+ image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest",
+ command=["python", "main.py"],
+ args=["--model-id=gemma-2b"],
+ env=[
+ types.EnvVar(name="MODEL_ID", value="gemma-2b")
+ ],
+ ),
+ dedicated_resources=types.DedicatedResources(
+ machine_spec=types.MachineSpec(
+ machine_type="g2-standard-32",
+ accelerator_type="NVIDIA_L4",
+ accelerator_count=4,
+ )
+ ),
+ ),
]
)
),
@@ -197,14 +224,39 @@ def get_publisher_model_mock():
multi_deploy_vertex=types.PublisherModel.CallToAction.DeployVertex(
multi_deploy_vertex=[
types.PublisherModel.CallToAction.Deploy(
+ container_spec=types.ModelContainerSpec(
+ image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00",
+ command=["python", "main.py"],
+ args=["--model-id=gemma-2b"],
+ env=[
+ types.EnvVar(name="MODEL_ID", value="gemma-2b")
+ ],
+ ),
dedicated_resources=types.DedicatedResources(
machine_spec=types.MachineSpec(
machine_type="g2-standard-16",
accelerator_type="NVIDIA_L4",
accelerator_count=1,
)
- )
- )
+ ),
+ ),
+ types.PublisherModel.CallToAction.Deploy(
+ container_spec=types.ModelContainerSpec(
+ image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest",
+ command=["python", "main.py"],
+ args=["--model-id=gemma-2b"],
+ env=[
+ types.EnvVar(name="MODEL_ID", value="gemma-2b")
+ ],
+ ),
+ dedicated_resources=types.DedicatedResources(
+ machine_spec=types.MachineSpec(
+ machine_type="g2-standard-32",
+ accelerator_type="NVIDIA_L4",
+ accelerator_count=4,
+ )
+ ),
+ ),
]
)
),
@@ -398,6 +450,40 @@ def list_publisher_models_mock():
yield list_publisher_models
+@pytest.fixture
+def check_license_agreement_status_mock():
+ """Mocks the check_license_agreement_status method."""
+ with mock.patch.object(
+ model_garden_service.ModelGardenServiceClient,
+ "check_publisher_model_eula_acceptance",
+ ) as check_license_agreement_status:
+ check_license_agreement_status.return_value = (
+ types.PublisherModelEulaAcceptance(
+ project_number=_TEST_PROJECT_NUMBER,
+ publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
+ publisher_model_eula_acked=True,
+ )
+ )
+ yield check_license_agreement_status
+
+
+@pytest.fixture
+def accept_model_license_agreement_mock():
+ """Mocks the accept_model_license_agreement method."""
+ with mock.patch.object(
+ model_garden_service.ModelGardenServiceClient,
+ "accept_publisher_model_eula",
+ ) as accept_model_license_agreement:
+ accept_model_license_agreement.return_value = (
+ types.PublisherModelEulaAcceptance(
+ project_number=_TEST_PROJECT_NUMBER,
+ publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
+ publisher_model_eula_acked=True,
+ )
+ )
+ yield accept_model_license_agreement
+
+
@pytest.mark.usefixtures(
"google_auth_mock",
"deploy_mock",
@@ -406,9 +492,11 @@ def list_publisher_models_mock():
"export_publisher_model_mock",
"batch_prediction_mock",
"complete_bq_uri_mock",
+ "check_license_agreement_status_mock",
+ "accept_model_license_agreement_mock",
)
-class TestModelGarden:
- """Test cases for ModelGarden class."""
+class TestModelGardenOpenModel:
+ """Test cases for Model Garden OpenModel class."""
def setup_method(self):
importlib.reload(aiplatform.initializer)
@@ -709,6 +797,24 @@ def test_deploy_with_dedicated_endpoint_success(self, deploy_mock):
)
)
+ def test_deploy_with_system_labels_success(self, deploy_mock):
+ """Tests deploying a model with system labels."""
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ )
+ model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
+ model.deploy(system_labels={"test-key": "test-value"})
+ deploy_mock.assert_called_once_with(
+ types.DeployRequest(
+ publisher_model_name=_TEST_MODEL_FULL_RESOURCE_NAME,
+ destination=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}",
+ deploy_config=types.DeployRequest.DeployConfig(
+ system_labels={"test-key": "test-value"}
+ ),
+ )
+ )
+
def test_deploy_with_fast_tryout_enabled_success(self, deploy_mock):
"""Tests deploying a model with fast tryout enabled."""
aiplatform.init(
@@ -880,9 +986,8 @@ def test_list_deploy_options(self, get_publisher_model_mock):
)
expected_message = (
- "Model does not support deployment, please use a deploy-able model"
- " instead. You can use the list_deployable_models() method to find out"
- " which ones currently support deployment."
+ "Model does not support deployment. "
+ "Use `list_deployable_models()` to find supported models."
)
with pytest.raises(ValueError) as exception:
model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
@@ -908,6 +1013,71 @@ def test_list_deploy_options(self, get_publisher_model_mock):
)
)
+ def test_list_deploy_options_concise(self, get_publisher_model_mock):
+ """Tests getting the supported deploy options for a model."""
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ )
+
+ expected_message = (
+ "Model does not support deployment. "
+ "Use `list_deployable_models()` to find supported models."
+ )
+ with pytest.raises(ValueError) as exception:
+ model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
+ _ = model.list_deploy_options(concise=True)
+ assert str(exception.value) == expected_message
+
+ result = model.list_deploy_options(concise=True)
+ expected_result = textwrap.dedent(
+ """\
+ [Option 1]
+ serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00",
+ machine_type="g2-standard-16",
+ accelerator_type="NVIDIA_L4",
+ accelerator_count=1,
+
+ [Option 2]
+ serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest",
+ machine_type="g2-standard-32",
+ accelerator_type="NVIDIA_L4",
+ accelerator_count=4,"""
+ )
+ assert result == expected_result
+ get_publisher_model_mock.assert_called_with(
+ types.GetPublisherModelRequest(
+ name=_TEST_MODEL_FULL_RESOURCE_NAME,
+ is_hugging_face_model=False,
+ include_equivalent_model_garden_model_deployment_configs=True,
+ )
+ )
+
+ hf_model = model_garden.OpenModel(_TEST_MODEL_HUGGING_FACE_ID)
+ hf_result = hf_model.list_deploy_options(concise=True)
+ expected_hf_result = textwrap.dedent(
+ """\
+ [Option 1]
+ serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20241202_0916_RC00",
+ machine_type="g2-standard-16",
+ accelerator_type="NVIDIA_L4",
+ accelerator_count=1,
+
+ [Option 2]
+ serving_container_image_uri="us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/text-generation-inference-cu121.2-1.py310:latest",
+ machine_type="g2-standard-32",
+ accelerator_type="NVIDIA_L4",
+ accelerator_count=4,"""
+ )
+ assert hf_result == expected_hf_result
+ get_publisher_model_mock.assert_called_with(
+ types.GetPublisherModelRequest(
+ name=_TEST_HUGGING_FACE_MODEL_FULL_RESOURCE_NAME,
+ is_hugging_face_model=True,
+ include_equivalent_model_garden_model_deployment_configs=True,
+ )
+ )
+
def test_list_deployable_models(self, list_publisher_models_mock):
"""Tests getting the supported deploy options for a model."""
aiplatform.init(
@@ -999,3 +1169,43 @@ def test_batch_prediction_success(self, batch_prediction_mock):
batch_prediction_job=expected_gapic_batch_prediction_job,
timeout=None,
)
+
+ def test_check_license_agreement_status_success(
+ self, check_license_agreement_status_mock
+ ):
+ """Tests checking EULA acceptance for a model."""
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ )
+ model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
+ eula_acceptance = model.check_license_agreement_status()
+ check_license_agreement_status_mock.assert_called_once_with(
+ types.CheckPublisherModelEulaAcceptanceRequest(
+ parent=f"projects/{_TEST_PROJECT}",
+ publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
+ )
+ )
+ assert eula_acceptance
+
+ def test_accept_model_license_agreement_success(
+ self, accept_model_license_agreement_mock
+ ):
+ """Tests accepting EULA for a model."""
+ aiplatform.init(
+ project=_TEST_PROJECT,
+ location=_TEST_LOCATION,
+ )
+ model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME)
+ eula_acceptance = model.accept_model_license_agreement()
+ accept_model_license_agreement_mock.assert_called_once_with(
+ types.AcceptPublisherModelEulaRequest(
+ parent=f"projects/{_TEST_PROJECT}",
+ publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
+ )
+ )
+ assert eula_acceptance == types.PublisherModelEulaAcceptance(
+ project_number=_TEST_PROJECT_NUMBER,
+ publisher_model=_TEST_MODEL_FULL_RESOURCE_NAME,
+ publisher_model_eula_acked=True,
+ )
diff --git a/vertexai/_genai/__init__.py b/vertexai/_genai/__init__.py
new file mode 100644
index 0000000000..f1d1598d81
--- /dev/null
+++ b/vertexai/_genai/__init__.py
@@ -0,0 +1,23 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""The vertexai module."""
+
+from . import evals
+from .client import Client
+
+__all__ = [
+ "Client",
+ "evals",
+]
diff --git a/vertexai/_genai/client.py b/vertexai/_genai/client.py
new file mode 100644
index 0000000000..35e36a6064
--- /dev/null
+++ b/vertexai/_genai/client.py
@@ -0,0 +1,92 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional, Union
+
+import google.auth
+from google.genai import client
+from google.genai import types
+from .evals import Evals
+from .evals import AsyncEvals
+
+
+class AsyncClient:
+ """Async Client for the GenAI SDK."""
+
+ def __init__(self, api_client: client.Client):
+ self._api_client = api_client
+ self._aio = AsyncClient(self._api_client)
+ self._evals = AsyncEvals(self._api_client)
+
+ @property
+ def evals(self) -> AsyncEvals:
+ return self._evals
+
+
+class Client:
+ """Client for the GenAI SDK.
+
+ Use this client to interact with Vertex-specific Gemini features.
+ """
+
+ def __init__(
+ self,
+ *,
+ credentials: Optional[google.auth.credentials.Credentials] = None,
+ project: Optional[str] = None,
+ location: Optional[str] = None,
+ debug_config: Optional[client.DebugConfig] = None,
+ http_options: Optional[Union[types.HttpOptions, types.HttpOptionsDict]] = None,
+ ):
+ """Initializes the client.
+
+ Args:
+ credentials (google.auth.credentials.Credentials): The credentials to use
+ for authentication when calling the Vertex AI APIs. Credentials can be
+ obtained from environment variables and default credentials. For more
+ information, see `Set up Application Default Credentials
+ `_.
+ project (str): The `Google Cloud project ID
+ `_ to
+ use for quota. Can be obtained from environment variables (for example,
+ ``GOOGLE_CLOUD_PROJECT``).
+ location (str): The `location
+ `_
+ to send API requests to (for example, ``us-central1``). Can be obtained
+ from environment variables.
+ debug_config (DebugConfig): Config settings that control network behavior
+ of the client. This is typically used when running test code.
+ http_options (Union[HttpOptions, HttpOptionsDict]): Http options to use
+ for the client.
+ """
+
+ self._debug_config = debug_config or client.DebugConfig()
+ if isinstance(http_options, dict):
+ http_options = types.HttpOptions(**http_options)
+
+ self._api_client = client.Client._get_api_client(
+ vertexai=True,
+ credentials=credentials,
+ project=project,
+ location=location,
+ debug_config=self._debug_config,
+ http_options=http_options,
+ )
+
+ self._evals = Evals(self._api_client)
+
+ @property
+ def evals(self) -> Evals:
+ return self._evals
diff --git a/vertexai/_genai/evals.py b/vertexai/_genai/evals.py
new file mode 100644
index 0000000000..0aa2d16bea
--- /dev/null
+++ b/vertexai/_genai/evals.py
@@ -0,0 +1,894 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
+
+import logging
+from typing import Any, Optional, Union
+from urllib.parse import urlencode
+
+from google.genai import _api_module
+from google.genai import _common
+from google.genai import types as genai_types
+from google.genai._api_client import BaseApiClient
+from google.genai._common import get_value_by_path as getv
+from google.genai._common import set_value_by_path as setv
+
+from . import types
+
+logger = logging.getLogger("google_genai.evals")
+
+
+def _BleuInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["prediction"]) is not None:
+ setv(to_object, ["prediction"], getv(from_object, ["prediction"]))
+
+ if getv(from_object, ["reference"]) is not None:
+ setv(to_object, ["reference"], getv(from_object, ["reference"]))
+
+ return to_object
+
+
+def _BleuSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["use_effective_order"]) is not None:
+ setv(
+ to_object,
+ ["useEffectiveOrder"],
+ getv(from_object, ["use_effective_order"]),
+ )
+
+ return to_object
+
+
+def _BleuInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instances"]) is not None:
+ setv(
+ to_object,
+ ["instances"],
+ [
+ _BleuInstance_to_vertex(api_client, item, to_object)
+ for item in getv(from_object, ["instances"])
+ ],
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _BleuSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _ExactMatchInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["prediction"]) is not None:
+ setv(to_object, ["prediction"], getv(from_object, ["prediction"]))
+
+ if getv(from_object, ["reference"]) is not None:
+ setv(to_object, ["reference"], getv(from_object, ["reference"]))
+
+ return to_object
+
+
+def _ExactMatchSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+
+ return to_object
+
+
+def _ExactMatchInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instances"]) is not None:
+ setv(
+ to_object,
+ ["instances"],
+ [
+ _ExactMatchInstance_to_vertex(api_client, item, to_object)
+ for item in getv(from_object, ["instances"])
+ ],
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _ExactMatchSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _RougeInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["prediction"]) is not None:
+ setv(to_object, ["prediction"], getv(from_object, ["prediction"]))
+
+ if getv(from_object, ["reference"]) is not None:
+ setv(to_object, ["reference"], getv(from_object, ["reference"]))
+
+ return to_object
+
+
+def _RougeSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["rouge_type"]) is not None:
+ setv(to_object, ["rougeType"], getv(from_object, ["rouge_type"]))
+
+ if getv(from_object, ["split_summaries"]) is not None:
+ setv(
+ to_object,
+ ["splitSummaries"],
+ getv(from_object, ["split_summaries"]),
+ )
+
+ if getv(from_object, ["use_stemmer"]) is not None:
+ setv(to_object, ["useStemmer"], getv(from_object, ["use_stemmer"]))
+
+ return to_object
+
+
+def _RougeInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instances"]) is not None:
+ setv(
+ to_object,
+ ["instances"],
+ [
+ _RougeInstance_to_vertex(api_client, item, to_object)
+ for item in getv(from_object, ["instances"])
+ ],
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _RougeSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _PointwiseMetricInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["json_instance"]) is not None:
+ setv(to_object, ["jsonInstance"], getv(from_object, ["json_instance"]))
+
+ return to_object
+
+
+def _PointwiseMetricSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["metric_prompt_template"]) is not None:
+ setv(
+ to_object,
+ ["metricPromptTemplate"],
+ getv(from_object, ["metric_prompt_template"]),
+ )
+
+ return to_object
+
+
+def _PointwiseMetricInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instance"]) is not None:
+ setv(
+ to_object,
+ ["instance"],
+ _PointwiseMetricInstance_to_vertex(
+ api_client, getv(from_object, ["instance"]), to_object
+ ),
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _PointwiseMetricSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _PairwiseMetricInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["json_instance"]) is not None:
+ setv(to_object, ["jsonInstance"], getv(from_object, ["json_instance"]))
+
+ return to_object
+
+
+def _PairwiseMetricSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["metric_prompt_template"]) is not None:
+ setv(
+ to_object,
+ ["metricPromptTemplate"],
+ getv(from_object, ["metric_prompt_template"]),
+ )
+
+ return to_object
+
+
+def _PairwiseMetricInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instance"]) is not None:
+ setv(
+ to_object,
+ ["instance"],
+ _PairwiseMetricInstance_to_vertex(
+ api_client, getv(from_object, ["instance"]), to_object
+ ),
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _PairwiseMetricSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _ToolCallValidInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["prediction"]) is not None:
+ setv(to_object, ["prediction"], getv(from_object, ["prediction"]))
+
+ if getv(from_object, ["reference"]) is not None:
+ setv(to_object, ["reference"], getv(from_object, ["reference"]))
+
+ return to_object
+
+
+def _ToolCallValidSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+
+ return to_object
+
+
+def _ToolCallValidInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instances"]) is not None:
+ setv(
+ to_object,
+ ["instances"],
+ [
+ _ToolCallValidInstance_to_vertex(api_client, item, to_object)
+ for item in getv(from_object, ["instances"])
+ ],
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _ToolCallValidSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _ToolNameMatchInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["prediction"]) is not None:
+ setv(to_object, ["prediction"], getv(from_object, ["prediction"]))
+
+ if getv(from_object, ["reference"]) is not None:
+ setv(to_object, ["reference"], getv(from_object, ["reference"]))
+
+ return to_object
+
+
+def _ToolNameMatchSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+
+ return to_object
+
+
+def _ToolNameMatchInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instances"]) is not None:
+ setv(
+ to_object,
+ ["instances"],
+ [
+ _ToolNameMatchInstance_to_vertex(api_client, item, to_object)
+ for item in getv(from_object, ["instances"])
+ ],
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _ToolNameMatchSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _ToolParameterKeyMatchInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["prediction"]) is not None:
+ setv(to_object, ["prediction"], getv(from_object, ["prediction"]))
+
+ if getv(from_object, ["reference"]) is not None:
+ setv(to_object, ["reference"], getv(from_object, ["reference"]))
+
+ return to_object
+
+
+def _ToolParameterKeyMatchSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+
+ return to_object
+
+
+def _ToolParameterKeyMatchInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instances"]) is not None:
+ setv(
+ to_object,
+ ["instances"],
+ [
+ _ToolParameterKeyMatchInstance_to_vertex(api_client, item, to_object)
+ for item in getv(from_object, ["instances"])
+ ],
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _ToolParameterKeyMatchSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _ToolParameterKVMatchInstance_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["prediction"]) is not None:
+ setv(to_object, ["prediction"], getv(from_object, ["prediction"]))
+
+ if getv(from_object, ["reference"]) is not None:
+ setv(to_object, ["reference"], getv(from_object, ["reference"]))
+
+ return to_object
+
+
+def _ToolParameterKVMatchSpec_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["use_strict_string_match"]) is not None:
+ setv(
+ to_object,
+ ["useStrictStringMatch"],
+ getv(from_object, ["use_strict_string_match"]),
+ )
+
+ return to_object
+
+
+def _ToolParameterKVMatchInput_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["instances"]) is not None:
+ setv(
+ to_object,
+ ["instances"],
+ [
+ _ToolParameterKVMatchInstance_to_vertex(api_client, item, to_object)
+ for item in getv(from_object, ["instances"])
+ ],
+ )
+
+ if getv(from_object, ["metric_spec"]) is not None:
+ setv(
+ to_object,
+ ["metricSpec"],
+ _ToolParameterKVMatchSpec_to_vertex(
+ api_client, getv(from_object, ["metric_spec"]), to_object
+ ),
+ )
+
+ return to_object
+
+
+def _EvaluateInstancesRequestParameters_to_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+ if getv(from_object, ["bleu_input"]) is not None:
+ setv(
+ to_object,
+ ["bleuInput"],
+ _BleuInput_to_vertex(
+ api_client, getv(from_object, ["bleu_input"]), to_object
+ ),
+ )
+
+ if getv(from_object, ["exact_match_input"]) is not None:
+ setv(
+ to_object,
+ ["exactMatchInput"],
+ _ExactMatchInput_to_vertex(
+ api_client, getv(from_object, ["exact_match_input"]), to_object
+ ),
+ )
+
+ if getv(from_object, ["rouge_input"]) is not None:
+ setv(
+ to_object,
+ ["rougeInput"],
+ _RougeInput_to_vertex(
+ api_client, getv(from_object, ["rouge_input"]), to_object
+ ),
+ )
+
+ if getv(from_object, ["pointwise_metric_input"]) is not None:
+ setv(
+ to_object,
+ ["pointwiseMetricInput"],
+ _PointwiseMetricInput_to_vertex(
+ api_client,
+ getv(from_object, ["pointwise_metric_input"]),
+ to_object,
+ ),
+ )
+
+ if getv(from_object, ["pairwise_metric_input"]) is not None:
+ setv(
+ to_object,
+ ["pairwiseMetricInput"],
+ _PairwiseMetricInput_to_vertex(
+ api_client,
+ getv(from_object, ["pairwise_metric_input"]),
+ to_object,
+ ),
+ )
+
+ if getv(from_object, ["tool_call_valid_input"]) is not None:
+ setv(
+ to_object,
+ ["toolCallValidInput"],
+ _ToolCallValidInput_to_vertex(
+ api_client,
+ getv(from_object, ["tool_call_valid_input"]),
+ to_object,
+ ),
+ )
+
+ if getv(from_object, ["tool_name_match_input"]) is not None:
+ setv(
+ to_object,
+ ["toolNameMatchInput"],
+ _ToolNameMatchInput_to_vertex(
+ api_client,
+ getv(from_object, ["tool_name_match_input"]),
+ to_object,
+ ),
+ )
+
+ if getv(from_object, ["tool_parameter_key_match_input"]) is not None:
+ setv(
+ to_object,
+ ["toolParameterKeyMatchInput"],
+ _ToolParameterKeyMatchInput_to_vertex(
+ api_client,
+ getv(from_object, ["tool_parameter_key_match_input"]),
+ to_object,
+ ),
+ )
+
+ if getv(from_object, ["tool_parameter_kv_match_input"]) is not None:
+ setv(
+ to_object,
+ ["toolParameterKvMatchInput"],
+ _ToolParameterKVMatchInput_to_vertex(
+ api_client,
+ getv(from_object, ["tool_parameter_kv_match_input"]),
+ to_object,
+ ),
+ )
+
+ if getv(from_object, ["config"]) is not None:
+ setv(to_object, ["config"], getv(from_object, ["config"]))
+
+ return to_object
+
+
+def _EvaluateInstancesResponse_from_vertex(
+ api_client: BaseApiClient,
+ from_object: Union[dict[str, Any], object],
+ parent_object: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
+ to_object: dict[str, Any] = {}
+
+ if getv(from_object, ["bleuResults"]) is not None:
+ setv(to_object, ["bleu_results"], getv(from_object, ["bleuResults"]))
+
+ if getv(from_object, ["exactMatchResults"]) is not None:
+ setv(
+ to_object,
+ ["exact_match_results"],
+ getv(from_object, ["exactMatchResults"]),
+ )
+
+ if getv(from_object, ["pairwiseMetricResult"]) is not None:
+ setv(
+ to_object,
+ ["pairwise_metric_result"],
+ getv(from_object, ["pairwiseMetricResult"]),
+ )
+
+ if getv(from_object, ["pointwiseMetricResult"]) is not None:
+ setv(
+ to_object,
+ ["pointwise_metric_result"],
+ getv(from_object, ["pointwiseMetricResult"]),
+ )
+
+ if getv(from_object, ["rougeResults"]) is not None:
+ setv(to_object, ["rouge_results"], getv(from_object, ["rougeResults"]))
+
+ if getv(from_object, ["summarizationVerbosityResult"]) is not None:
+ setv(
+ to_object,
+ ["summarization_verbosity_result"],
+ getv(from_object, ["summarizationVerbosityResult"]),
+ )
+
+ if getv(from_object, ["toolCallValidResults"]) is not None:
+ setv(
+ to_object,
+ ["tool_call_valid_results"],
+ getv(from_object, ["toolCallValidResults"]),
+ )
+
+ if getv(from_object, ["toolNameMatchResults"]) is not None:
+ setv(
+ to_object,
+ ["tool_name_match_results"],
+ getv(from_object, ["toolNameMatchResults"]),
+ )
+
+ if getv(from_object, ["toolParameterKeyMatchResults"]) is not None:
+ setv(
+ to_object,
+ ["tool_parameter_key_match_results"],
+ getv(from_object, ["toolParameterKeyMatchResults"]),
+ )
+
+ if getv(from_object, ["toolParameterKvMatchResults"]) is not None:
+ setv(
+ to_object,
+ ["tool_parameter_kv_match_results"],
+ getv(from_object, ["toolParameterKvMatchResults"]),
+ )
+
+ return to_object
+
+
+class Evals(_api_module.BaseModule):
+ def _evaluate_instances(
+ self,
+ *,
+ bleu_input: Optional[types.BleuInputOrDict] = None,
+ exact_match_input: Optional[types.ExactMatchInputOrDict] = None,
+ rouge_input: Optional[types.RougeInputOrDict] = None,
+ pointwise_metric_input: Optional[types.PointwiseMetricInputOrDict] = None,
+ pairwise_metric_input: Optional[types.PairwiseMetricInputOrDict] = None,
+ tool_call_valid_input: Optional[types.ToolCallValidInputOrDict] = None,
+ tool_name_match_input: Optional[types.ToolNameMatchInputOrDict] = None,
+ tool_parameter_key_match_input: Optional[
+ types.ToolParameterKeyMatchInputOrDict
+ ] = None,
+ tool_parameter_kv_match_input: Optional[
+ types.ToolParameterKVMatchInputOrDict
+ ] = None,
+ config: Optional[types.EvaluateInstancesConfigOrDict] = None,
+ ) -> types.EvaluateInstancesResponse:
+ """Evaluates instances based on a given metric."""
+
+ parameter_model = types._EvaluateInstancesRequestParameters(
+ bleu_input=bleu_input,
+ exact_match_input=exact_match_input,
+ rouge_input=rouge_input,
+ pointwise_metric_input=pointwise_metric_input,
+ pairwise_metric_input=pairwise_metric_input,
+ tool_call_valid_input=tool_call_valid_input,
+ tool_name_match_input=tool_name_match_input,
+ tool_parameter_key_match_input=tool_parameter_key_match_input,
+ tool_parameter_kv_match_input=tool_parameter_kv_match_input,
+ config=config,
+ )
+
+ request_url_dict: Optional[dict[str, str]]
+ if not self._api_client.vertexai:
+ raise ValueError("This method is only supported in the Vertex AI client.")
+ else:
+ request_dict = _EvaluateInstancesRequestParameters_to_vertex(
+ self._api_client, parameter_model
+ )
+ request_url_dict = request_dict.get("_url")
+ if request_url_dict:
+ path = ":evaluateInstances".format_map(request_url_dict)
+ else:
+ path = ":evaluateInstances"
+
+ query_params = request_dict.get("_query")
+ if query_params:
+ path = f"{path}?{urlencode(query_params)}"
+ # TODO: remove the hack that pops config.
+ request_dict.pop("config", None)
+
+ http_options: Optional[genai_types.HttpOptions] = None
+ if (
+ parameter_model.config is not None
+ and parameter_model.config.http_options is not None
+ ):
+ http_options = parameter_model.config.http_options
+
+ request_dict = _common.convert_to_dict(request_dict)
+ request_dict = _common.encode_unserializable_types(request_dict)
+
+ response_dict = self._api_client.request(
+ "post", path, request_dict, http_options
+ )
+
+ if self._api_client.vertexai:
+ response_dict = _EvaluateInstancesResponse_from_vertex(
+ self._api_client, response_dict
+ )
+
+ return_value = types.EvaluateInstancesResponse._from_response(
+ response=response_dict, kwargs=parameter_model.model_dump()
+ )
+ self._api_client._verify_response(return_value)
+ return return_value
+
+ def run(self) -> types.EvaluateInstancesResponse:
+ """Evaluates an instance of a model.
+
+ This should eventually call _evaluate_instances()
+ """
+ raise NotImplementedError()
+
+ def evaluate_instances(
+ self,
+ *,
+ metric_config: types._EvaluateInstancesRequestParameters,
+ ) -> types.EvaluateInstancesResponse:
+ """Evaluates an instance of a model."""
+
+ if isinstance(metric_config, types._EvaluateInstancesRequestParameters):
+ metric_config = metric_config.model_dump()
+ else:
+ metric_config = dict(metric_config)
+
+ return self._evaluate_instances(
+ **metric_config,
+ )
+
+
+class AsyncEvals(_api_module.BaseModule):
+ async def _evaluate_instances(
+ self,
+ *,
+ bleu_input: Optional[types.BleuInputOrDict] = None,
+ exact_match_input: Optional[types.ExactMatchInputOrDict] = None,
+ rouge_input: Optional[types.RougeInputOrDict] = None,
+ pointwise_metric_input: Optional[types.PointwiseMetricInputOrDict] = None,
+ pairwise_metric_input: Optional[types.PairwiseMetricInputOrDict] = None,
+ tool_call_valid_input: Optional[types.ToolCallValidInputOrDict] = None,
+ tool_name_match_input: Optional[types.ToolNameMatchInputOrDict] = None,
+ tool_parameter_key_match_input: Optional[
+ types.ToolParameterKeyMatchInputOrDict
+ ] = None,
+ tool_parameter_kv_match_input: Optional[
+ types.ToolParameterKVMatchInputOrDict
+ ] = None,
+ config: Optional[types.EvaluateInstancesConfigOrDict] = None,
+ ) -> types.EvaluateInstancesResponse:
+ """Evaluates instances based on a given metric."""
+
+ parameter_model = types._EvaluateInstancesRequestParameters(
+ bleu_input=bleu_input,
+ exact_match_input=exact_match_input,
+ rouge_input=rouge_input,
+ pointwise_metric_input=pointwise_metric_input,
+ pairwise_metric_input=pairwise_metric_input,
+ tool_call_valid_input=tool_call_valid_input,
+ tool_name_match_input=tool_name_match_input,
+ tool_parameter_key_match_input=tool_parameter_key_match_input,
+ tool_parameter_kv_match_input=tool_parameter_kv_match_input,
+ config=config,
+ )
+
+ request_url_dict: Optional[dict[str, str]]
+ if not self._api_client.vertexai:
+ raise ValueError("This method is only supported in the Vertex AI client.")
+ else:
+ request_dict = _EvaluateInstancesRequestParameters_to_vertex(
+ self._api_client, parameter_model
+ )
+ request_url_dict = request_dict.get("_url")
+ if request_url_dict:
+ path = ":evaluateInstances".format_map(request_url_dict)
+ else:
+ path = ":evaluateInstances"
+
+ query_params = request_dict.get("_query")
+ if query_params:
+ path = f"{path}?{urlencode(query_params)}"
+ # TODO: remove the hack that pops config.
+ request_dict.pop("config", None)
+
+ http_options: Optional[genai_types.HttpOptions] = None
+ if (
+ parameter_model.config is not None
+ and parameter_model.config.http_options is not None
+ ):
+ http_options = parameter_model.config.http_options
+
+ request_dict = _common.convert_to_dict(request_dict)
+ request_dict = _common.encode_unserializable_types(request_dict)
+
+ response_dict = await self._api_client.async_request(
+ "post", path, request_dict, http_options
+ )
+
+ if self._api_client.vertexai:
+ response_dict = _EvaluateInstancesResponse_from_vertex(
+ self._api_client, response_dict
+ )
+
+ return_value = types.EvaluateInstancesResponse._from_response(
+ response=response_dict, kwargs=parameter_model.model_dump()
+ )
+ self._api_client._verify_response(return_value)
+ return return_value
diff --git a/vertexai/_genai/types.py b/vertexai/_genai/types.py
new file mode 100644
index 0000000000..7c02c495db
--- /dev/null
+++ b/vertexai/_genai/types.py
@@ -0,0 +1,1245 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Code generated by the Google Gen AI SDK generator DO NOT EDIT.
+
+import logging
+from typing import Any, Optional, Union
+from google.genai import _common
+from pydantic import Field
+from typing_extensions import TypedDict
+
+logger = logging.getLogger("google_genai.types")
+
+
+class PairwiseChoice(_common.CaseInSensitiveEnum):
+ """Output only. Pairwise metric choice."""
+
+ PAIRWISE_CHOICE_UNSPECIFIED = "PAIRWISE_CHOICE_UNSPECIFIED"
+ BASELINE = "BASELINE"
+ CANDIDATE = "CANDIDATE"
+ TIE = "TIE"
+
+
+class BleuInstance(_common.BaseModel):
+ """Bleu instance."""
+
+ prediction: Optional[str] = Field(
+ default=None, description="""Required. Output of the evaluated model."""
+ )
+ reference: Optional[str] = Field(
+ default=None,
+ description="""Required. Ground truth used to compare against the prediction.""",
+ )
+
+
+class BleuInstanceDict(TypedDict, total=False):
+ """Bleu instance."""
+
+ prediction: Optional[str]
+ """Required. Output of the evaluated model."""
+
+ reference: Optional[str]
+ """Required. Ground truth used to compare against the prediction."""
+
+
+BleuInstanceOrDict = Union[BleuInstance, BleuInstanceDict]
+
+
+class BleuSpec(_common.BaseModel):
+ """Spec for bleu metric."""
+
+ use_effective_order: Optional[bool] = Field(
+ default=None,
+ description="""Optional. Whether to use_effective_order to compute bleu score.""",
+ )
+
+
+class BleuSpecDict(TypedDict, total=False):
+ """Spec for bleu metric."""
+
+ use_effective_order: Optional[bool]
+ """Optional. Whether to use_effective_order to compute bleu score."""
+
+
+BleuSpecOrDict = Union[BleuSpec, BleuSpecDict]
+
+
+class BleuInput(_common.BaseModel):
+
+ instances: Optional[list[BleuInstance]] = Field(
+ default=None, description="""Required. Repeated bleu instances."""
+ )
+ metric_spec: Optional[BleuSpec] = Field(
+ default=None, description="""Required. Spec for bleu score metric."""
+ )
+
+
+class BleuInputDict(TypedDict, total=False):
+
+ instances: Optional[list[BleuInstanceDict]]
+ """Required. Repeated bleu instances."""
+
+ metric_spec: Optional[BleuSpecDict]
+ """Required. Spec for bleu score metric."""
+
+
+BleuInputOrDict = Union[BleuInput, BleuInputDict]
+
+
+class ExactMatchInstance(_common.BaseModel):
+ """Exact match instance."""
+
+ prediction: Optional[str] = Field(
+ default=None, description="""Required. Output of the evaluated model."""
+ )
+ reference: Optional[str] = Field(
+ default=None,
+ description="""Required. Ground truth used to compare against the prediction.""",
+ )
+
+
+class ExactMatchInstanceDict(TypedDict, total=False):
+ """Exact match instance."""
+
+ prediction: Optional[str]
+ """Required. Output of the evaluated model."""
+
+ reference: Optional[str]
+ """Required. Ground truth used to compare against the prediction."""
+
+
+ExactMatchInstanceOrDict = Union[ExactMatchInstance, ExactMatchInstanceDict]
+
+
+class ExactMatchSpec(_common.BaseModel):
+ """Spec for exact match metric."""
+
+ pass
+
+
+class ExactMatchSpecDict(TypedDict, total=False):
+ """Spec for exact match metric."""
+
+ pass
+
+
+ExactMatchSpecOrDict = Union[ExactMatchSpec, ExactMatchSpecDict]
+
+
+class ExactMatchInput(_common.BaseModel):
+
+ instances: Optional[list[ExactMatchInstance]] = Field(
+ default=None,
+ description="""Required. Repeated exact match instances.""",
+ )
+ metric_spec: Optional[ExactMatchSpec] = Field(
+ default=None, description="""Required. Spec for exact match metric."""
+ )
+
+
+class ExactMatchInputDict(TypedDict, total=False):
+
+ instances: Optional[list[ExactMatchInstanceDict]]
+ """Required. Repeated exact match instances."""
+
+ metric_spec: Optional[ExactMatchSpecDict]
+ """Required. Spec for exact match metric."""
+
+
+ExactMatchInputOrDict = Union[ExactMatchInput, ExactMatchInputDict]
+
+
+class RougeInstance(_common.BaseModel):
+ """Rouge instance."""
+
+ prediction: Optional[str] = Field(
+ default=None, description="""Required. Output of the evaluated model."""
+ )
+ reference: Optional[str] = Field(
+ default=None,
+ description="""Required. Ground truth used to compare against the prediction.""",
+ )
+
+
+class RougeInstanceDict(TypedDict, total=False):
+ """Rouge instance."""
+
+ prediction: Optional[str]
+ """Required. Output of the evaluated model."""
+
+ reference: Optional[str]
+ """Required. Ground truth used to compare against the prediction."""
+
+
+RougeInstanceOrDict = Union[RougeInstance, RougeInstanceDict]
+
+
+class RougeSpec(_common.BaseModel):
+ """Spec for rouge metric."""
+
+ rouge_type: Optional[str] = Field(
+ default=None,
+ description="""Optional. Supported rouge types are rougen[1-9], rougeL, and rougeLsum.""",
+ )
+ split_summaries: Optional[bool] = Field(
+ default=None,
+ description="""Optional. Whether to split summaries while using rougeLsum.""",
+ )
+ use_stemmer: Optional[bool] = Field(
+ default=None,
+ description="""Optional. Whether to use stemmer to compute rouge score.""",
+ )
+
+
+class RougeSpecDict(TypedDict, total=False):
+ """Spec for rouge metric."""
+
+ rouge_type: Optional[str]
+ """Optional. Supported rouge types are rougen[1-9], rougeL, and rougeLsum."""
+
+ split_summaries: Optional[bool]
+ """Optional. Whether to split summaries while using rougeLsum."""
+
+ use_stemmer: Optional[bool]
+ """Optional. Whether to use stemmer to compute rouge score."""
+
+
+RougeSpecOrDict = Union[RougeSpec, RougeSpecDict]
+
+
+class RougeInput(_common.BaseModel):
+ """Rouge input."""
+
+ instances: Optional[list[RougeInstance]] = Field(
+ default=None, description="""Required. Repeated rouge instances."""
+ )
+ metric_spec: Optional[RougeSpec] = Field(
+ default=None, description="""Required. Spec for rouge score metric."""
+ )
+
+
+class RougeInputDict(TypedDict, total=False):
+ """Rouge input."""
+
+ instances: Optional[list[RougeInstanceDict]]
+ """Required. Repeated rouge instances."""
+
+ metric_spec: Optional[RougeSpecDict]
+ """Required. Spec for rouge score metric."""
+
+
+RougeInputOrDict = Union[RougeInput, RougeInputDict]
+
+
+class PointwiseMetricInstance(_common.BaseModel):
+ """Pointwise metric instance."""
+
+ json_instance: Optional[str] = Field(
+ default=None,
+ description="""Instance specified as a json string. String key-value pairs are expected in the json_instance to render PointwiseMetricSpec.instance_prompt_template.""",
+ )
+
+
+class PointwiseMetricInstanceDict(TypedDict, total=False):
+ """Pointwise metric instance."""
+
+ json_instance: Optional[str]
+ """Instance specified as a json string. String key-value pairs are expected in the json_instance to render PointwiseMetricSpec.instance_prompt_template."""
+
+
+PointwiseMetricInstanceOrDict = Union[
+ PointwiseMetricInstance, PointwiseMetricInstanceDict
+]
+
+
+class PointwiseMetricSpec(_common.BaseModel):
+ """Spec for pointwise metric."""
+
+ metric_prompt_template: Optional[str] = Field(
+ default=None,
+ description="""Required. Metric prompt template for pointwise metric.""",
+ )
+
+
+class PointwiseMetricSpecDict(TypedDict, total=False):
+ """Spec for pointwise metric."""
+
+ metric_prompt_template: Optional[str]
+ """Required. Metric prompt template for pointwise metric."""
+
+
+PointwiseMetricSpecOrDict = Union[PointwiseMetricSpec, PointwiseMetricSpecDict]
+
+
+class PointwiseMetricInput(_common.BaseModel):
+ """Pointwise metric input."""
+
+ instance: Optional[PointwiseMetricInstance] = Field(
+ default=None, description="""Required. Pointwise metric instance."""
+ )
+ metric_spec: Optional[PointwiseMetricSpec] = Field(
+ default=None, description="""Required. Spec for pointwise metric."""
+ )
+
+
+class PointwiseMetricInputDict(TypedDict, total=False):
+ """Pointwise metric input."""
+
+ instance: Optional[PointwiseMetricInstanceDict]
+ """Required. Pointwise metric instance."""
+
+ metric_spec: Optional[PointwiseMetricSpecDict]
+ """Required. Spec for pointwise metric."""
+
+
+PointwiseMetricInputOrDict = Union[PointwiseMetricInput, PointwiseMetricInputDict]
+
+
+class PairwiseMetricInstance(_common.BaseModel):
+ """Pairwise metric instance."""
+
+ json_instance: Optional[str] = Field(
+ default=None,
+ description="""Instance specified as a json string. String key-value pairs are expected in the json_instance to render PairwiseMetricSpec.instance_prompt_template.""",
+ )
+
+
+class PairwiseMetricInstanceDict(TypedDict, total=False):
+ """Pairwise metric instance."""
+
+ json_instance: Optional[str]
+ """Instance specified as a json string. String key-value pairs are expected in the json_instance to render PairwiseMetricSpec.instance_prompt_template."""
+
+
+PairwiseMetricInstanceOrDict = Union[PairwiseMetricInstance, PairwiseMetricInstanceDict]
+
+
+class PairwiseMetricSpec(_common.BaseModel):
+ """Spec for pairwise metric."""
+
+ metric_prompt_template: Optional[str] = Field(
+ default=None,
+ description="""Required. Metric prompt template for pairwise metric.""",
+ )
+
+
+class PairwiseMetricSpecDict(TypedDict, total=False):
+ """Spec for pairwise metric."""
+
+ metric_prompt_template: Optional[str]
+ """Required. Metric prompt template for pairwise metric."""
+
+
+PairwiseMetricSpecOrDict = Union[PairwiseMetricSpec, PairwiseMetricSpecDict]
+
+
+class PairwiseMetricInput(_common.BaseModel):
+ """Pairwise metric instance."""
+
+ instance: Optional[PairwiseMetricInstance] = Field(
+ default=None, description="""Required. Pairwise metric instance."""
+ )
+ metric_spec: Optional[PairwiseMetricSpec] = Field(
+ default=None, description="""Required. Spec for pairwise metric."""
+ )
+
+
+class PairwiseMetricInputDict(TypedDict, total=False):
+ """Pairwise metric instance."""
+
+ instance: Optional[PairwiseMetricInstanceDict]
+ """Required. Pairwise metric instance."""
+
+ metric_spec: Optional[PairwiseMetricSpecDict]
+ """Required. Spec for pairwise metric."""
+
+
+PairwiseMetricInputOrDict = Union[PairwiseMetricInput, PairwiseMetricInputDict]
+
+
+class ToolCallValidInstance(_common.BaseModel):
+ """Tool call valid instance."""
+
+ prediction: Optional[str] = Field(
+ default=None, description="""Required. Output of the evaluated model."""
+ )
+ reference: Optional[str] = Field(
+ default=None,
+ description="""Required. Ground truth used to compare against the prediction.""",
+ )
+
+
+class ToolCallValidInstanceDict(TypedDict, total=False):
+ """Tool call valid instance."""
+
+ prediction: Optional[str]
+ """Required. Output of the evaluated model."""
+
+ reference: Optional[str]
+ """Required. Ground truth used to compare against the prediction."""
+
+
+ToolCallValidInstanceOrDict = Union[ToolCallValidInstance, ToolCallValidInstanceDict]
+
+
+class ToolCallValidSpec(_common.BaseModel):
+ """Spec for tool call valid metric."""
+
+ pass
+
+
+class ToolCallValidSpecDict(TypedDict, total=False):
+ """Spec for tool call valid metric."""
+
+ pass
+
+
+ToolCallValidSpecOrDict = Union[ToolCallValidSpec, ToolCallValidSpecDict]
+
+
+class ToolCallValidInput(_common.BaseModel):
+ """Tool call valid input."""
+
+ instances: Optional[list[ToolCallValidInstance]] = Field(
+ default=None,
+ description="""Required. Repeated tool call valid instances.""",
+ )
+ metric_spec: Optional[ToolCallValidSpec] = Field(
+ default=None,
+ description="""Required. Spec for tool call valid metric.""",
+ )
+
+
+class ToolCallValidInputDict(TypedDict, total=False):
+ """Tool call valid input."""
+
+ instances: Optional[list[ToolCallValidInstanceDict]]
+ """Required. Repeated tool call valid instances."""
+
+ metric_spec: Optional[ToolCallValidSpecDict]
+ """Required. Spec for tool call valid metric."""
+
+
+ToolCallValidInputOrDict = Union[ToolCallValidInput, ToolCallValidInputDict]
+
+
+class ToolNameMatchInstance(_common.BaseModel):
+ """Tool name match instance."""
+
+ prediction: Optional[str] = Field(
+ default=None, description="""Required. Output of the evaluated model."""
+ )
+ reference: Optional[str] = Field(
+ default=None,
+ description="""Required. Ground truth used to compare against the prediction.""",
+ )
+
+
+class ToolNameMatchInstanceDict(TypedDict, total=False):
+ """Tool name match instance."""
+
+ prediction: Optional[str]
+ """Required. Output of the evaluated model."""
+
+ reference: Optional[str]
+ """Required. Ground truth used to compare against the prediction."""
+
+
+ToolNameMatchInstanceOrDict = Union[ToolNameMatchInstance, ToolNameMatchInstanceDict]
+
+
+class ToolNameMatchSpec(_common.BaseModel):
+ """Spec for tool name match metric."""
+
+ pass
+
+
+class ToolNameMatchSpecDict(TypedDict, total=False):
+ """Spec for tool name match metric."""
+
+ pass
+
+
+ToolNameMatchSpecOrDict = Union[ToolNameMatchSpec, ToolNameMatchSpecDict]
+
+
+class ToolNameMatchInput(_common.BaseModel):
+ """Tool name match input."""
+
+ instances: Optional[list[ToolNameMatchInstance]] = Field(
+ default=None,
+ description="""Required. Repeated tool name match instances.""",
+ )
+ metric_spec: Optional[ToolNameMatchSpec] = Field(
+ default=None,
+ description="""Required. Spec for tool name match metric.""",
+ )
+
+
+class ToolNameMatchInputDict(TypedDict, total=False):
+ """Tool name match input."""
+
+ instances: Optional[list[ToolNameMatchInstanceDict]]
+ """Required. Repeated tool name match instances."""
+
+ metric_spec: Optional[ToolNameMatchSpecDict]
+ """Required. Spec for tool name match metric."""
+
+
+ToolNameMatchInputOrDict = Union[ToolNameMatchInput, ToolNameMatchInputDict]
+
+
+class ToolParameterKeyMatchInstance(_common.BaseModel):
+ """Tool parameter key match instance."""
+
+ prediction: Optional[str] = Field(
+ default=None, description="""Required. Output of the evaluated model."""
+ )
+ reference: Optional[str] = Field(
+ default=None,
+ description="""Required. Ground truth used to compare against the prediction.""",
+ )
+
+
+class ToolParameterKeyMatchInstanceDict(TypedDict, total=False):
+ """Tool parameter key match instance."""
+
+ prediction: Optional[str]
+ """Required. Output of the evaluated model."""
+
+ reference: Optional[str]
+ """Required. Ground truth used to compare against the prediction."""
+
+
+ToolParameterKeyMatchInstanceOrDict = Union[
+ ToolParameterKeyMatchInstance, ToolParameterKeyMatchInstanceDict
+]
+
+
+class ToolParameterKeyMatchSpec(_common.BaseModel):
+ """Spec for tool parameter key match metric."""
+
+ pass
+
+
+class ToolParameterKeyMatchSpecDict(TypedDict, total=False):
+ """Spec for tool parameter key match metric."""
+
+ pass
+
+
+ToolParameterKeyMatchSpecOrDict = Union[
+ ToolParameterKeyMatchSpec, ToolParameterKeyMatchSpecDict
+]
+
+
+class ToolParameterKeyMatchInput(_common.BaseModel):
+ """Tool parameter key match input."""
+
+ instances: Optional[list[ToolParameterKeyMatchInstance]] = Field(
+ default=None,
+ description="""Required. Repeated tool parameter key match instances.""",
+ )
+ metric_spec: Optional[ToolParameterKeyMatchSpec] = Field(
+ default=None,
+ description="""Required. Spec for tool parameter key match metric.""",
+ )
+
+
+class ToolParameterKeyMatchInputDict(TypedDict, total=False):
+ """Tool parameter key match input."""
+
+ instances: Optional[list[ToolParameterKeyMatchInstanceDict]]
+ """Required. Repeated tool parameter key match instances."""
+
+ metric_spec: Optional[ToolParameterKeyMatchSpecDict]
+ """Required. Spec for tool parameter key match metric."""
+
+
+ToolParameterKeyMatchInputOrDict = Union[
+ ToolParameterKeyMatchInput, ToolParameterKeyMatchInputDict
+]
+
+
+class ToolParameterKVMatchInstance(_common.BaseModel):
+ """Tool parameter kv match instance."""
+
+ prediction: Optional[str] = Field(
+ default=None, description="""Required. Output of the evaluated model."""
+ )
+ reference: Optional[str] = Field(
+ default=None,
+ description="""Required. Ground truth used to compare against the prediction.""",
+ )
+
+
+class ToolParameterKVMatchInstanceDict(TypedDict, total=False):
+ """Tool parameter kv match instance."""
+
+ prediction: Optional[str]
+ """Required. Output of the evaluated model."""
+
+ reference: Optional[str]
+ """Required. Ground truth used to compare against the prediction."""
+
+
+ToolParameterKVMatchInstanceOrDict = Union[
+ ToolParameterKVMatchInstance, ToolParameterKVMatchInstanceDict
+]
+
+
+class ToolParameterKVMatchSpec(_common.BaseModel):
+ """Spec for tool parameter kv match metric."""
+
+ use_strict_string_match: Optional[bool] = Field(
+ default=None,
+ description="""Optional. Whether to use STRICT string match on parameter values.""",
+ )
+
+
+class ToolParameterKVMatchSpecDict(TypedDict, total=False):
+ """Spec for tool parameter kv match metric."""
+
+ use_strict_string_match: Optional[bool]
+ """Optional. Whether to use STRICT string match on parameter values."""
+
+
+ToolParameterKVMatchSpecOrDict = Union[
+ ToolParameterKVMatchSpec, ToolParameterKVMatchSpecDict
+]
+
+
+class ToolParameterKVMatchInput(_common.BaseModel):
+ """Tool parameter kv match input."""
+
+ instances: Optional[list[ToolParameterKVMatchInstance]] = Field(
+ default=None,
+ description="""Required. Repeated tool parameter key value match instances.""",
+ )
+ metric_spec: Optional[ToolParameterKVMatchSpec] = Field(
+ default=None,
+ description="""Required. Spec for tool parameter key value match metric.""",
+ )
+
+
+class ToolParameterKVMatchInputDict(TypedDict, total=False):
+ """Tool parameter kv match input."""
+
+ instances: Optional[list[ToolParameterKVMatchInstanceDict]]
+ """Required. Repeated tool parameter key value match instances."""
+
+ metric_spec: Optional[ToolParameterKVMatchSpecDict]
+ """Required. Spec for tool parameter key value match metric."""
+
+
+ToolParameterKVMatchInputOrDict = Union[
+ ToolParameterKVMatchInput, ToolParameterKVMatchInputDict
+]
+
+
+class HttpOptions(_common.BaseModel):
+ """HTTP options to be used in each of the requests."""
+
+ base_url: Optional[str] = Field(
+ default=None,
+ description="""The base URL for the AI platform service endpoint.""",
+ )
+ api_version: Optional[str] = Field(
+ default=None, description="""Specifies the version of the API to use."""
+ )
+ headers: Optional[dict[str, str]] = Field(
+ default=None,
+ description="""Additional HTTP headers to be sent with the request.""",
+ )
+ timeout: Optional[int] = Field(
+ default=None, description="""Timeout for the request in milliseconds."""
+ )
+ client_args: Optional[dict[str, Any]] = Field(
+ default=None, description="""Args passed to the HTTP client."""
+ )
+ async_client_args: Optional[dict[str, Any]] = Field(
+ default=None, description="""Args passed to the async HTTP client."""
+ )
+
+
+class HttpOptionsDict(TypedDict, total=False):
+ """HTTP options to be used in each of the requests."""
+
+ base_url: Optional[str]
+ """The base URL for the AI platform service endpoint."""
+
+ api_version: Optional[str]
+ """Specifies the version of the API to use."""
+
+ headers: Optional[dict[str, str]]
+ """Additional HTTP headers to be sent with the request."""
+
+ timeout: Optional[int]
+ """Timeout for the request in milliseconds."""
+
+ client_args: Optional[dict[str, Any]]
+ """Args passed to the HTTP client."""
+
+ async_client_args: Optional[dict[str, Any]]
+ """Args passed to the async HTTP client."""
+
+
+HttpOptionsOrDict = Union[HttpOptions, HttpOptionsDict]
+
+
+class EvaluateInstancesConfig(_common.BaseModel):
+ """Config for evaluate instances."""
+
+ http_options: Optional[HttpOptions] = Field(
+ default=None, description="""Used to override HTTP request options."""
+ )
+
+
+class EvaluateInstancesConfigDict(TypedDict, total=False):
+ """Config for evaluate instances."""
+
+ http_options: Optional[HttpOptionsDict]
+ """Used to override HTTP request options."""
+
+
+EvaluateInstancesConfigOrDict = Union[
+ EvaluateInstancesConfig, EvaluateInstancesConfigDict
+]
+
+
+class _EvaluateInstancesRequestParameters(_common.BaseModel):
+ """Parameters for evaluating instances."""
+
+ bleu_input: Optional[BleuInput] = Field(default=None, description="""""")
+ exact_match_input: Optional[ExactMatchInput] = Field(
+ default=None, description=""""""
+ )
+ rouge_input: Optional[RougeInput] = Field(default=None, description="""""")
+ pointwise_metric_input: Optional[PointwiseMetricInput] = Field(
+ default=None, description=""""""
+ )
+ pairwise_metric_input: Optional[PairwiseMetricInput] = Field(
+ default=None, description=""""""
+ )
+ tool_call_valid_input: Optional[ToolCallValidInput] = Field(
+ default=None, description=""""""
+ )
+ tool_name_match_input: Optional[ToolNameMatchInput] = Field(
+ default=None, description=""""""
+ )
+ tool_parameter_key_match_input: Optional[ToolParameterKeyMatchInput] = Field(
+ default=None, description=""""""
+ )
+ tool_parameter_kv_match_input: Optional[ToolParameterKVMatchInput] = Field(
+ default=None, description=""""""
+ )
+ config: Optional[EvaluateInstancesConfig] = Field(default=None, description="""""")
+
+
+class _EvaluateInstancesRequestParametersDict(TypedDict, total=False):
+ """Parameters for evaluating instances."""
+
+ bleu_input: Optional[BleuInputDict]
+ """"""
+
+ exact_match_input: Optional[ExactMatchInputDict]
+ """"""
+
+ rouge_input: Optional[RougeInputDict]
+ """"""
+
+ pointwise_metric_input: Optional[PointwiseMetricInputDict]
+ """"""
+
+ pairwise_metric_input: Optional[PairwiseMetricInputDict]
+ """"""
+
+ tool_call_valid_input: Optional[ToolCallValidInputDict]
+ """"""
+
+ tool_name_match_input: Optional[ToolNameMatchInputDict]
+ """"""
+
+ tool_parameter_key_match_input: Optional[ToolParameterKeyMatchInputDict]
+ """"""
+
+ tool_parameter_kv_match_input: Optional[ToolParameterKVMatchInputDict]
+ """"""
+
+ config: Optional[EvaluateInstancesConfigDict]
+ """"""
+
+
+_EvaluateInstancesRequestParametersOrDict = Union[
+ _EvaluateInstancesRequestParameters, _EvaluateInstancesRequestParametersDict
+]
+
+
+class BleuMetricValue(_common.BaseModel):
+ """Bleu metric value for an instance."""
+
+ score: Optional[float] = Field(
+ default=None, description="""Output only. Bleu score."""
+ )
+
+
+class BleuMetricValueDict(TypedDict, total=False):
+ """Bleu metric value for an instance."""
+
+ score: Optional[float]
+ """Output only. Bleu score."""
+
+
+BleuMetricValueOrDict = Union[BleuMetricValue, BleuMetricValueDict]
+
+
+class BleuResults(_common.BaseModel):
+ """Results for bleu metric."""
+
+ bleu_metric_values: Optional[list[BleuMetricValue]] = Field(
+ default=None, description="""Output only. Bleu metric values."""
+ )
+
+
+class BleuResultsDict(TypedDict, total=False):
+ """Results for bleu metric."""
+
+ bleu_metric_values: Optional[list[BleuMetricValueDict]]
+ """Output only. Bleu metric values."""
+
+
+BleuResultsOrDict = Union[BleuResults, BleuResultsDict]
+
+
+class ExactMatchMetricValue(_common.BaseModel):
+ """Exact match metric value for an instance."""
+
+ score: Optional[float] = Field(
+ default=None, description="""Output only. Exact match score."""
+ )
+
+
+class ExactMatchMetricValueDict(TypedDict, total=False):
+ """Exact match metric value for an instance."""
+
+ score: Optional[float]
+ """Output only. Exact match score."""
+
+
+ExactMatchMetricValueOrDict = Union[ExactMatchMetricValue, ExactMatchMetricValueDict]
+
+
+class ExactMatchResults(_common.BaseModel):
+ """Results for exact match metric."""
+
+ exact_match_metric_values: Optional[list[ExactMatchMetricValue]] = Field(
+ default=None, description="""Output only. Exact match metric values."""
+ )
+
+
+class ExactMatchResultsDict(TypedDict, total=False):
+ """Results for exact match metric."""
+
+ exact_match_metric_values: Optional[list[ExactMatchMetricValueDict]]
+ """Output only. Exact match metric values."""
+
+
+ExactMatchResultsOrDict = Union[ExactMatchResults, ExactMatchResultsDict]
+
+
+class PairwiseMetricResult(_common.BaseModel):
+ """Spec for pairwise metric result."""
+
+ explanation: Optional[str] = Field(
+ default=None,
+ description="""Output only. Explanation for pairwise metric score.""",
+ )
+ pairwise_choice: Optional[PairwiseChoice] = Field(
+ default=None, description="""Output only. Pairwise metric choice."""
+ )
+
+
+class PairwiseMetricResultDict(TypedDict, total=False):
+ """Spec for pairwise metric result."""
+
+ explanation: Optional[str]
+ """Output only. Explanation for pairwise metric score."""
+
+ pairwise_choice: Optional[PairwiseChoice]
+ """Output only. Pairwise metric choice."""
+
+
+PairwiseMetricResultOrDict = Union[PairwiseMetricResult, PairwiseMetricResultDict]
+
+
+class PointwiseMetricResult(_common.BaseModel):
+ """Spec for pointwise metric result."""
+
+ explanation: Optional[str] = Field(
+ default=None,
+ description="""Output only. Explanation for pointwise metric score.""",
+ )
+ score: Optional[float] = Field(
+ default=None, description="""Output only. Pointwise metric score."""
+ )
+
+
+class PointwiseMetricResultDict(TypedDict, total=False):
+ """Spec for pointwise metric result."""
+
+ explanation: Optional[str]
+ """Output only. Explanation for pointwise metric score."""
+
+ score: Optional[float]
+ """Output only. Pointwise metric score."""
+
+
+PointwiseMetricResultOrDict = Union[PointwiseMetricResult, PointwiseMetricResultDict]
+
+
+class RougeMetricValue(_common.BaseModel):
+ """Rouge metric value for an instance."""
+
+ score: Optional[float] = Field(
+ default=None, description="""Output only. Rouge score."""
+ )
+
+
+class RougeMetricValueDict(TypedDict, total=False):
+ """Rouge metric value for an instance."""
+
+ score: Optional[float]
+ """Output only. Rouge score."""
+
+
+RougeMetricValueOrDict = Union[RougeMetricValue, RougeMetricValueDict]
+
+
+class RougeResults(_common.BaseModel):
+ """Results for rouge metric."""
+
+ rouge_metric_values: Optional[list[RougeMetricValue]] = Field(
+ default=None, description="""Output only. Rouge metric values."""
+ )
+
+
+class RougeResultsDict(TypedDict, total=False):
+ """Results for rouge metric."""
+
+ rouge_metric_values: Optional[list[RougeMetricValueDict]]
+ """Output only. Rouge metric values."""
+
+
+RougeResultsOrDict = Union[RougeResults, RougeResultsDict]
+
+
+class SummarizationVerbosityResult(_common.BaseModel):
+ """Spec for summarization verbosity result."""
+
+ confidence: Optional[float] = Field(
+ default=None,
+ description="""Output only. Confidence for summarization verbosity score.""",
+ )
+ explanation: Optional[str] = Field(
+ default=None,
+ description="""Output only. Explanation for summarization verbosity score.""",
+ )
+ score: Optional[float] = Field(
+ default=None,
+ description="""Output only. Summarization Verbosity score.""",
+ )
+
+
+class SummarizationVerbosityResultDict(TypedDict, total=False):
+ """Spec for summarization verbosity result."""
+
+ confidence: Optional[float]
+ """Output only. Confidence for summarization verbosity score."""
+
+ explanation: Optional[str]
+ """Output only. Explanation for summarization verbosity score."""
+
+ score: Optional[float]
+ """Output only. Summarization Verbosity score."""
+
+
+SummarizationVerbosityResultOrDict = Union[
+ SummarizationVerbosityResult, SummarizationVerbosityResultDict
+]
+
+
+class ToolCallValidMetricValue(_common.BaseModel):
+ """Tool call valid metric value for an instance."""
+
+ score: Optional[float] = Field(
+ default=None, description="""Output only. Tool call valid score."""
+ )
+
+
+class ToolCallValidMetricValueDict(TypedDict, total=False):
+ """Tool call valid metric value for an instance."""
+
+ score: Optional[float]
+ """Output only. Tool call valid score."""
+
+
+ToolCallValidMetricValueOrDict = Union[
+ ToolCallValidMetricValue, ToolCallValidMetricValueDict
+]
+
+
+class ToolCallValidResults(_common.BaseModel):
+ """Results for tool call valid metric."""
+
+ tool_call_valid_metric_values: Optional[list[ToolCallValidMetricValue]] = Field(
+ default=None,
+ description="""Output only. Tool call valid metric values.""",
+ )
+
+
+class ToolCallValidResultsDict(TypedDict, total=False):
+ """Results for tool call valid metric."""
+
+ tool_call_valid_metric_values: Optional[list[ToolCallValidMetricValueDict]]
+ """Output only. Tool call valid metric values."""
+
+
+ToolCallValidResultsOrDict = Union[ToolCallValidResults, ToolCallValidResultsDict]
+
+
+class ToolNameMatchMetricValue(_common.BaseModel):
+ """Tool name match metric value for an instance."""
+
+ score: Optional[float] = Field(
+ default=None, description="""Output only. Tool name match score."""
+ )
+
+
+class ToolNameMatchMetricValueDict(TypedDict, total=False):
+ """Tool name match metric value for an instance."""
+
+ score: Optional[float]
+ """Output only. Tool name match score."""
+
+
+ToolNameMatchMetricValueOrDict = Union[
+ ToolNameMatchMetricValue, ToolNameMatchMetricValueDict
+]
+
+
+class ToolNameMatchResults(_common.BaseModel):
+ """Results for tool name match metric."""
+
+ tool_name_match_metric_values: Optional[list[ToolNameMatchMetricValue]] = Field(
+ default=None,
+ description="""Output only. Tool name match metric values.""",
+ )
+
+
+class ToolNameMatchResultsDict(TypedDict, total=False):
+ """Results for tool name match metric."""
+
+ tool_name_match_metric_values: Optional[list[ToolNameMatchMetricValueDict]]
+ """Output only. Tool name match metric values."""
+
+
+ToolNameMatchResultsOrDict = Union[ToolNameMatchResults, ToolNameMatchResultsDict]
+
+
+class ToolParameterKeyMatchMetricValue(_common.BaseModel):
+ """Tool parameter key match metric value for an instance."""
+
+ score: Optional[float] = Field(
+ default=None,
+ description="""Output only. Tool parameter key match score.""",
+ )
+
+
+class ToolParameterKeyMatchMetricValueDict(TypedDict, total=False):
+ """Tool parameter key match metric value for an instance."""
+
+ score: Optional[float]
+ """Output only. Tool parameter key match score."""
+
+
+ToolParameterKeyMatchMetricValueOrDict = Union[
+ ToolParameterKeyMatchMetricValue, ToolParameterKeyMatchMetricValueDict
+]
+
+
+class ToolParameterKeyMatchResults(_common.BaseModel):
+ """Results for tool parameter key match metric."""
+
+ tool_parameter_key_match_metric_values: Optional[
+ list[ToolParameterKeyMatchMetricValue]
+ ] = Field(
+ default=None,
+ description="""Output only. Tool parameter key match metric values.""",
+ )
+
+
+class ToolParameterKeyMatchResultsDict(TypedDict, total=False):
+ """Results for tool parameter key match metric."""
+
+ tool_parameter_key_match_metric_values: Optional[
+ list[ToolParameterKeyMatchMetricValueDict]
+ ]
+ """Output only. Tool parameter key match metric values."""
+
+
+ToolParameterKeyMatchResultsOrDict = Union[
+ ToolParameterKeyMatchResults, ToolParameterKeyMatchResultsDict
+]
+
+
+class ToolParameterKVMatchMetricValue(_common.BaseModel):
+ """Tool parameter key value match metric value for an instance."""
+
+ score: Optional[float] = Field(
+ default=None,
+ description="""Output only. Tool parameter key value match score.""",
+ )
+
+
+class ToolParameterKVMatchMetricValueDict(TypedDict, total=False):
+ """Tool parameter key value match metric value for an instance."""
+
+ score: Optional[float]
+ """Output only. Tool parameter key value match score."""
+
+
+ToolParameterKVMatchMetricValueOrDict = Union[
+ ToolParameterKVMatchMetricValue, ToolParameterKVMatchMetricValueDict
+]
+
+
+class ToolParameterKVMatchResults(_common.BaseModel):
+ """Results for tool parameter key value match metric."""
+
+ tool_parameter_kv_match_metric_values: Optional[
+ list[ToolParameterKVMatchMetricValue]
+ ] = Field(
+ default=None,
+ description="""Output only. Tool parameter key value match metric values.""",
+ )
+
+
+class ToolParameterKVMatchResultsDict(TypedDict, total=False):
+ """Results for tool parameter key value match metric."""
+
+ tool_parameter_kv_match_metric_values: Optional[
+ list[ToolParameterKVMatchMetricValueDict]
+ ]
+ """Output only. Tool parameter key value match metric values."""
+
+
+ToolParameterKVMatchResultsOrDict = Union[
+ ToolParameterKVMatchResults, ToolParameterKVMatchResultsDict
+]
+
+
+class EvaluateInstancesResponse(_common.BaseModel):
+ """Result of evaluating an LLM metric."""
+
+ bleu_results: Optional[BleuResults] = Field(
+ default=None, description="""Results for bleu metric."""
+ )
+ exact_match_results: Optional[ExactMatchResults] = Field(
+ default=None,
+ description="""Auto metric evaluation results. Results for exact match metric.""",
+ )
+ pairwise_metric_result: Optional[PairwiseMetricResult] = Field(
+ default=None, description="""Result for pairwise metric."""
+ )
+ pointwise_metric_result: Optional[PointwiseMetricResult] = Field(
+ default=None,
+ description="""Generic metrics. Result for pointwise metric.""",
+ )
+ rouge_results: Optional[RougeResults] = Field(
+ default=None, description="""Results for rouge metric."""
+ )
+ summarization_verbosity_result: Optional[SummarizationVerbosityResult] = Field(
+ default=None,
+ description="""Result for summarization verbosity metric.""",
+ )
+ tool_call_valid_results: Optional[ToolCallValidResults] = Field(
+ default=None,
+ description="""Tool call metrics. Results for tool call valid metric.""",
+ )
+ tool_name_match_results: Optional[ToolNameMatchResults] = Field(
+ default=None, description="""Results for tool name match metric."""
+ )
+ tool_parameter_key_match_results: Optional[ToolParameterKeyMatchResults] = Field(
+ default=None,
+ description="""Results for tool parameter key match metric.""",
+ )
+ tool_parameter_kv_match_results: Optional[ToolParameterKVMatchResults] = Field(
+ default=None,
+ description="""Results for tool parameter key value match metric.""",
+ )
+
+
+class EvaluateInstancesResponseDict(TypedDict, total=False):
+ """Result of evaluating an LLM metric."""
+
+ bleu_results: Optional[BleuResultsDict]
+ """Results for bleu metric."""
+
+ exact_match_results: Optional[ExactMatchResultsDict]
+ """Auto metric evaluation results. Results for exact match metric."""
+
+ pairwise_metric_result: Optional[PairwiseMetricResultDict]
+ """Result for pairwise metric."""
+
+ pointwise_metric_result: Optional[PointwiseMetricResultDict]
+ """Generic metrics. Result for pointwise metric."""
+
+ rouge_results: Optional[RougeResultsDict]
+ """Results for rouge metric."""
+
+ summarization_verbosity_result: Optional[SummarizationVerbosityResultDict]
+ """Result for summarization verbosity metric."""
+
+ tool_call_valid_results: Optional[ToolCallValidResultsDict]
+ """Tool call metrics. Results for tool call valid metric."""
+
+ tool_name_match_results: Optional[ToolNameMatchResultsDict]
+ """Results for tool name match metric."""
+
+ tool_parameter_key_match_results: Optional[ToolParameterKeyMatchResultsDict]
+ """Results for tool parameter key match metric."""
+
+ tool_parameter_kv_match_results: Optional[ToolParameterKVMatchResultsDict]
+ """Results for tool parameter key value match metric."""
+
+
+EvaluateInstancesResponseOrDict = Union[
+ EvaluateInstancesResponse, EvaluateInstancesResponseDict
+]
+
+
+class EvalDataset(_common.BaseModel):
+
+ file: Optional[str] = Field(default=None, description="""""")
+
+
+class EvalDatasetDict(TypedDict, total=False):
+
+ file: Optional[str]
+ """"""
+
+
+EvalDatasetOrDict = Union[EvalDataset, EvalDatasetDict]
diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py
index 37ba528134..983e3c572a 100644
--- a/vertexai/agent_engines/_agent_engines.py
+++ b/vertexai/agent_engines/_agent_engines.py
@@ -338,8 +338,11 @@ def create(
"""
sys_version = f"{sys.version_info.major}.{sys.version_info.minor}"
_validate_sys_version_or_raise(sys_version)
+ gcs_dir_name = gcs_dir_name or _DEFAULT_GCS_DIR_NAME
+ staging_bucket = initializer.global_config.staging_bucket
if agent_engine is not None:
agent_engine = _validate_agent_engine_or_raise(agent_engine)
+ _validate_staging_bucket_or_raise(staging_bucket)
if agent_engine is None:
if requirements is not None:
raise ValueError("requirements must be None if agent_engine is None.")
@@ -350,12 +353,9 @@ def create(
requirements=requirements,
)
extra_packages = _validate_extra_packages_or_raise(extra_packages)
- gcs_dir_name = gcs_dir_name or _DEFAULT_GCS_DIR_NAME
sdk_resource = cls.__new__(cls)
base.VertexAiResourceNounWithFutureManager.__init__(sdk_resource)
- staging_bucket = initializer.global_config.staging_bucket
- _validate_staging_bucket_or_raise(staging_bucket)
# Prepares the Agent Engine for creation in Vertex AI.
# This involves packaging and uploading the artifacts for
# agent_engine, requirements and extra_packages to
@@ -881,17 +881,18 @@ def _prepare(
gcs_dir_name (str): The GCS bucket directory under `staging_bucket` to
use for staging the artifacts needed.
"""
+ if agent_engine is None:
+ return
gcs_bucket = _get_gcs_bucket(
project=project,
location=location,
staging_bucket=staging_bucket,
)
- if agent_engine is not None:
- _upload_agent_engine(
- agent_engine=agent_engine,
- gcs_bucket=gcs_bucket,
- gcs_dir_name=gcs_dir_name,
- )
+ _upload_agent_engine(
+ agent_engine=agent_engine,
+ gcs_bucket=gcs_bucket,
+ gcs_dir_name=gcs_dir_name,
+ )
if requirements is not None:
_upload_requirements(
requirements=requirements,
@@ -992,7 +993,7 @@ def _generate_update_request_or_raise(
Union[Sequence[str], Dict[str, Union[str, aip_types.SecretRef]]]
] = None,
) -> reasoning_engine_service.UpdateReasoningEngineRequest:
- """Tries to generates the update request for the agent engine."""
+ """Tries to generate the update request for the agent engine."""
is_spec_update = False
update_masks: List[str] = []
agent_engine_spec = aip_types.ReasoningEngineSpec()
diff --git a/vertexai/model_garden/_model_garden.py b/vertexai/model_garden/_model_garden.py
index 826bc6a7dc..a311f40c8a 100644
--- a/vertexai/model_garden/_model_garden.py
+++ b/vertexai/model_garden/_model_garden.py
@@ -257,7 +257,17 @@ class _ModelGardenClientWithOverride(utils.ClientWithOverride):
class OpenModel:
- """Represents a Model Garden Open model."""
+ """Represents a Model Garden Open model.
+
+ Attributes:
+ model_name: Model Garden model resource name in the format of
+ `publishers/{publisher}/models/{model}@{version}`, or a
+ simplified resource name in the format of
+ `{publisher}/{model}@{version}`, or a Hugging Face model ID in
+ the format of `{organization}/{model}`.
+ """
+
+ __module__ = "vertexai.preview.model_garden"
def __init__(
self,
@@ -373,6 +383,7 @@ def deploy(
reservation_affinity_values: Optional[List[str]] = None,
use_dedicated_endpoint: Optional[bool] = False,
fast_tryout_enabled: Optional[bool] = False,
+ system_labels: Optional[Dict[str, str]] = None,
endpoint_display_name: Optional[str] = None,
model_display_name: Optional[str] = None,
deploy_request_timeout: Optional[float] = None,
@@ -399,139 +410,138 @@ def deploy(
Args:
accept_eula (bool): Whether to accept the End User License Agreement.
hugging_face_access_token (str): The access token to access Hugging Face
- models. Reference: https://huggingface.co/docs/hub/en/security-tokens
- machine_type (str):
- Optional. The type of machine. Not specifying machine type will
- result in model to be deployed with automatic resources.
- min_replica_count (int):
- Optional. The minimum number of machine replicas this deployed
- model will be always deployed on. If traffic against it increases,
- it may dynamically be deployed onto more replicas, and as traffic
- decreases, some of these extra replicas may be freed.
- max_replica_count (int):
- Optional. The maximum number of replicas this deployed model may
- be deployed on when the traffic against it increases. If requested
- value is too large, the deployment will error, but if deployment
- succeeds then the ability to scale the model to that many replicas
- is guaranteed (barring service outages). If traffic against the
- deployed model increases beyond what its replicas at maximum may
- handle, a portion of the traffic will be dropped. If this value
- is not provided, the larger value of min_replica_count or 1 will
- be used. If value provided is smaller than min_replica_count, it
- will automatically be increased to be min_replica_count.
- accelerator_type (str):
- Optional. Hardware accelerator type. Must also set accelerator_count if used.
- One of ACCELERATOR_TYPE_UNSPECIFIED, NVIDIA_TESLA_K80, NVIDIA_TESLA_P100,
- NVIDIA_TESLA_V100, NVIDIA_TESLA_P4, NVIDIA_TESLA_T4
- accelerator_count (int):
- Optional. The number of accelerators to attach to a worker replica.
- spot (bool):
- Optional. Whether to schedule the deployment workload on spot VMs.
- reservation_affinity_type (str):
- Optional. The type of reservation affinity.
- One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION,
- SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION
- reservation_affinity_key (str):
- Optional. Corresponds to the label key of a reservation resource.
- To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key
- and specify the name of your reservation as its value.
- reservation_affinity_values (List[str]):
- Optional. Corresponds to the label values of a reservation resource.
- This must be the full resource name of the reservation.
- Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}'
- use_dedicated_endpoint (bool):
- Optional. Default value is False. If set to True, the underlying prediction call will be made
- using the dedicated endpoint dns.
- fast_tryout_enabled (bool):
- Optional. Defaults to False.
- If True, model will be deployed using faster deployment path.
- Useful for quick experiments. Not for production workloads. Only
- available for most popular models with certain machine types.
+ models. Reference: https://huggingface.co/docs/hub/en/security-tokens
+ machine_type (str): Optional. The type of machine. Not specifying
+ machine type will result in model to be deployed with automatic
+ resources.
+ min_replica_count (int): Optional. The minimum number of machine
+ replicas this deployed model will be always deployed on. If traffic
+ against it increases, it may dynamically be deployed onto more
+ replicas, and as traffic decreases, some of these extra replicas may
+ be freed.
+ max_replica_count (int): Optional. The maximum number of replicas this
+ deployed model may be deployed on when the traffic against it
+ increases. If requested value is too large, the deployment will error,
+ but if deployment succeeds then the ability to scale the model to that
+ many replicas is guaranteed (barring service outages). If traffic
+ against the deployed model increases beyond what its replicas at
+ maximum may handle, a portion of the traffic will be dropped. If this
+ value is not provided, the larger value of min_replica_count or 1 will
+ be used. If value provided is smaller than min_replica_count, it will
+ automatically be increased to be min_replica_count.
+ accelerator_type (str): Optional. Hardware accelerator type. Must also
+ set accelerator_count if used. One of ACCELERATOR_TYPE_UNSPECIFIED,
+ NVIDIA_TESLA_K80, NVIDIA_TESLA_P100, NVIDIA_TESLA_V100,
+ NVIDIA_TESLA_P4, NVIDIA_TESLA_T4
+ accelerator_count (int): Optional. The number of accelerators to attach
+ to a worker replica.
+ spot (bool): Optional. Whether to schedule the deployment workload on
+ spot VMs.
+ reservation_affinity_type (str): Optional. The type of reservation
+ affinity. One of NO_RESERVATION, ANY_RESERVATION,
+ SPECIFIC_RESERVATION, SPECIFIC_THEN_ANY_RESERVATION,
+ SPECIFIC_THEN_NO_RESERVATION
+ reservation_affinity_key (str): Optional. Corresponds to the label key
+ of a reservation resource. To target a SPECIFIC_RESERVATION by name,
+ use `compute.googleapis.com/reservation-name` as the key and specify
+ the name of your reservation as its value.
+ reservation_affinity_values (List[str]): Optional. Corresponds to the
+ label values of a reservation resource. This must be the full resource
+ name of the reservation.
+ Format:
+ 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}'
+ use_dedicated_endpoint (bool): Optional. Default value is False. If set
+ to True, the underlying prediction call will be made using the
+ dedicated endpoint dns.
+ fast_tryout_enabled (bool): Optional. Defaults to False. If True, model
+ will be deployed using faster deployment path. Useful for quick
+ experiments. Not for production workloads. Only available for most
+ popular models with certain machine types.
+ system_labels (Dict[str, str]): Optional. System labels for Model Garden
+ deployments. These labels are managed by Google and for tracking
+ purposes only.
endpoint_display_name: The display name of the created endpoint.
model_display_name: The display name of the uploaded model.
- deploy_request_timeout: The timeout for the deploy request. Default
- is 2 hours.
- serving_container_spec (types.ModelContainerSpec):
- Optional. The container specification for the model instance.
- This specification overrides the default container specification
- and other serving container parameters.
- serving_container_image_uri (str):
- Optional. The URI of the Model serving container. This parameter is required
- if the parameter `local_model` is not specified.
- serving_container_predict_route (str):
- Optional. An HTTP path to send prediction requests to the container, and
- which must be supported by it. If not specified a default HTTP path will
- be used by Vertex AI.
- serving_container_health_route (str):
- Optional. An HTTP path to send health check requests to the container, and which
- must be supported by it. If not specified a standard HTTP path will be
- used by Vertex AI.
- serving_container_command: Optional[Sequence[str]]=None,
- The command with which the container is run. Not executed within a
- shell. The Docker image's ENTRYPOINT is used if this is not provided.
- Variable references $(VAR_NAME) are expanded using the container's
- environment. If a variable cannot be resolved, the reference in the
- input string will be unchanged. The $(VAR_NAME) syntax can be escaped
- with a double $$, ie: $$(VAR_NAME). Escaped references will never be
- expanded, regardless of whether the variable exists or not.
- serving_container_args: Optional[Sequence[str]]=None,
- The arguments to the command. The Docker image's CMD is used if this is
- not provided. Variable references $(VAR_NAME) are expanded using the
- container's environment. If a variable cannot be resolved, the reference
- in the input string will be unchanged. The $(VAR_NAME) syntax can be
- escaped with a double $$, ie: $$(VAR_NAME). Escaped references will
- never be expanded, regardless of whether the variable exists or not.
+ deploy_request_timeout: The timeout for the deploy request. Default is 2
+ hours.
+ serving_container_spec (types.ModelContainerSpec): Optional. The
+ container specification for the model instance. This specification
+ overrides the default container specification and other serving
+ container parameters.
+ serving_container_image_uri (str): Optional. The URI of the Model
+ serving container. This parameter is required if the parameter
+ `local_model` is not specified.
+ serving_container_predict_route (str): Optional. An HTTP path to send
+ prediction requests to the container, and which must be supported by
+ it. If not specified a default HTTP path will be used by Vertex AI.
+ serving_container_health_route (str): Optional. An HTTP path to send
+ health check requests to the container, and which must be supported by
+ it. If not specified a standard HTTP path will be used by Vertex AI.
+ serving_container_command: Optional[Sequence[str]]=None, The command
+ with which the container is run. Not executed within a shell. The
+ Docker image's ENTRYPOINT is used if this is not provided. Variable
+ references $(VAR_NAME) are expanded using the container's environment.
+ If a variable cannot be resolved, the reference in the input string
+ will be unchanged. The $(VAR_NAME) syntax can be escaped with a double
+ $$, ie: $$(VAR_NAME). Escaped references will never be expanded,
+ regardless of whether the variable exists or not.
+ serving_container_args: Optional[Sequence[str]]=None, The arguments to
+ the command. The Docker image's CMD is used if this is not provided.
+ Variable references $(VAR_NAME) are expanded using the container's
+ environment. If a variable cannot be resolved, the reference in the
+ input string will be unchanged. The $(VAR_NAME) syntax can be escaped
+ with a double $$, ie: $$(VAR_NAME). Escaped references will never be
+ expanded, regardless of whether the variable exists or not.
serving_container_environment_variables: Optional[Dict[str, str]]=None,
- The environment variables that are to be present in the container.
- Should be a dictionary where keys are environment variable names
- and values are environment variable values for those names.
- serving_container_ports: Optional[Sequence[int]]=None,
- Declaration of ports that are exposed by the container. This field is
- primarily informational, it gives Vertex AI information about the
- network connections the container uses. Listing or not a port here has
- no impact on whether the port is actually exposed, any port listening on
- the default "0.0.0.0" address inside a container will be accessible from
- the network.
- serving_container_grpc_ports: Optional[Sequence[int]]=None,
- Declaration of ports that are exposed by the container. Vertex AI sends gRPC
- prediction requests that it receives to the first port on this list. Vertex
- AI also sends liveness and health checks to this port.
- If you do not specify this field, gRPC requests to the container will be
- disabled.
- Vertex AI does not use ports other than the first one listed. This field
- corresponds to the `ports` field of the Kubernetes Containers v1 core API.
- serving_container_deployment_timeout (int):
- Optional. Deployment timeout in seconds.
- serving_container_shared_memory_size_mb (int):
- Optional. The amount of the VM memory to reserve as the shared
- memory for the model in megabytes.
- serving_container_startup_probe_exec (Sequence[str]):
- Optional. Exec specifies the action to take. Used by startup
- probe. An example of this argument would be
- ["cat", "/tmp/healthy"]
- serving_container_startup_probe_period_seconds (int):
- Optional. How often (in seconds) to perform the startup probe.
- Default to 10 seconds. Minimum value is 1.
- serving_container_startup_probe_timeout_seconds (int):
- Optional. Number of seconds after which the startup probe times
- out. Defaults to 1 second. Minimum value is 1.
- serving_container_health_probe_exec (Sequence[str]):
- Optional. Exec specifies the action to take. Used by health
- probe. An example of this argument would be
- ["cat", "/tmp/healthy"]
- serving_container_health_probe_period_seconds (int):
- Optional. How often (in seconds) to perform the health probe.
- Default to 10 seconds. Minimum value is 1.
- serving_container_health_probe_timeout_seconds (int):
- Optional. Number of seconds after which the health probe times
- out. Defaults to 1 second. Minimum value is 1.
+ The environment variables that are to be present in the container.
+ Should be a dictionary where keys are environment variable names and
+ values are environment variable values for those names.
+ serving_container_ports: Optional[Sequence[int]]=None, Declaration of
+ ports that are exposed by the container. This field is primarily
+ informational, it gives Vertex AI information about the network
+ connections the container uses. Listing or not a port here has no
+ impact on whether the port is actually exposed, any port listening on
+ the default "0.0.0.0" address inside a container will be accessible
+ from the network.
+ serving_container_grpc_ports: Optional[Sequence[int]]=None, Declaration
+ of ports that are exposed by the container. Vertex AI sends gRPC
+ prediction requests that it receives to the first port on this list.
+ Vertex AI also sends liveness and health checks to this port. If you
+ do not specify this field, gRPC requests to the container will be
+ disabled. Vertex AI does not use ports other than the first one
+ listed. This field corresponds to the `ports` field of the Kubernetes
+ Containers v1 core API.
+ serving_container_deployment_timeout (int): Optional. Deployment timeout
+ in seconds.
+ serving_container_shared_memory_size_mb (int): Optional. The amount of
+ the VM memory to reserve as the shared memory for the model in
+ megabytes.
+ serving_container_startup_probe_exec (Sequence[str]): Optional. Exec
+ specifies the action to take. Used by startup probe. An example of
+ this argument would be ["cat", "/tmp/healthy"]
+ serving_container_startup_probe_period_seconds (int): Optional. How
+ often (in seconds) to perform the startup probe. Default to 10
+ seconds. Minimum value is 1.
+ serving_container_startup_probe_timeout_seconds (int): Optional. Number
+ of seconds after which the startup probe times out. Defaults to 1
+ second. Minimum value is 1.
+ serving_container_health_probe_exec (Sequence[str]): Optional. Exec
+ specifies the action to take. Used by health probe. An example of this
+ argument would be ["cat", "/tmp/healthy"]
+ serving_container_health_probe_period_seconds (int): Optional. How often
+ (in seconds) to perform the health probe. Default to 10 seconds.
+ Minimum value is 1.
+ serving_container_health_probe_timeout_seconds (int): Optional. Number
+ of seconds after which the health probe times out. Defaults to 1
+ second. Minimum value is 1.
Returns:
endpoint (aiplatform.Endpoint):
Created endpoint.
Raises:
- ValueError: If ``serving_container_spec`` is specified but ``serving_container_spec.image_uri``
+ ValueError: If ``serving_container_spec`` is specified but
+ ``serving_container_spec.image_uri``
is ``None``, or if ``serving_container_spec`` is specified but other
serving container parameters are specified.
"""
@@ -589,14 +599,19 @@ def deploy(
if fast_tryout_enabled:
request.deploy_config.fast_tryout_enabled = fast_tryout_enabled
+ if system_labels:
+ request.deploy_config.system_labels = system_labels
+
if serving_container_spec:
if not serving_container_spec.image_uri:
raise ValueError(
- "Serving container image uri is required for the serving container spec."
+ "Serving container image uri is required for the serving container"
+ " spec."
)
if serving_container_image_uri:
raise ValueError(
- "Serving container image uri is already set in the serving container spec."
+ "Serving container image uri is already set in the serving"
+ " container spec."
)
request.model_config.container_spec = serving_container_spec
@@ -640,24 +655,65 @@ def deploy(
def list_deploy_options(
self,
- ) -> Sequence[types.PublisherModel.CallToAction.Deploy]:
- """Lists the verified deploy options for the model."""
+ concise: bool = False,
+ ) -> Union[str, Sequence[types.PublisherModel.CallToAction.Deploy]]:
+ """Lists the verified deploy options for the model.
+
+ Args:
+ concise: If true, returns a human-readable string with container and
+ machine specs.
+
+ Returns:
+ A list of deploy options or a concise formatted string.
+ """
request = types.GetPublisherModelRequest(
name=self._publisher_model_name,
is_hugging_face_model=bool(self._is_hugging_face_model),
include_equivalent_model_garden_model_deployment_configs=True,
)
response = self._us_central1_model_garden_client.get_publisher_model(request)
- multi_deploy = (
+ deploy_options = (
response.supported_actions.multi_deploy_vertex.multi_deploy_vertex
)
- if not multi_deploy:
+
+ if not deploy_options:
raise ValueError(
- "Model does not support deployment, please use a deploy-able model"
- " instead. You can use the list_deployable_models() method"
- " to find out which ones currently support deployment."
+ "Model does not support deployment. "
+ "Use `list_deployable_models()` to find supported models."
+ )
+
+ if not concise:
+ return deploy_options
+
+ def _extract_config(option):
+ container = (
+ option.container_spec.image_uri if option.container_spec else None
+ )
+ machine = (
+ option.dedicated_resources.machine_spec
+ if option.dedicated_resources
+ else None
+ )
+
+ return {
+ "serving_container_image_uri": container,
+ "machine_type": getattr(machine, "machine_type", None),
+ "accelerator_type": getattr(
+ getattr(machine, "accelerator_type", None), "name", None
+ ),
+ "accelerator_count": getattr(machine, "accelerator_count", None),
+ }
+
+ concise_deploy_options = [_extract_config(opt) for opt in deploy_options]
+ return "\n\n".join(
+ f"[Option {i + 1}]\n"
+ + "\n".join(
+ f' {k}="{v}",' if k != "accelerator_count" else f" {k}={v},"
+ for k, v in config.items()
+ if v is not None
)
- return multi_deploy
+ for i, config in enumerate(concise_deploy_options)
+ )
def batch_predict(
self,
@@ -715,3 +771,44 @@ def batch_predict(
starting_replica_count=starting_replica_count,
max_replica_count=max_replica_count,
)
+
+ def check_license_agreement_status(self) -> bool:
+ """Check whether the project has accepted the license agreement of the model.
+
+ EULA (End User License Agreement) is a legal document that the user must
+ accept before using the model. For Models having license restrictions,
+ the user must accept the EULA before using the model. You can check the
+ details of the License in Model Garden.
+
+ Returns:
+ bool : True if the project has accepted the End User License
+ Agreement, False otherwise.
+ """
+ request = types.CheckPublisherModelEulaAcceptanceRequest(
+ parent=f"projects/{self._project}",
+ publisher_model=self._publisher_model_name,
+ )
+ response = self._model_garden_client.check_publisher_model_eula_acceptance(
+ request
+ )
+ return response.publisher_model_eula_acked
+
+ def accept_model_license_agreement(
+ self,
+ ) -> types.model_garden_service.PublisherModelEulaAcceptance:
+ """Accepts the EULA(End User License Agreement) of the model for the project.
+
+ For Models having license restrictions, the user must accept the EULA
+ before using the model. Calling this method will mark the EULA as accepted
+ for the project.
+
+ Returns:
+ types.model_garden_service.PublisherModelEulaAcceptance:
+ The response of the accept_eula call, containing project number,
+ model name and acceptance status.
+ """
+ request = types.AcceptPublisherModelEulaRequest(
+ parent=f"projects/{self._project}",
+ publisher_model=self._publisher_model_name,
+ )
+ return self._model_garden_client.accept_publisher_model_eula(request)
diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py
index 435cafa6f4..3ef0d10a53 100644
--- a/vertexai/preview/reasoning_engines/templates/adk.py
+++ b/vertexai/preview/reasoning_engines/templates/adk.py
@@ -13,9 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
+
if TYPE_CHECKING:
+ try:
+ from google.genai import types
+
+ ContentDict = types.Content
+ except (ImportError, AttributeError):
+ ContentDict = Dict
+
try:
from google.adk.events.event import Event
@@ -397,6 +405,11 @@ def set_up(self):
)
for key, value in self._tmpl_attrs.get("env_vars").items():
os.environ[key] = value
+ if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ:
+ self._tmpl_attrs["app_name"] = os.environ.get(
+ "GOOGLE_CLOUD_AGENT_ENGINE_ID",
+ self._tmpl_attrs.get("app_name"),
+ )
artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder")
if artifact_service_builder:
@@ -416,10 +429,6 @@ def set_up(self):
project=project,
location=location,
)
- self._tmpl_attrs["app_name"] = os.environ.get(
- "GOOGLE_CLOUD_AGENT_ENGINE_ID",
- self._tmpl_attrs.get("app_name"),
- )
else:
self._tmpl_attrs["session_service"] = InMemorySessionService()
@@ -441,7 +450,7 @@ def set_up(self):
def stream_query(
self,
*,
- message: str,
+ message: Union[str, "ContentDict"],
user_id: str,
session_id: Optional[str] = None,
**kwargs,
@@ -465,7 +474,15 @@ def stream_query(
"""
from google.genai import types
- content = types.Content(role="user", parts=[types.Part(text=message)])
+ if isinstance(message, Dict):
+ content = types.Content.model_validate(message)
+ elif isinstance(message, str):
+ content = types.Content(role="user", parts=[types.Part(text=message)])
+ else:
+ raise TypeError(
+ "message must be a string or a dictionary representing a Content object."
+ )
+
if not self._tmpl_attrs.get("runner"):
self.set_up()
if not session_id: