Skip to content

Commit b4978cd

Browse files
vishwakarianavinsoni
authored andcommitted
fix unit tests
1 parent f2d7d07 commit b4978cd

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/sagemaker/fw_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -804,14 +804,16 @@ def validate_pytorch_distribution(
804804
`py_version` is not python3 or
805805
`framework_version` is not in PYTORCHDDP_SUPPORTED_FRAMEWORK_VERSIONS
806806
"""
807-
if framework_name != "pytorch":
807+
if framework_name and framework_name != "pytorch":
808808
# We need to validate only for PyTorch framework
809809
return
810+
811+
pytorch_ddp_enabled = False
810812
if "pytorchddp" in distribution:
811813
pytorch_ddp_enabled = distribution.get("pytorchddp").get("enabled", False)
812-
if not pytorch_ddp_enabled:
813-
# Distribution strategy other than pytorchddp is selected
814-
return
814+
if not pytorch_ddp_enabled:
815+
# Distribution strategy other than pytorchddp is selected
816+
return
815817

816818
err_msg = ""
817819
if not image_uri:

0 commit comments

Comments
 (0)