Skip to content

Commit a0bb500

Browse files
author
Andrew Tian
committed
fixing formatting
1 parent 1f0ba5c commit a0bb500

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/sagemaker/fw_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,14 @@
161161
]
162162

163163

164-
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.2", "2.2.0"]
164+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
165+
"1.13.1",
166+
"2.0.0",
167+
"2.0.1",
168+
"2.1.0",
169+
"2.1.2",
170+
"2.2.0",
171+
]
165172

166173
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
167174
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [

src/sagemaker/image_uris.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,9 +678,11 @@ def get_training_image_uri(
678678
if "modelparallel" in distribution["smdistributed"]:
679679
if distribution["smdistributed"]["modelparallel"].get("enabled", True):
680680
framework = "pytorch-smp"
681-
if "p5" in instance_type or \
682-
"2.1" in framework_version or \
683-
"2.2" in framework_version:
681+
if (
682+
"p5" in instance_type
683+
or "2.1" in framework_version
684+
or "2.2" in framework_version
685+
):
684686
container_version = "cu121"
685687
else:
686688
container_version = "cu118"

0 commit comments

Comments
 (0)