Skip to content

Commit d1522a7

Browse files
authored
feat: PT2.1 SM Training/Inference DLC Release (aws#4244)
1 parent 164f935 commit d1522a7

File tree

4 files changed

+131
-8
lines changed

4 files changed

+131
-8
lines changed

src/sagemaker/fw_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
"1.12.1",
139139
"1.13.1",
140140
"2.0.0",
141+
"2.0.1",
141142
],
142143
}
143144

@@ -153,10 +154,11 @@
153154
"1.13.1",
154155
"2.0.0",
155156
"2.0.1",
157+
"2.1.0",
156158
]
157159

158160

159-
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1"]
161+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0"]
160162

161163
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
162164
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@
7878
"1.11": "1.11.0",
7979
"1.12": "1.12.1",
8080
"1.13": "1.13.1",
81-
"2.0": "2.0.1"
81+
"2.0": "2.0.1",
82+
"2.1": "2.1.0"
8283
},
8384
"versions": {
8485
"0.4.0": {
@@ -931,6 +932,44 @@
931932
"us-west-2": "763104351884"
932933
},
933934
"repository": "pytorch-inference"
935+
},
936+
"2.1.0": {
937+
"py_versions": [
938+
"py310"
939+
],
940+
"registries": {
941+
"af-south-1": "626614931356",
942+
"il-central-1": "780543022126",
943+
"ap-east-1": "871362719292",
944+
"ap-northeast-1": "763104351884",
945+
"ap-northeast-2": "763104351884",
946+
"ap-northeast-3": "364406365360",
947+
"ap-south-1": "763104351884",
948+
"ap-southeast-1": "763104351884",
949+
"ap-southeast-2": "763104351884",
950+
"ap-southeast-3": "907027046896",
951+
"ap-southeast-4": "457447274322",
952+
"ca-central-1": "763104351884",
953+
"cn-north-1": "727897471807",
954+
"cn-northwest-1": "727897471807",
955+
"eu-central-1": "763104351884",
956+
"eu-north-1": "763104351884",
957+
"eu-west-1": "763104351884",
958+
"eu-west-2": "763104351884",
959+
"eu-west-3": "763104351884",
960+
"eu-south-1": "692866216735",
961+
"me-south-1": "217643126080",
962+
"sa-east-1": "763104351884",
963+
"us-east-1": "763104351884",
964+
"us-east-2": "763104351884",
965+
"us-gov-east-1": "446045086412",
966+
"us-gov-west-1": "442386744353",
967+
"us-iso-east-1": "886529160074",
968+
"us-isob-east-1": "094389454867",
969+
"us-west-1": "763104351884",
970+
"us-west-2": "763104351884"
971+
},
972+
"repository": "pytorch-inference"
934973
}
935974
}
936975
},
@@ -940,7 +979,8 @@
940979
],
941980
"version_aliases": {
942981
"1.12": "1.12.1",
943-
"2.0": "2.0.1"
982+
"2.0": "2.0.1",
983+
"2.1": "2.1.0"
944984
},
945985
"versions": {
946986
"1.12.1": {
@@ -1056,6 +1096,42 @@
10561096
},
10571097
"repository": "pytorch-inference-graviton",
10581098
"container_version": {"cpu": "ubuntu20.04"}
1099+
},
1100+
"2.1.0": {
1101+
"py_versions": [
1102+
"py310"
1103+
],
1104+
"registries": {
1105+
"af-south-1": "626614931356",
1106+
"il-central-1": "780543022126",
1107+
"ap-east-1": "871362719292",
1108+
"ap-northeast-1": "763104351884",
1109+
"ap-northeast-2": "763104351884",
1110+
"ap-northeast-3": "364406365360",
1111+
"ap-south-1": "763104351884",
1112+
"ap-south-2": "772153158452",
1113+
"ap-southeast-1": "763104351884",
1114+
"ap-southeast-2": "763104351884",
1115+
"ap-southeast-3": "907027046896",
1116+
"ap-southeast-4": "457447274322",
1117+
"ca-central-1": "763104351884",
1118+
"eu-central-1": "763104351884",
1119+
"eu-central-2": "380420809688",
1120+
"eu-north-1": "763104351884",
1121+
"eu-west-1": "763104351884",
1122+
"eu-west-2": "763104351884",
1123+
"eu-west-3": "763104351884",
1124+
"eu-south-1": "692866216735",
1125+
"eu-south-2": "503227376785",
1126+
"me-south-1": "217643126080",
1127+
"sa-east-1": "763104351884",
1128+
"us-east-1": "763104351884",
1129+
"us-east-2": "763104351884",
1130+
"us-west-1": "763104351884",
1131+
"us-west-2": "763104351884"
1132+
},
1133+
"repository": "pytorch-inference-graviton",
1134+
"container_version": {"cpu": "ubuntu20.04"}
10591135
}
10601136
}
10611137
},
@@ -1080,7 +1156,8 @@
10801156
"1.11": "1.11.0",
10811157
"1.12": "1.12.1",
10821158
"1.13": "1.13.1",
1083-
"2.0": "2.0.1"
1159+
"2.0": "2.0.1",
1160+
"2.1": "2.1.0"
10841161
},
10851162
"versions": {
10861163
"0.4.0": {
@@ -1934,6 +2011,44 @@
19342011
"us-west-2": "763104351884"
19352012
},
19362013
"repository": "pytorch-training"
2014+
},
2015+
"2.1.0": {
2016+
"py_versions": [
2017+
"py310"
2018+
],
2019+
"registries": {
2020+
"af-south-1": "626614931356",
2021+
"il-central-1": "780543022126",
2022+
"ap-east-1": "871362719292",
2023+
"ap-northeast-1": "763104351884",
2024+
"ap-northeast-2": "763104351884",
2025+
"ap-northeast-3": "364406365360",
2026+
"ap-south-1": "763104351884",
2027+
"ap-southeast-1": "763104351884",
2028+
"ap-southeast-2": "763104351884",
2029+
"ap-southeast-3": "907027046896",
2030+
"ap-southeast-4": "457447274322",
2031+
"ca-central-1": "763104351884",
2032+
"cn-north-1": "727897471807",
2033+
"cn-northwest-1": "727897471807",
2034+
"eu-central-1": "763104351884",
2035+
"eu-north-1": "763104351884",
2036+
"eu-west-1": "763104351884",
2037+
"eu-west-2": "763104351884",
2038+
"eu-west-3": "763104351884",
2039+
"eu-south-1": "692866216735",
2040+
"me-south-1": "217643126080",
2041+
"sa-east-1": "763104351884",
2042+
"us-east-1": "763104351884",
2043+
"us-east-2": "763104351884",
2044+
"us-gov-east-1": "446045086412",
2045+
"us-gov-west-1": "442386744353",
2046+
"us-iso-east-1": "886529160074",
2047+
"us-isob-east-1": "094389454867",
2048+
"us-west-1": "763104351884",
2049+
"us-west-2": "763104351884"
2050+
},
2051+
"repository": "pytorch-training"
19372052
}
19382053
}
19392054
}

