|
170 | 170 | }
|
171 | 171 | DISTRIBUTION_SM_DDP_ENABLED = {
|
172 | 172 | "smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
|
173 |
| - "torch_distributed": {"enabled": False} |
| 173 | + "torch_distributed": {"enabled": False}, |
174 | 174 | }
|
175 | 175 | DISTRIBUTION_SM_DDP_DISABLED = {
|
176 | 176 | "smdistributed": {"enabled": True},
|
177 |
| - "torch_distributed": {"enabled": False} |
| 177 | + "torch_distributed": {"enabled": False}, |
178 | 178 | }
|
179 | 179 | DISTRIBUTION_SM_TORCH_DIST_AND_DDP_ENABLED = {
|
180 | 180 | "smdistributed": {"dataparallel": {"enabled": True, "custom_mpi_options": "options"}},
|
181 |
| - "torch_distributed": {"enabled": True} |
| 181 | + "torch_distributed": {"enabled": True}, |
182 | 182 | }
|
183 | 183 | DISTRIBUTION_SM_TORCH_DIST_AND_DDP_DISABLED = {
|
184 | 184 | "smdistributed": {"enabled": True},
|
185 |
| - "torch_distributed": {"enabled": True} |
| 185 | + "torch_distributed": {"enabled": True}, |
186 | 186 | }
|
187 | 187 | MOCKED_S3_URI = "s3://mocked_s3_uri_from_source_dir"
|
188 | 188 | _DEFINITION_CONFIG = PipelineDefinitionConfig(use_custom_job_prefix=False)
|
@@ -360,18 +360,18 @@ def test_validate_smdistributed_unsupported_image_raises(sagemaker_session):
|
360 | 360 | def test_validate_smdistributed_p5_raises(sagemaker_session):
|
361 | 361 | # Supported DLC image.
|
362 | 362 | f = DummyFramework(
|
363 |
| - "some_script.py", |
| 363 | + "some_script.py", |
364 | 364 | role="DummyRole",
|
365 |
| - instance_type="ml.p5.48xlarge", |
| 365 | + instance_type="ml.p5.48xlarge", |
366 | 366 | sagemaker_session=sagemaker_session,
|
367 | 367 | output_path="outputpath",
|
368 | 368 | image_uri="some_acceptable_image",
|
369 | 369 | )
|
370 | 370 | # Both fail because instance type is p5 and torch_distributed is off.
|
371 | 371 | with pytest.raises(ValueError):
|
372 |
| - f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED) |
| 372 | + f._distribution_configuration(DISTRIBUTION_SM_DDP_ENABLED) |
373 | 373 | with pytest.raises(ValueError):
|
374 |
| - f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED) |
| 374 | + f._distribution_configuration(DISTRIBUTION_SM_DDP_DISABLED) |
375 | 375 |
|
376 | 376 |
|
377 | 377 | def test_validate_smdistributed_p5_not_raises(sagemaker_session):
|
|
0 commit comments