Skip to content

Commit 3140226

Browse files
committed
change order
1 parent 85a2212 commit 3140226

File tree

1 file changed

+12
-24
lines changed

1 file changed

+12
-24
lines changed

src/sagemaker/huggingface/estimator.py

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,15 @@ def __init__(
213213

214214
self._validate_args(image_uri=image_uri)
215215

216+
if "enable_sagemaker_metrics" not in kwargs:
217+
kwargs["enable_sagemaker_metrics"] = True
218+
219+
kwargs["py_version"] = self.py_version
220+
221+
super(HuggingFace, self).__init__(entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs)
222+
216223
self.base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
217-
self.base_framework_version = (
218-
tensorflow_version if tensorflow_version is not None else pytorch_version
219-
)
224+
self.base_framework_version = tensorflow_version if tensorflow_version is not None else pytorch_version
220225

221226
if distribution is not None:
222227
distribution = validate_distribution(
@@ -231,15 +236,6 @@ def __init__(
231236

232237
self.distribution = distribution or {}
233238

234-
if "enable_sagemaker_metrics" not in kwargs:
235-
kwargs["enable_sagemaker_metrics"] = True
236-
237-
kwargs["py_version"] = self.py_version
238-
239-
super(HuggingFace, self).__init__(
240-
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
241-
)
242-
243239
if compiler_config is not None:
244240
if not isinstance(compiler_config, TrainingCompilerConfig):
245241
error_string = (
@@ -324,18 +320,12 @@ def _huggingface_distribution_configuration(self, distribution):
324320
def hyperparameters(self):
325321
"""Return hyperparameters used by your custom PyTorch code during model training."""
326322
hyperparameters = super(HuggingFace, self).hyperparameters()
327-
additional_hyperparameters = self._huggingface_distribution_configuration(
328-
distribution=self.distribution
329-
)
330-
hyperparameters.update(
331-
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
332-
)
323+
additional_hyperparameters = self._huggingface_distribution_configuration(distribution=self.distribution)
324+
hyperparameters.update(EstimatorBase._json_encode_hyperparameters(additional_hyperparameters))
333325

334326
if self.compiler_config:
335327
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
336-
hyperparameters.update(
337-
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
338-
)
328+
hyperparameters.update(EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters))
339329

340330
return hyperparameters
341331

@@ -445,9 +435,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
445435

446436
if framework != cls._framework_name:
447437
raise ValueError(
448-
"Training job: {} didn't use image for requested framework".format(
449-
job_details["TrainingJobName"]
450-
)
438+
"Training job: {} didn't use image for requested framework".format(job_details["TrainingJobName"])
451439
)
452440

453441
return init_params

0 commit comments

Comments
 (0)