tests/unit/test_fw_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ def test_validate_smdataparallel_args_not_raises():
937937
("ml.p3.16xlarge", "pytorch", "1.12", "py38", smdataparallel_enabled),
938938
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled),
939939
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled),
940+
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled),
940941
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
941942
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
942943
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
@@ -959,6 +960,7 @@ def test_validate_smdataparallel_args_not_raises():
959960
("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled_custom_mpi),
960961
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled_custom_mpi),
961962
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled_custom_mpi),
963+
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled_custom_mpi),
962964
]
963965
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
964966
fw_utils._validate_smdataparallel_args(
@@ -995,6 +997,10 @@ def test_validate_pytorchddp_not_raises():
995997
"1.12",
996998
"1.12.0",
997999
"1.12.1",
1000+
"1.13.1",
1001+
"2.0.0",
1002+
"2.0.1",
1003+
"2.1.0",
9981004
]
9991005
for framework_version in pytorchddp_supported_fw_versions:
10001006
fw_utils.validate_pytorch_distribution(
@@ -1057,10 +1063,7 @@ def test_validate_torch_distributed_not_raises():
10571063

10581064
# Case 3: Distribution is torch_distributed enabled, supported framework and instances
10591065
torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
1060-
torch_distributed_gpu_supported_fw_versions = [
1061-
"1.13.1",
1062-
"2.0.0",
1063-
]
1066+
torch_distributed_gpu_supported_fw_versions = ["1.13.1", "2.0.0", "2.0.1", "2.1.0"]
10641067
for framework_version in torch_distributed_gpu_supported_fw_versions:
10651068
fw_utils.validate_torch_distributed_distribution(
10661069
instance_type="ml.p3.8xlarge",

tests/unit/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ def test_set_nested_value():
384384

385385

386386
def test_get_short_version():
387+
assert sagemaker.utils.get_short_version("2.1.0") == "2.1"
388+
assert sagemaker.utils.get_short_version("2.1") == "2.1"
389+
assert sagemaker.utils.get_short_version("2.0.1") == "2.0"
387390
assert sagemaker.utils.get_short_version("2.0.0") == "2.0"
388391
assert sagemaker.utils.get_short_version("2.0") == "2.0"
389392

0 commit comments

Comments
 (0)