Skip to content

Commit 31f0215

Browse files
committed
feature: add support for PyTorch 1.9.0
1 parent 6a782d9 commit 31f0215

File tree

3 files changed

+72
-3
lines changed

3 files changed

+72
-3
lines changed

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
6262
"tensorflow": ["2.3", "2.3.1", "2.3.2", "2.4", "2.4.1"],
63-
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1"],
63+
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1", "1.9", "1.9.0"],
6464
}
6565
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
6666

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@
6262
"1.5": "1.5.0",
6363
"1.6": "1.6.0",
6464
"1.7": "1.7.1",
65-
"1.8": "1.8.1"
65+
"1.8": "1.8.1",
66+
"1.9": "1.9.0"
6667
},
6768
"versions": {
6869
"0.4.0": {
@@ -433,6 +434,39 @@
433434
"us-west-2": "763104351884"
434435
},
435436
"repository": "pytorch-inference"
437+
},
438+
"1.9.0": {
439+
"py_versions": [
440+
"py38"
441+
],
442+
"registries": {
443+
"af-south-1": "626614931356",
444+
"ap-east-1": "871362719292",
445+
"ap-northeast-1": "763104351884",
446+
"ap-northeast-2": "763104351884",
447+
"ap-northeast-3": "364406365360",
448+
"ap-south-1": "763104351884",
449+
"ap-southeast-1": "763104351884",
450+
"ap-southeast-2": "763104351884",
451+
"ca-central-1": "763104351884",
452+
"cn-north-1": "727897471807",
453+
"cn-northwest-1": "727897471807",
454+
"eu-central-1": "763104351884",
455+
"eu-north-1": "763104351884",
456+
"eu-west-1": "763104351884",
457+
"eu-west-2": "763104351884",
458+
"eu-west-3": "763104351884",
459+
"eu-south-1": "692866216735",
460+
"me-south-1": "217643126080",
461+
"sa-east-1": "763104351884",
462+
"us-east-1": "763104351884",
463+
"us-east-2": "763104351884",
464+
"us-gov-west-1": "442386744353",
465+
"us-iso-east-1": "886529160074",
466+
"us-west-1": "763104351884",
467+
"us-west-2": "763104351884"
468+
},
469+
"repository": "pytorch-inference"
436470
}
437471
}
438472
},
@@ -451,7 +485,8 @@
451485
"1.5": "1.5.0",
452486
"1.6": "1.6.0",
453487
"1.7": "1.7.1",
454-
"1.8": "1.8.1"
488+
"1.8": "1.8.1",
489+
"1.9": "1.9.0"
455490
},
456491
"versions": {
457492
"0.4.0": {
@@ -823,6 +858,39 @@
823858
"us-west-2": "763104351884"
824859
},
825860
"repository": "pytorch-training"
861+
},
862+
"1.9.0": {
863+
"py_versions": [
864+
"py38"
865+
],
866+
"registries": {
867+
"af-south-1": "626614931356",
868+
"ap-east-1": "871362719292",
869+
"ap-northeast-1": "763104351884",
870+
"ap-northeast-2": "763104351884",
871+
"ap-northeast-3": "364406365360",
872+
"ap-south-1": "763104351884",
873+
"ap-southeast-1": "763104351884",
874+
"ap-southeast-2": "763104351884",
875+
"ca-central-1": "763104351884",
876+
"cn-north-1": "727897471807",
877+
"cn-northwest-1": "727897471807",
878+
"eu-central-1": "763104351884",
879+
"eu-north-1": "763104351884",
880+
"eu-west-1": "763104351884",
881+
"eu-west-2": "763104351884",
882+
"eu-west-3": "763104351884",
883+
"eu-south-1": "692866216735",
884+
"me-south-1": "217643126080",
885+
"sa-east-1": "763104351884",
886+
"us-east-1": "763104351884",
887+
"us-east-2": "763104351884",
888+
"us-gov-west-1": "442386744353",
889+
"us-iso-east-1": "886529160074",
890+
"us-west-1": "763104351884",
891+
"us-west-2": "763104351884"
892+
},
893+
"repository": "pytorch-inference"
826894
}
827895
}
828896
}

tests/unit/test_fw_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ def test_validate_smdataparallel_args_not_raises():
653653
("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
654654
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
655655
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
656+
("ml.p3.16xlarge", "pytorch", "1.9.0", "py3", smdataparallel_enabled_custom_mpi),
656657
]
657658
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
658659
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)