Skip to content

Commit a906921

Browse files
adtian2Andrew Tian
and
Andrew Tian
authored
SMP PT upgrade to 2.1 (#4404)
* adding pt 2.1 support for smp * adding supported PT versions of 2.1.2 * bug fix in smp tests * adding 2.1.2 * adding 2.1.0 for Herring changes * changed image_uri generation logic for PT version 2.1.2 for smp * fixed this unit test to reflect changes for version PT 2.1 --------- Co-authored-by: Andrew Tian <tinandr@amazon.com>
1 parent 71468e5 commit a906921

File tree

4 files changed

+38
-6
lines changed

4 files changed

+38
-6
lines changed

src/sagemaker/fw_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@
139139
"1.13.1",
140140
"2.0.0",
141141
"2.0.1",
142+
"2.1.0",
143+
"2.1.2",
142144
],
143145
}
144146

@@ -158,7 +160,7 @@
158160
]
159161

160162

161-
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0"]
163+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.2"]
162164

163165
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
164166
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [

src/sagemaker/image_uri_config/pytorch-smp.json

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
"gpu"
55
],
66
"version_aliases": {
7-
"2.0": "2.0.1"
7+
"2.0": "2.0.1",
8+
"2.1": "2.1.2"
89
},
910
"versions": {
1011
"2.0.1": {
@@ -31,7 +32,32 @@
3132
"us-west-2": "658645717510"
3233
},
3334
"repository": "smdistributed-modelparallel"
35+
},
36+
"2.1.2": {
37+
"py_versions": [
38+
"py310"
39+
],
40+
"registries": {
41+
"ap-northeast-1": "658645717510",
42+
"ap-northeast-2": "658645717510",
43+
"ap-northeast-3": "658645717510",
44+
"ap-south-1": "658645717510",
45+
"ap-southeast-1": "658645717510",
46+
"ap-southeast-2": "658645717510",
47+
"ca-central-1": "658645717510",
48+
"eu-central-1": "658645717510",
49+
"eu-north-1": "658645717510",
50+
"eu-west-1": "658645717510",
51+
"eu-west-2": "658645717510",
52+
"eu-west-3": "658645717510",
53+
"sa-east-1": "658645717510",
54+
"us-east-1": "658645717510",
55+
"us-east-2": "658645717510",
56+
"us-west-1": "658645717510",
57+
"us-west-2": "658645717510"
58+
},
59+
"repository": "smdistributed-modelparallel"
3460
}
3561
}
3662
}
37-
}
63+
}

src/sagemaker/image_uris.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def get_training_image_uri(
672672
if "modelparallel" in distribution["smdistributed"]:
673673
if distribution["smdistributed"]["modelparallel"].get("enabled", True):
674674
framework = "pytorch-smp"
675-
if "p5" in instance_type:
675+
if "p5" in instance_type or "2.1" in framework_version:
676676
container_version = "cu121"
677677
else:
678678
container_version = "cu118"

tests/unit/sagemaker/image_uris/test_smp_v2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sagemaker import image_uris
1717
from tests.unit.sagemaker.image_uris import expected_uris
1818

19-
CONTAINER_VERSIONS = {"ml.p4d.24xlarge": "cu118", "ml.p5d.24xlarge": "cu121"}
19+
CONTAINER_VERSIONS = {"ml.p4d.24xlarge": "cu118", "ml.p5.24xlarge": "cu121"}
2020

2121

2222
@pytest.mark.parametrize("load_config", ["pytorch-smp.json"], indirect=True)
@@ -34,6 +34,10 @@ def test_smp_v2(load_config):
3434
for py_version in PY_VERSIONS:
3535
for region in ACCOUNTS.keys():
3636
for instance_type in CONTAINER_VERSIONS.keys():
37+
cuda_vers = CONTAINER_VERSIONS[instance_type]
38+
if "2.1" in version:
39+
cuda_vers = "cu121"
40+
3741
uri = image_uris.get_training_image_uri(
3842
region,
3943
framework="pytorch",
@@ -45,7 +49,7 @@ def test_smp_v2(load_config):
4549
expected = expected_uris.framework_uri(
4650
repo="smdistributed-modelparallel",
4751
fw_version=version,
48-
py_version=f"{py_version}-{CONTAINER_VERSIONS[instance_type]}",
52+
py_version=f"{py_version}-{cuda_vers}",
4953
processor=processor,
5054
region=region,
5155
account=ACCOUNTS[region],

0 commit comments

Comments
 (0)