Skip to content

Commit c77f874

Browse files
committed
linting
1 parent a91b81d commit c77f874

File tree

5 files changed

+24
-12
lines changed

5 files changed

+24
-12
lines changed

src/sagemaker/estimator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2751,16 +2751,26 @@ def _validate_and_set_debugger_configs(self):
27512751
def _validate_mwms_config(self):
27522752
"""Validate Multi Worker Mirrored Strategy configuration."""
27532753
minimum_supported_framework_version = {
2754-
'tensorflow': {'framework_version': '2.9'},
2755-
}
2754+
"tensorflow": {"framework_version": "2.9"},
2755+
}
27562756
if self._framework_name in minimum_supported_framework_version:
27572757
for version_argument in minimum_supported_framework_version[self._framework_name]:
27582758
current = getattr(self, version_argument)
2759-
threshold = minimum_supported_framework_version[self._framework_name][version_argument]
2759+
threshold = minimum_supported_framework_version[self._framework_name][
2760+
version_argument
2761+
]
27602762
if Version(current) in SpecifierSet(f"< {threshold}"):
2761-
raise ValueError("Multi Worker Mirrored Strategy is only supported from {} {} but received {}".format(version_argument, threshold, current))
2763+
raise ValueError(
2764+
"Multi Worker Mirrored Strategy is only supported from {} {} but received {}".format(
2765+
version_argument, threshold, current
2766+
)
2767+
)
27622768
else:
2763-
raise ValueError("Multi Worker Mirrored Strategy is currently only supported with {} frameworks but received {}".format(minimum_supported_framework_version.keys(), self._framework_name))
2769+
raise ValueError(
2770+
"Multi Worker Mirrored Strategy is currently only supported with {} frameworks but received {}".format(
2771+
minimum_supported_framework_version.keys(), self._framework_name
2772+
)
2773+
)
27642774

27652775
def _model_source_dir(self):
27662776
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.

src/sagemaker/tensorflow/training_compiler/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def validate(
111111
raise ValueError(error_helper_string)
112112

113113
if estimator.distribution and "multi_worker_mirrored_strategy" in estimator.distribution:
114-
mwms_enabled = estimator.distribution.get("multi_worker_mirrored_strategy").get("enabled", False)
114+
mwms_enabled = estimator.distribution.get("multi_worker_mirrored_strategy").get(
115+
"enabled", False
116+
)
115117
if mwms_enabled:
116118
raise ValueError(
117119
"Multi Worker Mirrored Strategy distributed training configuration "

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ def test_fit_mwms_unsupported(time, strftime, sagemaker_session):
571571
inputs = "s3://mybucket/train"
572572
tf.fit(inputs=inputs)
573573

574-
assert 'only supported from' in str(error)
575-
assert 'but received' in str(error)
574+
assert "only supported from" in str(error)
575+
assert "but received" in str(error)
576576

577577

578578
def test_hyperparameters_no_model_dir(

tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def test_mwms(
226226
framework_version=tensorflow_training_version,
227227
enable_sagemaker_metrics=False,
228228
compiler_config=TrainingCompilerConfig(),
229-
distribution={'multi_worker_mirrored_strategy': True},
229+
distribution={"multi_worker_mirrored_strategy": True},
230230
).fit()
231231

232232
def test_python_2(

tests/unit/test_estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3521,9 +3521,9 @@ def test_mwms_distribution_configuration(sagemaker_session):
35213521
)
35223522
with pytest.raises(ValueError) as error:
35233523
framework._distribution_configuration(distribution=DISTRIBUTION_MWMS_ENABLED)
3524-
3525-
assert 'only supported with' in str(error)
3526-
assert 'but received' in str(error)
3524+
3525+
assert "only supported with" in str(error)
3526+
assert "but received" in str(error)
35273527

35283528

35293529
def test_image_name_map(sagemaker_session):

0 commit comments

Comments
 (0)