Skip to content

Commit b612952

Browse files
committed
make tox happy
1 parent 74f9123 commit b612952

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/sagemaker/huggingface/estimator.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,17 @@ def __init__(
218218

219219
kwargs["py_version"] = self.py_version
220220

221-
super(HuggingFace, self).__init__(entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs)
221+
super(HuggingFace, self).__init__(
222+
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
223+
)
222224

223225
if "entry_point" not in kwargs:
224226
kwargs["entry_point"] = entry_point
225227

226228
self.base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
227-
self.base_framework_version = tensorflow_version if tensorflow_version is not None else pytorch_version
229+
self.base_framework_version = (
230+
tensorflow_version if tensorflow_version is not None else pytorch_version
231+
)
228232

229233
if distribution is not None:
230234
distribution = validate_distribution(
@@ -323,12 +327,18 @@ def _huggingface_distribution_configuration(self, distribution):
323327
def hyperparameters(self):
324328
"""Return hyperparameters used by your custom PyTorch code during model training."""
325329
hyperparameters = super(HuggingFace, self).hyperparameters()
326-
additional_hyperparameters = self._huggingface_distribution_configuration(distribution=self.distribution)
327-
hyperparameters.update(EstimatorBase._json_encode_hyperparameters(additional_hyperparameters))
330+
additional_hyperparameters = self._huggingface_distribution_configuration(
331+
distribution=self.distribution
332+
)
333+
hyperparameters.update(
334+
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
335+
)
328336

329337
if self.compiler_config:
330338
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
331-
hyperparameters.update(EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters))
339+
hyperparameters.update(
340+
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
341+
)
332342

333343
return hyperparameters
334344

@@ -438,7 +448,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
438448

439449
if framework != cls._framework_name:
440450
raise ValueError(
441-
"Training job: {} didn't use image for requested framework".format(job_details["TrainingJobName"])
451+
"Training job: {} didn't use image for requested framework".format(
452+
job_details["TrainingJobName"]
453+
)
442454
)
443455

444456
return init_params

0 commit comments

Comments
 (0)