Skip to content

Commit 601c94e

Browse files
saimidunavinsoni
authored andcommitted
change: Add support for PyTorch 1.9.1
Co-authored-by: Navin Soni <[email protected]>
1 parent 8573440 commit 601c94e

File tree

3 files changed

+72
-4
lines changed

3 files changed

+72
-4
lines changed

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
6262
"tensorflow": ["2.3", "2.3.1", "2.3.2", "2.4", "2.4.1", "2.4.3", "2.5", "2.5.0", "2.5.1"],
63-
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1", "1.9", "1.9.0"],
63+
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1", "1.9", "1.9.0", "1.9.1"],
6464
}
6565
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
6666

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
"1.6": "1.6.0",
6464
"1.7": "1.7.1",
6565
"1.8": "1.8.1",
66-
"1.9": "1.9.0"
66+
"1.9": "1.9.1"
6767
},
6868
"versions": {
6969
"0.4.0": {
@@ -467,6 +467,39 @@
467467
"us-west-2": "763104351884"
468468
},
469469
"repository": "pytorch-inference"
470+
},
471+
"1.9.1": {
472+
"py_versions": [
473+
"py38"
474+
],
475+
"registries": {
476+
"af-south-1": "626614931356",
477+
"ap-east-1": "871362719292",
478+
"ap-northeast-1": "763104351884",
479+
"ap-northeast-2": "763104351884",
480+
"ap-northeast-3": "364406365360",
481+
"ap-south-1": "763104351884",
482+
"ap-southeast-1": "763104351884",
483+
"ap-southeast-2": "763104351884",
484+
"ca-central-1": "763104351884",
485+
"cn-north-1": "727897471807",
486+
"cn-northwest-1": "727897471807",
487+
"eu-central-1": "763104351884",
488+
"eu-north-1": "763104351884",
489+
"eu-west-1": "763104351884",
490+
"eu-west-2": "763104351884",
491+
"eu-west-3": "763104351884",
492+
"eu-south-1": "692866216735",
493+
"me-south-1": "217643126080",
494+
"sa-east-1": "763104351884",
495+
"us-east-1": "763104351884",
496+
"us-east-2": "763104351884",
497+
"us-gov-west-1": "442386744353",
498+
"us-iso-east-1": "886529160074",
499+
"us-west-1": "763104351884",
500+
"us-west-2": "763104351884"
501+
},
502+
"repository": "pytorch-inference"
470503
}
471504
}
472505
},
@@ -486,7 +519,7 @@
486519
"1.6": "1.6.0",
487520
"1.7": "1.7.1",
488521
"1.8": "1.8.1",
489-
"1.9": "1.9.0"
522+
"1.9": "1.9.1"
490523
},
491524
"versions": {
492525
"0.4.0": {
@@ -891,6 +924,39 @@
891924
"us-west-2": "763104351884"
892925
},
893926
"repository": "pytorch-training"
927+
},
928+
"1.9.1": {
929+
"py_versions": [
930+
"py38"
931+
],
932+
"registries": {
933+
"af-south-1": "626614931356",
934+
"ap-east-1": "871362719292",
935+
"ap-northeast-1": "763104351884",
936+
"ap-northeast-2": "763104351884",
937+
"ap-northeast-3": "364406365360",
938+
"ap-south-1": "763104351884",
939+
"ap-southeast-1": "763104351884",
940+
"ap-southeast-2": "763104351884",
941+
"ca-central-1": "763104351884",
942+
"cn-north-1": "727897471807",
943+
"cn-northwest-1": "727897471807",
944+
"eu-central-1": "763104351884",
945+
"eu-north-1": "763104351884",
946+
"eu-west-1": "763104351884",
947+
"eu-west-2": "763104351884",
948+
"eu-west-3": "763104351884",
949+
"eu-south-1": "692866216735",
950+
"me-south-1": "217643126080",
951+
"sa-east-1": "763104351884",
952+
"us-east-1": "763104351884",
953+
"us-east-2": "763104351884",
954+
"us-gov-west-1": "442386744353",
955+
"us-iso-east-1": "886529160074",
956+
"us-west-1": "763104351884",
957+
"us-west-2": "763104351884"
958+
},
959+
"repository": "pytorch-training"
894960
}
895961
}
896962
}

tests/unit/test_fw_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,9 +651,11 @@ def test_validate_smdataparallel_args_not_raises():
651651
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled),
652652
("ml.p3.16xlarge", "pytorch", "1.8.1", "py3", smdataparallel_enabled),
653653
("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
654+
("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled),
655+
("ml.p3.16xlarge", "pytorch", "1.9", "py38", smdataparallel_enabled),
654656
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
655657
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
656-
("ml.p3.16xlarge", "pytorch", "1.9.0", "py3", smdataparallel_enabled_custom_mpi),
658+
("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi),
657659
]
658660
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
659661
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)