Skip to content

Commit eaeda5d

Browse files
authored
Merge branch 'master' into env_support_training
2 parents 0d674f0 + b13baee commit eaeda5d

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

doc/frameworks/pytorch/using_pytorch.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ directories ('train' and 'test').
154154
pytorch_estimator = PyTorch('pytorch-train.py',
155155
instance_type='ml.p3.2xlarge',
156156
instance_count=1,
157-
framework_version='1.5.0',
157+
framework_version='1.8.0',
158158
py_version='py3',
159159
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
160160
pytorch_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
@@ -248,7 +248,7 @@ operation.
248248
pytorch_estimator = PyTorch(entry_point='train_and_deploy.py',
249249
instance_type='ml.p3.2xlarge',
250250
instance_count=1,
251-
framework_version='1.5.0',
251+
framework_version='1.8.0',
252252
py_version='py3')
253253
pytorch_estimator.fit('s3://my_bucket/my_training_data/')
254254

src/sagemaker/fw_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@
5959
"local_gpu",
6060
)
6161
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
62-
"tensorflow": ["2.3.1", "2.3.2", "2.4.1"],
63-
"pytorch": ["1.6.0", "1.7.1", "1.8.0"],
62+
"tensorflow": ["2.3", "2.3.1", "2.3.2", "2.4", "2.4.1"],
63+
"pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0"],
6464
}
6565
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
6666

tests/unit/test_fw_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,10 +632,15 @@ def test_validate_smdataparallel_args_not_raises():
632632
(None, None, None, None, smdataparallel_disabled),
633633
("ml.p3.16xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
634634
("ml.p3.16xlarge", "tensorflow", "2.3.2", "py3", smdataparallel_enabled),
635+
("ml.p3.16xlarge", "tensorflow", "2.3", "py3", smdataparallel_enabled),
635636
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled),
637+
("ml.p3.16xlarge", "tensorflow", "2.4", "py3", smdataparallel_enabled),
636638
("ml.p3.16xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
639+
("ml.p3.16xlarge", "pytorch", "1.6", "py3", smdataparallel_enabled),
637640
("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled),
641+
("ml.p3.16xlarge", "pytorch", "1.7", "py3", smdataparallel_enabled),
638642
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled),
643+
("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
639644
]
640645
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
641646
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)