Skip to content

Commit 2b633fb

Browse files
author
Sirut Buasai
committed
edit unit tests for pt2.0.1 and pt2.1
1 parent 83bf6f0 commit 2b633fb

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

src/sagemaker/fw_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@
138138
"1.12.1",
139139
"1.13.1",
140140
"2.0.0",
141+
"2.0.1"
141142
],
142143
}
143144

@@ -153,10 +154,11 @@
153154
"1.13.1",
154155
"2.0.0",
155156
"2.0.1",
157+
"2.1.0"
156158
]
157159

158160

159-
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1"]
161+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0"]
160162

161163
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
162164
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [

tests/unit/test_fw_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ def test_validate_smdataparallel_args_not_raises():
937937
("ml.p3.16xlarge", "pytorch", "1.12", "py38", smdataparallel_enabled),
938938
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled),
939939
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled),
940+
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled),
940941
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
941942
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
942943
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
@@ -959,6 +960,7 @@ def test_validate_smdataparallel_args_not_raises():
959960
("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled_custom_mpi),
960961
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled_custom_mpi),
961962
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled_custom_mpi),
963+
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled_custom_mpi),
962964
]
963965
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
964966
fw_utils._validate_smdataparallel_args(
@@ -995,6 +997,10 @@ def test_validate_pytorchddp_not_raises():
995997
"1.12",
996998
"1.12.0",
997999
"1.12.1",
1000+
"1.13.1",
1001+
"2.0.0",
1002+
"2.0.1",
1003+
"2.1.0"
9981004
]
9991005
for framework_version in pytorchddp_supported_fw_versions:
10001006
fw_utils.validate_pytorch_distribution(
@@ -1060,6 +1066,8 @@ def test_validate_torch_distributed_not_raises():
10601066
torch_distributed_gpu_supported_fw_versions = [
10611067
"1.13.1",
10621068
"2.0.0",
1069+
"2.0.1",
1070+
"2.1.0"
10631071
]
10641072
for framework_version in torch_distributed_gpu_supported_fw_versions:
10651073
fw_utils.validate_torch_distributed_distribution(

tests/unit/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,9 @@ def test_set_nested_value():
384384

385385

386386
def test_get_short_version():
387+
assert sagemaker.utils.get_short_version("2.1.0") == "2.1"
388+
assert sagemaker.utils.get_short_version("2.1") == "2.1"
389+
assert sagemaker.utils.get_short_version("2.0.1") == "2.0"
387390
assert sagemaker.utils.get_short_version("2.0.0") == "2.0"
388391
assert sagemaker.utils.get_short_version("2.0") == "2.0"
389392

0 commit comments

Comments
 (0)