Skip to content

Commit c13fc6d

Browse files
feature: support PyTorch 1.7.1 training, inference and data parallel (#2185)
Co-authored-by: ChoiByungWook <[email protected]>
1 parent 9f343c1 commit c13fc6d

File tree

6 files changed

+82
-8
lines changed

6 files changed

+82
-8
lines changed

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
)
6161
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
6262
"tensorflow": ["2.3.0", "2.3.1"],
63-
"pytorch": ["1.6.0"],
63+
"pytorch": ["1.6.0", "1.7.1"],
6464
}
6565
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
6666

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
"1.3": "1.3.1",
5555
"1.4": "1.4.0",
5656
"1.5": "1.5.0",
57-
"1.6": "1.6.0"
57+
"1.6": "1.6.0",
58+
"1.7": "1.7.1"
5859
},
5960
"versions": {
6061
"0.4.0": {
@@ -318,6 +319,39 @@
318319
"us-west-2": "763104351884"
319320
},
320321
"repository": "pytorch-inference"
322+
},
323+
"1.7.1": {
324+
"py_versions": [
325+
"py3",
326+
"py36"
327+
],
328+
"registries": {
329+
"af-south-1": "626614931356",
330+
"ap-east-1": "871362719292",
331+
"ap-northeast-1": "763104351884",
332+
"ap-northeast-2": "763104351884",
333+
"ap-south-1": "763104351884",
334+
"ap-southeast-1": "763104351884",
335+
"ap-southeast-2": "763104351884",
336+
"ca-central-1": "763104351884",
337+
"cn-north-1": "727897471807",
338+
"cn-northwest-1": "727897471807",
339+
"eu-central-1": "763104351884",
340+
"eu-north-1": "763104351884",
341+
"eu-west-1": "763104351884",
342+
"eu-west-2": "763104351884",
343+
"eu-west-3": "763104351884",
344+
"eu-south-1": "692866216735",
345+
"me-south-1": "217643126080",
346+
"sa-east-1": "763104351884",
347+
"us-east-1": "763104351884",
348+
"us-east-2": "763104351884",
349+
"us-gov-west-1": "442386744353",
350+
"us-iso-east-1": "886529160074",
351+
"us-west-1": "763104351884",
352+
"us-west-2": "763104351884"
353+
},
354+
"repository": "pytorch-inference"
321355
}
322356
}
323357
},
@@ -334,7 +368,8 @@
334368
"1.3": "1.3.1",
335369
"1.4": "1.4.0",
336370
"1.5": "1.5.0",
337-
"1.6": "1.6.0"
371+
"1.6": "1.6.0",
372+
"1.7": "1.7.1"
338373
},
339374
"versions": {
340375
"0.4.0": {
@@ -599,6 +634,39 @@
599634
"us-west-2": "763104351884"
600635
},
601636
"repository": "pytorch-training"
637+
},
638+
"1.7.1": {
639+
"py_versions": [
640+
"py3",
641+
"py36"
642+
],
643+
"registries": {
644+
"af-south-1": "626614931356",
645+
"ap-east-1": "871362719292",
646+
"ap-northeast-1": "763104351884",
647+
"ap-northeast-2": "763104351884",
648+
"ap-south-1": "763104351884",
649+
"ap-southeast-1": "763104351884",
650+
"ap-southeast-2": "763104351884",
651+
"ca-central-1": "763104351884",
652+
"cn-north-1": "727897471807",
653+
"cn-northwest-1": "727897471807",
654+
"eu-central-1": "763104351884",
655+
"eu-north-1": "763104351884",
656+
"eu-west-1": "763104351884",
657+
"eu-west-2": "763104351884",
658+
"eu-west-3": "763104351884",
659+
"eu-south-1": "692866216735",
660+
"me-south-1": "217643126080",
661+
"sa-east-1": "763104351884",
662+
"us-east-1": "763104351884",
663+
"us-east-2": "763104351884",
664+
"us-gov-west-1": "442386744353",
665+
"us-iso-east-1": "886529160074",
666+
"us-west-1": "763104351884",
667+
"us-west-2": "763104351884"
668+
},
669+
"repository": "pytorch-training"
602670
}
603671
}
604672
}

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,8 @@ def mxnet_eia_latest_py_version():
173173
def pytorch_training_py_version(pytorch_training_version, request):
174174
if Version(pytorch_training_version) < Version("1.5.0"):
175175
return request.param
176+
elif Version(pytorch_training_version) == Version("1.7.1"):
177+
return "py36"
176178
else:
177179
return "py3"
178180

@@ -181,6 +183,8 @@ def pytorch_training_py_version(pytorch_training_version, request):
181183
def pytorch_inference_py_version(pytorch_inference_version, request):
182184
if Version(pytorch_inference_version) < Version("1.4.0"):
183185
return request.param
186+
elif Version(pytorch_inference_version) == Version("1.7.1"):
187+
return "py36"
184188
else:
185189
return "py3"
186190

tests/data/smdistributed_dataparallel/mnist_pt.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import print_function
1414

1515
import argparse
16+
import os
1617
import time
1718
import torch
1819
import torch.nn as nn
@@ -150,8 +151,8 @@ def main():
150151
parser.add_argument(
151152
"--data-path",
152153
type=str,
153-
default="/tmp/data",
154-
help="Path for downloading " "the MNIST dataset",
154+
default=os.environ["SM_CHANNEL_TRAINING"],
155+
help="Path for downloading the MNIST dataset",
155156
)
156157

157158
args = parser.parse_args()
@@ -186,7 +187,7 @@ def main():
186187
train_dataset = datasets.MNIST(
187188
data_path,
188189
train=True,
189-
download=True,
190+
download=False, # True sets a dependency on an external site for our tests.
190191
transform=transforms.Compose(
191192
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
192193
),

tests/integ/test_smdataparallel_pt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from sagemaker.pytorch import PyTorch
2323
from tests.integ import timeout
24-
24+
from tests.integ.test_pytorch import _upload_training_data
2525

2626
smdataparallel_dir = os.path.join(
2727
os.path.dirname(__file__), "..", "data", "smdistributed_dataparallel"
@@ -51,4 +51,4 @@ def test_smdataparallel_pt_mnist(
5151
)
5252

5353
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
54-
estimator.fit(job_name=job_name)
54+
estimator.fit({"training": _upload_training_data(estimator)}, job_name=job_name)

tests/unit/test_fw_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ def test_validate_smdataparallel_args_not_raises():
632632
(None, None, None, None, smdataparallel_disabled),
633633
("ml.p3.16xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
634634
("ml.p3.16xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
635+
("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled),
635636
]
636637
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
637638
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)