Skip to content

Commit f26a6d4

Browse files
SergTogulSergey Togulevahsan-z-khan
authored and
Talia Chopra
committed
feature: add support for PyTorch 1.8.1 (aws#2278)
* feature: add support for PyTorch 1.8.1 * Updated available versions Co-authored-by: Sergey Togulev <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent 8f6e085 commit f26a6d4

File tree

3 files changed

+70
-3
lines changed

3 files changed

+70
-3
lines changed

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
6262
"tensorflow": ["2.3", "2.3.1", "2.3.2", "2.4", "2.4.1"],
63-
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0"],
63+
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1"],
6464
}
6565
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
6666

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
"1.5": "1.5.0",
5757
"1.6": "1.6.0",
5858
"1.7": "1.7.1",
59-
"1.8": "1.8.0"
59+
"1.8": "1.8.1"
6060
},
6161
"versions": {
6262
"0.4.0": {
@@ -386,6 +386,39 @@
386386
"us-west-2": "763104351884"
387387
},
388388
"repository": "pytorch-inference"
389+
},
390+
"1.8.1": {
391+
"py_versions": [
392+
"py3",
393+
"py36"
394+
],
395+
"registries": {
396+
"af-south-1": "626614931356",
397+
"ap-east-1": "871362719292",
398+
"ap-northeast-1": "763104351884",
399+
"ap-northeast-2": "763104351884",
400+
"ap-south-1": "763104351884",
401+
"ap-southeast-1": "763104351884",
402+
"ap-southeast-2": "763104351884",
403+
"ca-central-1": "763104351884",
404+
"cn-north-1": "727897471807",
405+
"cn-northwest-1": "727897471807",
406+
"eu-central-1": "763104351884",
407+
"eu-north-1": "763104351884",
408+
"eu-west-1": "763104351884",
409+
"eu-west-2": "763104351884",
410+
"eu-west-3": "763104351884",
411+
"eu-south-1": "692866216735",
412+
"me-south-1": "217643126080",
413+
"sa-east-1": "763104351884",
414+
"us-east-1": "763104351884",
415+
"us-east-2": "763104351884",
416+
"us-gov-west-1": "442386744353",
417+
"us-iso-east-1": "886529160074",
418+
"us-west-1": "763104351884",
419+
"us-west-2": "763104351884"
420+
},
421+
"repository": "pytorch-inference"
389422
}
390423
}
391424
},
@@ -404,7 +437,7 @@
404437
"1.5": "1.5.0",
405438
"1.6": "1.6.0",
406439
"1.7": "1.7.1",
407-
"1.8": "1.8.0"
440+
"1.8": "1.8.1"
408441
},
409442
"versions": {
410443
"0.4.0": {
@@ -735,6 +768,39 @@
735768
"us-west-2": "763104351884"
736769
},
737770
"repository": "pytorch-training"
771+
},
772+
"1.8.1": {
773+
"py_versions": [
774+
"py3",
775+
"py36"
776+
],
777+
"registries": {
778+
"af-south-1": "626614931356",
779+
"ap-east-1": "871362719292",
780+
"ap-northeast-1": "763104351884",
781+
"ap-northeast-2": "763104351884",
782+
"ap-south-1": "763104351884",
783+
"ap-southeast-1": "763104351884",
784+
"ap-southeast-2": "763104351884",
785+
"ca-central-1": "763104351884",
786+
"cn-north-1": "727897471807",
787+
"cn-northwest-1": "727897471807",
788+
"eu-central-1": "763104351884",
789+
"eu-north-1": "763104351884",
790+
"eu-west-1": "763104351884",
791+
"eu-west-2": "763104351884",
792+
"eu-west-3": "763104351884",
793+
"eu-south-1": "692866216735",
794+
"me-south-1": "217643126080",
795+
"sa-east-1": "763104351884",
796+
"us-east-1": "763104351884",
797+
"us-east-2": "763104351884",
798+
"us-gov-west-1": "442386744353",
799+
"us-iso-east-1": "886529160074",
800+
"us-west-1": "763104351884",
801+
"us-west-2": "763104351884"
802+
},
803+
"repository": "pytorch-training"
738804
}
739805
}
740806
}

tests/unit/test_fw_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,7 @@ def test_validate_smdataparallel_args_not_raises():
642642
("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled),
643643
("ml.p3.16xlarge", "pytorch", "1.7", "py3", smdataparallel_enabled),
644644
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled),
645+
("ml.p3.16xlarge", "pytorch", "1.8.1", "py3", smdataparallel_enabled),
645646
("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
646647
]
647648
for instance_type, framework_name, framework_version, py_version, distribution in good_args:

0 commit comments

Comments
 (0)