Skip to content

Commit 9a036f6

Browse files
author
Mike Schneider
committed
Added PT 2.0 SM updates
1 parent e05ccd7 commit 9a036f6

File tree

4 files changed

+47
-3
lines changed

4 files changed

+47
-3
lines changed

src/sagemaker/fw_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
"1.12.0",
136136
"1.12.1",
137137
"1.13.1",
138+
"2.0.0",
138139
],
139140
}
140141

@@ -148,6 +149,7 @@
148149
"1.12.0",
149150
"1.12.1",
150151
"1.13.1",
152+
"2.0.0",
151153
]
152154

153155

@@ -161,6 +163,7 @@
161163
"1.12.0",
162164
"1.12.1",
163165
"1.13.1",
166+
"2.0.0",
164167
]
165168

166169
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,8 @@
992992
"1.10": "1.10.2",
993993
"1.11": "1.11.0",
994994
"1.12": "1.12.1",
995-
"1.13": "1.13.1"
995+
"1.13": "1.13.1",
996+
"2.0": "2.0.0"
996997
},
997998
"versions": {
998999
"0.4.0": {
@@ -1754,6 +1755,43 @@
17541755
"us-west-2": "763104351884"
17551756
},
17561757
"repository": "pytorch-training"
1758+
},
1759+
"2.0.0": {
1760+
"py_versions": [
1761+
"py310"
1762+
],
1763+
"registries": {
1764+
"af-south-1": "626614931356",
1765+
"ap-east-1": "871362719292",
1766+
"ap-northeast-1": "763104351884",
1767+
"ap-northeast-2": "763104351884",
1768+
"ap-northeast-3": "364406365360",
1769+
"ap-south-1": "763104351884",
1770+
"ap-southeast-1": "763104351884",
1771+
"ap-southeast-2": "763104351884",
1772+
"ap-southeast-3": "907027046896",
1773+
"ap-southeast-4": "457447274322",
1774+
"ca-central-1": "763104351884",
1775+
"cn-north-1": "727897471807",
1776+
"cn-northwest-1": "727897471807",
1777+
"eu-central-1": "763104351884",
1778+
"eu-north-1": "763104351884",
1779+
"eu-west-1": "763104351884",
1780+
"eu-west-2": "763104351884",
1781+
"eu-west-3": "763104351884",
1782+
"eu-south-1": "692866216735",
1783+
"me-south-1": "217643126080",
1784+
"sa-east-1": "763104351884",
1785+
"us-east-1": "763104351884",
1786+
"us-east-2": "763104351884",
1787+
"us-gov-east-1": "446045086412",
1788+
"us-gov-west-1": "442386744353",
1789+
"us-iso-east-1": "886529160074",
1790+
"us-isob-east-1": "094389454867",
1791+
"us-west-1": "763104351884",
1792+
"us-west-2": "763104351884"
1793+
},
1794+
"repository": "pytorch-training"
17571795
}
17581796
}
17591797
}

tests/unit/test_fw_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,7 @@ def test_validate_smdataparallel_args_not_raises():
913913
("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled),
914914
("ml.p3.16xlarge", "pytorch", "1.12", "py38", smdataparallel_enabled),
915915
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled),
916+
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled),
916917
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
917918
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
918919
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
@@ -934,6 +935,7 @@ def test_validate_smdataparallel_args_not_raises():
934935
("ml.p3.16xlarge", "pytorch", "1.12.0", "py38", smdataparallel_enabled_custom_mpi),
935936
("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled_custom_mpi),
936937
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled_custom_mpi),
938+
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled_custom_mpi),
937939
]
938940
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
939941
fw_utils._validate_smdataparallel_args(
@@ -1034,6 +1036,7 @@ def test_validate_torch_distributed_not_raises():
10341036
torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
10351037
torch_distributed_gpu_supported_fw_versions = [
10361038
"1.13.1",
1039+
"2.0.0",
10371040
]
10381041
for framework_version in torch_distributed_gpu_supported_fw_versions:
10391042
fw_utils.validate_torch_distributed_distribution(

tests/unit/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ def test_set_nested_value():
380380

381381

382382
def test_get_short_version():
383-
assert sagemaker.utils.get_short_version("1.13.1") == "1.13"
384-
assert sagemaker.utils.get_short_version("1.13") == "1.13"
383+
assert sagemaker.utils.get_short_version("2.0.0") == "2.0"
384+
assert sagemaker.utils.get_short_version("2.0") == "2.0"
385385

386386

387387
def test_deferred_error():

0 commit comments

Comments
 (0)