Skip to content

Commit ff59ff6

Browse files
authored
feat: jumpstart instance specific hyperparameters (#4180)
1 parent 59f7d11 commit ff59ff6

File tree

7 files changed

+622
-19
lines changed

7 files changed

+622
-19
lines changed

src/sagemaker/hyperparameters.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def retrieve_default(
3131
region: Optional[str] = None,
3232
model_id: Optional[str] = None,
3333
model_version: Optional[str] = None,
34+
instance_type: Optional[str] = None,
3435
include_container_hyperparameters: bool = False,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
@@ -45,6 +46,8 @@ def retrieve_default(
4546
retrieve the default hyperparameters. (Default: None).
4647
model_version (str): The version of the model for which to retrieve the
4748
default hyperparameters. (Default: None).
49+
instance_type (str): An instance type to optionally supply in order to get hyperparameters
50+
specific for the instance type.
4851
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
4952
should be returned. Container hyperparameters are not used to tune
5053
the specific algorithm. They are used by SageMaker Training jobs to set up
@@ -75,12 +78,13 @@ def retrieve_default(
7578
)
7679

7780
return artifacts._retrieve_default_hyperparameters(
78-
model_id,
79-
model_version,
80-
region,
81-
include_container_hyperparameters,
82-
tolerate_vulnerable_model,
83-
tolerate_deprecated_model,
81+
model_id=model_id,
82+
model_version=model_version,
83+
instance_type=instance_type,
84+
region=region,
85+
include_container_hyperparameters=include_container_hyperparameters,
86+
tolerate_vulnerable_model=tolerate_vulnerable_model,
87+
tolerate_deprecated_model=tolerate_deprecated_model,
8488
sagemaker_session=sagemaker_session,
8589
)
8690

src/sagemaker/jumpstart/artifacts/hyperparameters.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _retrieve_default_hyperparameters(
3535
tolerate_vulnerable_model: bool = False,
3636
tolerate_deprecated_model: bool = False,
3737
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
38+
instance_type: Optional[str] = None,
3839
):
3940
"""Retrieves the training hyperparameters for the model matching the given arguments.
4041
@@ -63,6 +64,8 @@ def _retrieve_default_hyperparameters(
6364
object, used for SageMaker interactions. If not
6465
specified, one is created using the default AWS configuration
6566
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
67+
instance_type (str): An instance type to optionally supply in order to get hyperparameters
68+
specific for the instance type.
6669
Returns:
6770
dict: the hyperparameters to use for the model.
6871
"""
@@ -86,4 +89,19 @@ def _retrieve_default_hyperparameters(
8689
include_container_hyperparameters and hyperparameter.scope == VariableScope.CONTAINER
8790
) or hyperparameter.scope == VariableScope.ALGORITHM:
8891
default_hyperparameters[hyperparameter.name] = str(hyperparameter.default)
92+
93+
instance_specific_hyperparameters = (
94+
model_specs.training_instance_type_variants.get_instance_specific_hyperparameters(
95+
instance_type
96+
)
97+
if instance_type
98+
and getattr(model_specs, "training_instance_type_variants", None) is not None
99+
else []
100+
)
101+
102+
for instance_specific_hyperparameter in instance_specific_hyperparameters:
103+
default_hyperparameters[instance_specific_hyperparameter.name] = str(
104+
instance_specific_hyperparameter.default
105+
)
106+
89107
return default_hyperparameters

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,7 @@ def _add_hyperparameters_to_kwargs(
599599
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
600600
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
601601
sagemaker_session=kwargs.sagemaker_session,
602+
instance_type=kwargs.instance_type,
602603
)
603604

604605
for key, value in default_hyperparameters.items():

src/sagemaker/jumpstart/types.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,50 @@ def to_json(self) -> Dict[str, Any]:
403403
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
404404
return json_obj
405405

406+
def get_instance_specific_hyperparameters(
407+
self, instance_type: str
408+
) -> List[JumpStartHyperparameter]:
409+
"""Returns instance specific hyperparameters.
410+
411+
Returns empty list if a model, instance type tuple does not have specific
412+
hyperparameters.
413+
"""
414+
415+
if self.variants is None:
416+
return []
417+
418+
instance_specific_hyperparameters: List[JumpStartHyperparameter] = [
419+
JumpStartHyperparameter(json)
420+
for json in self.variants.get(instance_type, {})
421+
.get("properties", {})
422+
.get("hyperparameters", [])
423+
]
424+
425+
instance_type_family = get_instance_type_family(instance_type)
426+
427+
instance_family_hyperparameters: List[JumpStartHyperparameter] = [
428+
JumpStartHyperparameter(json)
429+
for json in (
430+
self.variants.get(instance_type_family, {})
431+
.get("properties", {})
432+
.get("hyperparameters", [])
433+
if instance_type_family not in {"", None}
434+
else []
435+
)
436+
]
437+
438+
instance_specific_hyperparameter_names = {
439+
hyperparameter.name for hyperparameter in instance_specific_hyperparameters
440+
}
441+
442+
hyperparams_to_return = deepcopy(instance_specific_hyperparameters)
443+
444+
for hyperparameter in instance_family_hyperparameters:
445+
if hyperparameter.name not in instance_specific_hyperparameter_names:
446+
hyperparams_to_return.append(hyperparameter)
447+
448+
return hyperparams_to_return
449+
406450
def get_instance_specific_environment_variables(self, instance_type: str) -> Dict[str, str]:
407451
"""Returns instance specific environment variables.
408452
@@ -867,10 +911,10 @@ def __init__(
867911
self.instance_type = instance_type
868912
self.region = region
869913
self.image_uri = image_uri
870-
self.model_data = model_data
914+
self.model_data = deepcopy(model_data)
871915
self.source_dir = source_dir
872916
self.entry_point = entry_point
873-
self.env = env
917+
self.env = deepcopy(env)
874918
self.predictor_cls = predictor_cls
875919
self.role = role
876920
self.name = name
@@ -963,7 +1007,7 @@ def __init__(
9631007
self.deserializer = deserializer
9641008
self.accelerator_type = accelerator_type
9651009
self.endpoint_name = endpoint_name
966-
self.tags = tags
1010+
self.tags = deepcopy(tags)
9671011
self.kms_key = kms_key
9681012
self.wait = wait
9691013
self.data_capture_config = data_capture_config
@@ -1111,8 +1155,8 @@ def __init__(
11111155
self.model_uri = model_uri
11121156
self.source_dir = source_dir
11131157
self.entry_point = entry_point
1114-
self.hyperparameters = hyperparameters
1115-
self.metric_definitions = metric_definitions
1158+
self.hyperparameters = deepcopy(hyperparameters)
1159+
self.metric_definitions = deepcopy(metric_definitions)
11161160
self.role = role
11171161
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
11181162
self.volume_size = volume_size
@@ -1123,7 +1167,7 @@ def __init__(
11231167
self.output_kms_key = output_kms_key
11241168
self.base_job_name = base_job_name
11251169
self.sagemaker_session = sagemaker_session
1126-
self.tags = tags
1170+
self.tags = deepcopy(tags)
11271171
self.subnets = subnets
11281172
self.security_group_ids = security_group_ids
11291173
self.model_channel_name = model_channel_name
@@ -1139,7 +1183,7 @@ def __init__(
11391183
self.enable_sagemaker_metrics = enable_sagemaker_metrics
11401184
self.profiler_config = profiler_config
11411185
self.disable_profiler = disable_profiler
1142-
self.environment = environment
1186+
self.environment = deepcopy(environment)
11431187
self.max_retry_attempts = max_retry_attempts
11441188
self.git_config = git_config
11451189
self.container_log_level = container_log_level
@@ -1319,13 +1363,13 @@ def __init__(
13191363
self.image_uri = image_uri
13201364
self.source_dir = source_dir
13211365
self.entry_point = entry_point
1322-
self.env = env
1366+
self.env = deepcopy(env)
13231367
self.predictor_cls = predictor_cls
13241368
self.serializer = serializer
13251369
self.deserializer = deserializer
13261370
self.accelerator_type = accelerator_type
13271371
self.endpoint_name = endpoint_name
1328-
self.tags = tags
1372+
self.tags = deepcopy(tags)
13291373
self.kms_key = kms_key
13301374
self.wait = wait
13311375
self.data_capture_config = data_capture_config

tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from sagemaker import hyperparameters
2121

22-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
22+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
2323

2424

2525
mock_client = boto3.client("s3")
@@ -116,3 +116,74 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs):
116116
hyperparameters.retrieve_default(
117117
model_id=model_id,
118118
)
119+
120+
121+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
122+
def test_jumpstart_sdk_hyperparameters_instance_type_overrides(patched_get_model_specs):
123+
124+
patched_get_model_specs.side_effect = get_special_model_spec
125+
126+
model_id = "variant-model"
127+
region = "us-west-2"
128+
129+
# assert that we can add hyperparameters to default
130+
vars = hyperparameters.retrieve_default(
131+
region=region,
132+
model_id=model_id,
133+
model_version="*",
134+
sagemaker_session=mock_session,
135+
instance_type="ml.p2.48xlarge",
136+
)
137+
assert vars == {
138+
"adam-learning-rate": "0.05",
139+
"batch-size": "4",
140+
"epochs": "3",
141+
"num_bag_sets": "5",
142+
"num_stack_levels": "6",
143+
"refit_full": "False",
144+
"sagemaker_container_log_level": "20",
145+
"sagemaker_program": "transfer_learning.py",
146+
"sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz",
147+
"save_space": "False",
148+
"set_best_to_refit_full": "False",
149+
"verbosity": "2",
150+
}
151+
152+
# assert that we can override default environment variables (instance family + instance type
153+
# specific)
154+
vars = hyperparameters.retrieve_default(
155+
region=region,
156+
model_id=model_id,
157+
model_version="*",
158+
sagemaker_session=mock_session,
159+
instance_type="ml.p2.12xlarge",
160+
)
161+
assert vars == {
162+
"adam-learning-rate": "0.05",
163+
"batch-size": "1",
164+
"epochs": "3",
165+
"num_bag_sets": "1",
166+
"num_stack_levels": "0",
167+
"refit_full": "False",
168+
"eval_metric": "auto",
169+
"num_bag_folds": "0",
170+
"presets": "medium_quality",
171+
"auto_stack": "False",
172+
"sagemaker_container_log_level": "20",
173+
"sagemaker_program": "transfer_learning.py",
174+
"sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz",
175+
"save_space": "False",
176+
"set_best_to_refit_full": "False",
177+
"verbosity": "2",
178+
}
179+
180+
# assert that we can return default hyperparameters for unrecognized instance
181+
vars = hyperparameters.retrieve_default(
182+
region=region,
183+
model_id=model_id,
184+
model_version="*",
185+
sagemaker_session=mock_session,
186+
instance_type="ml.p9999.48xlarge",
187+
)
188+
189+
assert vars == {"epochs": "3", "adam-learning-rate": "0.05", "batch-size": "4"}

0 commit comments

Comments
 (0)