Skip to content

Commit 31e817b

Browse files
saimiduahsan-z-khanshreyapanditjeniyat
authored
feature: add support for PyTorch 1.9.0 (#2653)
Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Shreya Pandit <[email protected]> Co-authored-by: Jeniya Tabassum <[email protected]>
1 parent a734f29 commit 31e817b

File tree

5 files changed

+93
-14
lines changed

5 files changed

+93
-14
lines changed

src/sagemaker/fw_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
6262
"tensorflow": ["2.3", "2.3.1", "2.3.2", "2.4", "2.4.1"],
63-
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1"],
63+
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1", "1.9", "1.9.0"],
6464
}
6565
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
6666

@@ -298,7 +298,7 @@ def framework_name_from_image(image_uri):
298298
(tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost
299299
|huggingface-tensorflow|huggingface-pytorch)(?:-)?
300300
(scriptmode|training)?
301-
:(.*)-(.*?)-(py2|py3[67]?)(?:.*)$""",
301+
:(.*)-(.*?)-(py2|py3\d*)(?:.*)$""",
302302
re.VERBOSE,
303303
)
304304
name_match = name_pattern.match(sagemaker_match.group(9))
@@ -329,7 +329,7 @@ def framework_version_from_tag(image_tag):
329329
Returns:
330330
str: The framework version.
331331
"""
332-
tag_pattern = re.compile("^(.*)-(cpu|gpu)-(py2|py3[67]?)$")
332+
tag_pattern = re.compile(r"^(.*)-(cpu|gpu)-(py2|py3\d*)$")
333333
tag_match = tag_pattern.match(image_tag)
334334
return None if tag_match is None else tag_match.group(1)
335335

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@
6262
"1.5": "1.5.0",
6363
"1.6": "1.6.0",
6464
"1.7": "1.7.1",
65-
"1.8": "1.8.1"
65+
"1.8": "1.8.1",
66+
"1.9": "1.9.0"
6667
},
6768
"versions": {
6869
"0.4.0": {
@@ -433,6 +434,39 @@
433434
"us-west-2": "763104351884"
434435
},
435436
"repository": "pytorch-inference"
437+
},
438+
"1.9.0": {
439+
"py_versions": [
440+
"py38"
441+
],
442+
"registries": {
443+
"af-south-1": "626614931356",
444+
"ap-east-1": "871362719292",
445+
"ap-northeast-1": "763104351884",
446+
"ap-northeast-2": "763104351884",
447+
"ap-northeast-3": "364406365360",
448+
"ap-south-1": "763104351884",
449+
"ap-southeast-1": "763104351884",
450+
"ap-southeast-2": "763104351884",
451+
"ca-central-1": "763104351884",
452+
"cn-north-1": "727897471807",
453+
"cn-northwest-1": "727897471807",
454+
"eu-central-1": "763104351884",
455+
"eu-north-1": "763104351884",
456+
"eu-west-1": "763104351884",
457+
"eu-west-2": "763104351884",
458+
"eu-west-3": "763104351884",
459+
"eu-south-1": "692866216735",
460+
"me-south-1": "217643126080",
461+
"sa-east-1": "763104351884",
462+
"us-east-1": "763104351884",
463+
"us-east-2": "763104351884",
464+
"us-gov-west-1": "442386744353",
465+
"us-iso-east-1": "886529160074",
466+
"us-west-1": "763104351884",
467+
"us-west-2": "763104351884"
468+
},
469+
"repository": "pytorch-inference"
436470
}
437471
}
438472
},
@@ -451,7 +485,8 @@
451485
"1.5": "1.5.0",
452486
"1.6": "1.6.0",
453487
"1.7": "1.7.1",
454-
"1.8": "1.8.1"
488+
"1.8": "1.8.1",
489+
"1.9": "1.9.0"
455490
},
456491
"versions": {
457492
"0.4.0": {
@@ -823,6 +858,39 @@
823858
"us-west-2": "763104351884"
824859
},
825860
"repository": "pytorch-training"
861+
},
862+
"1.9.0": {
863+
"py_versions": [
864+
"py38"
865+
],
866+
"registries": {
867+
"af-south-1": "626614931356",
868+
"ap-east-1": "871362719292",
869+
"ap-northeast-1": "763104351884",
870+
"ap-northeast-2": "763104351884",
871+
"ap-northeast-3": "364406365360",
872+
"ap-south-1": "763104351884",
873+
"ap-southeast-1": "763104351884",
874+
"ap-southeast-2": "763104351884",
875+
"ca-central-1": "763104351884",
876+
"cn-north-1": "727897471807",
877+
"cn-northwest-1": "727897471807",
878+
"eu-central-1": "763104351884",
879+
"eu-north-1": "763104351884",
880+
"eu-west-1": "763104351884",
881+
"eu-west-2": "763104351884",
882+
"eu-west-3": "763104351884",
883+
"eu-south-1": "692866216735",
884+
"me-south-1": "217643126080",
885+
"sa-east-1": "763104351884",
886+
"us-east-1": "763104351884",
887+
"us-east-2": "763104351884",
888+
"us-gov-west-1": "442386744353",
889+
"us-iso-east-1": "886529160074",
890+
"us-west-1": "763104351884",
891+
"us-west-2": "763104351884"
892+
},
893+
"repository": "pytorch-training"
826894
}
827895
}
828896
}

tests/conftest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def mxnet_eia_latest_py_version():
184184
def pytorch_training_py_version(pytorch_training_version, request):
185185
if Version(pytorch_training_version) < Version("1.5.0"):
186186
return request.param
187-
elif Version(pytorch_training_version) == Version("1.7.1"):
188-
return "py36"
187+
elif Version(pytorch_training_version) >= Version("1.9"):
188+
return "py38"
189189
else:
190190
return "py3"
191191

@@ -194,8 +194,8 @@ def pytorch_training_py_version(pytorch_training_version, request):
194194
def pytorch_inference_py_version(pytorch_inference_version, request):
195195
if Version(pytorch_inference_version) < Version("1.4.0"):
196196
return request.param
197-
elif Version(pytorch_inference_version) == Version("1.7.1"):
198-
return "py36"
197+
elif Version(pytorch_inference_version) >= Version("1.9"):
198+
return "py38"
199199
else:
200200
return "py3"
201201

tests/unit/test_fw_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ def test_validate_smdataparallel_args_not_raises():
653653
("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
654654
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
655655
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
656+
("ml.p3.16xlarge", "pytorch", "1.9.0", "py3", smdataparallel_enabled_custom_mpi),
656657
]
657658
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
658659
fw_utils._validate_smdataparallel_args(

tests/unit/test_processing.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,14 +284,20 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
284284
@patch("os.path.exists", return_value=True)
285285
@patch("os.path.isfile", return_value=True)
286286
def test_pytorch_processor_with_required_parameters(
287-
exists_mock, isfile_mock, botocore_resolver, sagemaker_session, pytorch_training_version
287+
exists_mock,
288+
isfile_mock,
289+
botocore_resolver,
290+
sagemaker_session,
291+
pytorch_training_version,
292+
pytorch_training_py_version,
288293
):
289294
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
290295

291296
processor = PyTorchProcessor(
292297
role=ROLE,
293298
instance_type="ml.m4.xlarge",
294299
framework_version=pytorch_training_version,
300+
py_version=pytorch_training_py_version,
295301
instance_count=1,
296302
sagemaker_session=sagemaker_session,
297303
)
@@ -302,12 +308,16 @@ def test_pytorch_processor_with_required_parameters(
302308

303309
if version.parse(pytorch_training_version) < version.parse("1.2"):
304310
pytorch_image_uri = (
305-
"520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:{}-cpu-py3"
306-
).format(pytorch_training_version)
311+
"520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-pytorch:{}-cpu-{}".format(
312+
pytorch_training_version, pytorch_training_py_version
313+
)
314+
)
307315
else:
308316
pytorch_image_uri = (
309-
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:{}-cpu-py3"
310-
).format(pytorch_training_version)
317+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:{}-cpu-{}".format(
318+
pytorch_training_version, pytorch_training_py_version
319+
)
320+
)
311321

312322
expected_args["app_specification"]["ImageUri"] = pytorch_image_uri
313323

0 commit comments

Comments
 (0)