Skip to content

Commit ec04711

Browse files
authored
add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor (#4463)
1 parent b0c43d5 commit ec04711

File tree

20 files changed

+203
-73
lines changed

20 files changed

+203
-73
lines changed

src/sagemaker/accept_types.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def retrieve_options(
2323
region: Optional[str] = None,
2424
model_id: Optional[str] = None,
2525
model_version: Optional[str] = None,
26+
hub_arn: Optional[str] = None,
2627
tolerate_vulnerable_model: bool = False,
2728
tolerate_deprecated_model: bool = False,
2829
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -36,6 +37,8 @@ def retrieve_options(
3637
retrieve the supported accept types. (Default: None).
3738
model_version (str): The version of the model for which to retrieve the
3839
supported accept types. (Default: None).
40+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
41+
model details from. (Default: None).
3942
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4043
specifications should be tolerated (exception not raised). If False, raises an
4144
exception if the script used by this version of the model has dependencies with known
@@ -59,11 +62,12 @@ def retrieve_options(
5962
)
6063

6164
return artifacts._retrieve_supported_accept_types(
62-
model_id,
63-
model_version,
64-
region,
65-
tolerate_vulnerable_model,
66-
tolerate_deprecated_model,
65+
model_id=model_id,
66+
model_version=model_version,
67+
hub_arn=hub_arn,
68+
region=region,
69+
tolerate_vulnerable_model=tolerate_vulnerable_model,
70+
tolerate_deprecated_model=tolerate_deprecated_model,
6771
sagemaker_session=sagemaker_session,
6872
)
6973

@@ -72,6 +76,7 @@ def retrieve_default(
7276
region: Optional[str] = None,
7377
model_id: Optional[str] = None,
7478
model_version: Optional[str] = None,
79+
hub_arn: Optional[str] = None,
7580
tolerate_vulnerable_model: bool = False,
7681
tolerate_deprecated_model: bool = False,
7782
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -85,6 +90,8 @@ def retrieve_default(
8590
retrieve the default accept type. (Default: None).
8691
model_version (str): The version of the model for which to retrieve the
8792
default accept type. (Default: None).
93+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
94+
model details from. (Default: None).
8895
tolerate_vulnerable_model (bool): True if vulnerable versions of model
8996
specifications should be tolerated (exception not raised). If False, raises an
9097
exception if the script used by this version of the model has dependencies with known
@@ -108,10 +115,11 @@ def retrieve_default(
108115
)
109116

110117
return artifacts._retrieve_default_accept_type(
111-
model_id,
112-
model_version,
113-
region,
114-
tolerate_vulnerable_model,
115-
tolerate_deprecated_model,
118+
model_id=model_id,
119+
model_version=model_version,
120+
hub_arn=hub_arn,
121+
region=region,
122+
tolerate_vulnerable_model=tolerate_vulnerable_model,
123+
tolerate_deprecated_model=tolerate_deprecated_model,
116124
sagemaker_session=sagemaker_session,
117125
)

src/sagemaker/content_types.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def retrieve_options(
2323
region: Optional[str] = None,
2424
model_id: Optional[str] = None,
2525
model_version: Optional[str] = None,
26+
hub_arn: Optional[str] = None,
2627
tolerate_vulnerable_model: bool = False,
2728
tolerate_deprecated_model: bool = False,
2829
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -36,6 +37,8 @@ def retrieve_options(
3637
retrieve the supported content types. (Default: None).
3738
model_version (str): The version of the model for which to retrieve the
3839
supported content types. (Default: None).
40+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
41+
model details from. (Default: None).
3942
tolerate_vulnerable_model (bool): True if vulnerable versions of model
4043
specifications should be tolerated (exception not raised). If False, raises an
4144
exception if the script used by this version of the model has dependencies with known
@@ -59,11 +62,12 @@ def retrieve_options(
5962
)
6063

6164
return artifacts._retrieve_supported_content_types(
62-
model_id,
63-
model_version,
64-
region,
65-
tolerate_vulnerable_model,
66-
tolerate_deprecated_model,
65+
model_id=model_id,
66+
model_version=model_version,
67+
hub_arn=hub_arn,
68+
region=region,
69+
tolerate_vulnerable_model=tolerate_vulnerable_model,
70+
tolerate_deprecated_model=tolerate_deprecated_model,
6771
sagemaker_session=sagemaker_session,
6872
)
6973

@@ -72,6 +76,7 @@ def retrieve_default(
7276
region: Optional[str] = None,
7377
model_id: Optional[str] = None,
7478
model_version: Optional[str] = None,
79+
hub_arn: Optional[str] = None,
7580
tolerate_vulnerable_model: bool = False,
7681
tolerate_deprecated_model: bool = False,
7782
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -85,6 +90,8 @@ def retrieve_default(
8590
retrieve the default content type. (Default: None).
8691
model_version (str): The version of the model for which to retrieve the
8792
default content type. (Default: None).
93+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
94+
model details from. (default: None).
8895
tolerate_vulnerable_model (bool): True if vulnerable versions of model
8996
specifications should be tolerated (exception not raised). If False, raises an
9097
exception if the script used by this version of the model has dependencies with known
@@ -108,11 +115,12 @@ def retrieve_default(
108115
)
109116

110117
return artifacts._retrieve_default_content_type(
111-
model_id,
112-
model_version,
113-
region,
114-
tolerate_vulnerable_model,
115-
tolerate_deprecated_model,
118+
model_id=model_id,
119+
model_version=model_version,
120+
hub_arn=hub_arn,
121+
region=region,
122+
tolerate_vulnerable_model=tolerate_vulnerable_model,
123+
tolerate_deprecated_model=tolerate_deprecated_model,
116124
sagemaker_session=sagemaker_session,
117125
)
118126

src/sagemaker/deserializers.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def retrieve_options(
4242
region: Optional[str] = None,
4343
model_id: Optional[str] = None,
4444
model_version: Optional[str] = None,
45+
hub_arn: Optional[str] = None,
4546
tolerate_vulnerable_model: bool = False,
4647
tolerate_deprecated_model: bool = False,
4748
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -55,6 +56,8 @@ def retrieve_options(
5556
retrieve the supported deserializers. (Default: None).
5657
model_version (str): The version of the model for which to retrieve the
5758
supported deserializers. (Default: None).
59+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
60+
model details from. (Default: None).
5861
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5962
specifications should be tolerated (exception not raised). If False, raises an
6063
exception if the script used by this version of the model has dependencies with known
@@ -79,11 +82,12 @@ def retrieve_options(
7982
)
8083

8184
return artifacts._retrieve_deserializer_options(
82-
model_id,
83-
model_version,
84-
region,
85-
tolerate_vulnerable_model,
86-
tolerate_deprecated_model,
85+
model_id=model_id,
86+
model_version=model_version,
87+
hub_arn=hub_arn,
88+
region=region,
89+
tolerate_vulnerable_model=tolerate_vulnerable_model,
90+
tolerate_deprecated_model=tolerate_deprecated_model,
8791
sagemaker_session=sagemaker_session,
8892
)
8993

@@ -92,6 +96,7 @@ def retrieve_default(
9296
region: Optional[str] = None,
9397
model_id: Optional[str] = None,
9498
model_version: Optional[str] = None,
99+
hub_arn: Optional[str] = None,
95100
tolerate_vulnerable_model: bool = False,
96101
tolerate_deprecated_model: bool = False,
97102
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
@@ -105,6 +110,8 @@ def retrieve_default(
105110
retrieve the default deserializer. (Default: None).
106111
model_version (str): The version of the model for which to retrieve the
107112
default deserializer. (Default: None).
113+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
114+
model details from. (Default: None).
108115
tolerate_vulnerable_model (bool): True if vulnerable versions of model
109116
specifications should be tolerated (exception not raised). If False, raises an
110117
exception if the script used by this version of the model has dependencies with known
@@ -129,10 +136,11 @@ def retrieve_default(
129136
)
130137

131138
return artifacts._retrieve_default_deserializer(
132-
model_id,
133-
model_version,
134-
region,
135-
tolerate_vulnerable_model,
136-
tolerate_deprecated_model,
139+
model_id=model_id,
140+
model_version=model_version,
141+
hub_arn=hub_arn,
142+
region=region,
143+
tolerate_vulnerable_model=tolerate_vulnerable_model,
144+
tolerate_deprecated_model=tolerate_deprecated_model,
137145
sagemaker_session=sagemaker_session,
138146
)

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def _retrieve_model_init_kwargs(
3232
model_id: str,
3333
model_version: str,
34+
hub_arn: Optional[str] = None,
3435
region: Optional[str] = None,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
@@ -43,6 +44,8 @@ def _retrieve_model_init_kwargs(
4344
retrieve the kwargs.
4445
model_version (str): Version of the JumpStart model for which to retrieve the
4546
kwargs.
47+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
48+
model details from. (default: None).
4649
region (Optional[str]): Region for which to retrieve kwargs.
4750
(Default: None).
4851
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -66,6 +69,7 @@ def _retrieve_model_init_kwargs(
6669
model_specs = verify_model_region_and_return_specs(
6770
model_id=model_id,
6871
version=model_version,
72+
hub_arn=hub_arn,
6973
scope=JumpStartScriptScope.INFERENCE,
7074
region=region,
7175
tolerate_vulnerable_model=tolerate_vulnerable_model,
@@ -85,6 +89,7 @@ def _retrieve_model_deploy_kwargs(
8589
model_id: str,
8690
model_version: str,
8791
instance_type: str,
92+
hub_arn: Optional[str] = None,
8893
region: Optional[str] = None,
8994
tolerate_vulnerable_model: bool = False,
9095
tolerate_deprecated_model: bool = False,
@@ -99,6 +104,8 @@ def _retrieve_model_deploy_kwargs(
99104
kwargs.
100105
instance_type (str): Instance type of the hosting endpoint, to determine if volume size
101106
is supported.
107+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
108+
model details from. (default: None).
102109
region (Optional[str]): Region for which to retrieve kwargs.
103110
(Default: None).
104111
tolerate_vulnerable_model (bool): True if vulnerable versions of model
@@ -123,6 +130,7 @@ def _retrieve_model_deploy_kwargs(
123130
model_specs = verify_model_region_and_return_specs(
124131
model_id=model_id,
125132
version=model_version,
133+
hub_arn=hub_arn,
126134
scope=JumpStartScriptScope.INFERENCE,
127135
region=region,
128136
tolerate_vulnerable_model=tolerate_vulnerable_model,

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def _retrieve_model_package_arn(
3131
model_version: str,
3232
instance_type: Optional[str],
3333
region: Optional[str],
34+
hub_arn: Optional[str] = None,
3435
scope: Optional[str] = None,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
@@ -46,6 +47,8 @@ def _retrieve_model_package_arn(
4647
instance_type (Optional[str]): An instance type to optionally supply in order to get an arn
4748
specific for the instance type.
4849
region (Optional[str]): Region for which to retrieve the model package arn.
50+
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
51+
model details from. (default: None).
4952
scope (Optional[str]): Scope for which to retrieve the model package arn.
5053
tolerate_vulnerable_model (bool): True if vulnerable versions of model
5154
specifications should be tolerated (exception not raised). If False, raises an
@@ -69,6 +72,7 @@ def _retrieve_model_package_arn(
6972
model_specs = verify_model_region_and_return_specs(
7073
model_id=model_id,
7174
version=model_version,
75+
hub_arn=hub_arn,
7276
scope=scope,
7377
region=region,
7478
tolerate_vulnerable_model=tolerate_vulnerable_model,

0 commit comments

Comments
 (0)