Skip to content

Commit 4dd0abb

Browse files
viclzhubhupendrasingh
authored andcommitted
fix: fix formatting
1 parent 20cd14b commit 4dd0abb

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

src/sagemaker/estimator.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3852,13 +3852,22 @@ def _distribution_configuration(self, distribution):
38523852
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
38533853
smdistributed = distribution["smdistributed"]
38543854
smdataparallel_enabled = smdistributed.get("dataparallel", {}).get("enabled", False)
3855-
p5_enabled = bool("p5.48xlarge" in self.instance_type)
3855+
p5_enabled = "p5.48xlarge" in self.instance_type
38563856
img_uri = "" if self.image_uri is None else self.image_uri
38573857
for unsupported_image in Framework.UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM:
3858-
if unsupported_image in img_uri and not torch_distributed_enabled: #disabling DLC images with CUDA12
3859-
raise ValueError(f"SMDistributed is currently incompatible with DLC image: {img_uri}. (Could be due to CUDA version being greater than 11.)")
3860-
if not torch_distributed_enabled and p5_enabled: #disabling p5 when torch distributed is disabled
3861-
raise ValueError("SMModelParallel and SMDataParallel currently do not support p5 instances.")
3858+
if (
3859+
unsupported_image in img_uri and not torch_distributed_enabled
3860+
): # disabling DLC images with CUDA12
3861+
raise ValueError(
3862+
f"SMDistributed is currently incompatible with DLC image: {img_uri}. "
3863+
"(Could be due to CUDA version being greater than 11.)"
3864+
)
3865+
if (
3866+
not torch_distributed_enabled and p5_enabled
3867+
): # disabling p5 when torch distributed is disabled
3868+
raise ValueError(
3869+
"SMModelParallel and SMDataParallel currently do not support p5 instances."
3870+
)
38623871
# smdistributed strategy selected with supported instance type
38633872
distribution_config[self.LAUNCH_SM_DDP_ENV_NAME] = smdataparallel_enabled
38643873
distribution_config[self.INSTANCE_TYPE] = self.instance_type

tests/unit/test_estimator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -170,19 +170,19 @@
170170
}
171171
DISTRIBUTION_SM_DDP_ENABLED = {
172172
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
173-
"torch_distributed": {"enabled": False}
173+
"torch_distributed": {"enabled": False},
174174
}
175175
DISTRIBUTION_SM_DDP_DISABLED = {
176176
"smdistributed": {"enabled": True},
177-
"torch_distributed": {"enabled": False}
177+
"torch_distributed": {"enabled": False},
178178
}
179179
DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED = {
180180
"smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
181-
"torch_distributed": {"enabled": True}
181+
"torch_distributed": {"enabled": True},
182182
}
183183
DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED = {
184184
"smdistributed": {"enabled": True},
185-
"torch_distributed": {"enabled": True}
185+
"torch_distributed": {"enabled": True},
186186
}
187187
MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir"
188188
_DEFINITION_CONFIG = PipelineDefinitionConfig(use_custom_job_prefix=False)
@@ -360,18 +360,18 @@ def test_validate_smdistributed_unsupported_image_raises(sagemaker_session):
360360
def test_validate_smdistributed_p5_raises(sagemaker_session):
361361
# Supported DLC image.
362362
f = DummyFramework(
363-
"some_script.py",
363+
"some_script.py",
364364
role="DummyRole",
365-
instance_type="ml.p5.48xlarge",
365+
instance_type="ml.p5.48xlarge",
366366
sagemaker_session=sagemaker_session,
367367
output_path="outputpath",
368368
image_uri="some_acceptable_image",
369369
)
370370
# Both fail because instance type is p5 and torch_distributed is off.
371371
with pytest.raises(ValueError):
372-
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
372+
f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED)
373373
with pytest.raises(ValueError):
374-
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
374+
f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED)
375375

376376

377377
def test_validate_smdistributed_p5_not_raises(sagemaker_session):

0 commit comments

Comments
 (0)