@@ -213,10 +213,15 @@ def __init__(
213
213
214
214
self ._validate_args (image_uri = image_uri )
215
215
216
+ if "enable_sagemaker_metrics" not in kwargs :
217
+ kwargs ["enable_sagemaker_metrics" ] = True
218
+
219
+ kwargs ["py_version" ] = self .py_version
220
+
221
+ super (HuggingFace , self ).__init__ (entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs )
222
+
216
223
self .base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
217
- self .base_framework_version = (
218
- tensorflow_version if tensorflow_version is not None else pytorch_version
219
- )
224
+ self .base_framework_version = tensorflow_version if tensorflow_version is not None else pytorch_version
220
225
221
226
if distribution is not None :
222
227
distribution = validate_distribution (
@@ -231,15 +236,6 @@ def __init__(
231
236
232
237
self .distribution = distribution or {}
233
238
234
- if "enable_sagemaker_metrics" not in kwargs :
235
- kwargs ["enable_sagemaker_metrics" ] = True
236
-
237
- kwargs ["py_version" ] = self .py_version
238
-
239
- super (HuggingFace , self ).__init__ (
240
- entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs
241
- )
242
-
243
239
if compiler_config is not None :
244
240
if not isinstance (compiler_config , TrainingCompilerConfig ):
245
241
error_string = (
@@ -324,18 +320,12 @@ def _huggingface_distribution_configuration(self, distribution):
324
320
def hyperparameters (self ):
325
321
"""Return hyperparameters used by your custom PyTorch code during model training."""
326
322
hyperparameters = super (HuggingFace , self ).hyperparameters ()
327
- additional_hyperparameters = self ._huggingface_distribution_configuration (
328
- distribution = self .distribution
329
- )
330
- hyperparameters .update (
331
- EstimatorBase ._json_encode_hyperparameters (additional_hyperparameters )
332
- )
323
+ additional_hyperparameters = self ._huggingface_distribution_configuration (distribution = self .distribution )
324
+ hyperparameters .update (EstimatorBase ._json_encode_hyperparameters (additional_hyperparameters ))
333
325
334
326
if self .compiler_config :
335
327
training_compiler_hyperparameters = self .compiler_config ._to_hyperparameter_dict ()
336
- hyperparameters .update (
337
- EstimatorBase ._json_encode_hyperparameters (training_compiler_hyperparameters )
338
- )
328
+ hyperparameters .update (EstimatorBase ._json_encode_hyperparameters (training_compiler_hyperparameters ))
339
329
340
330
return hyperparameters
341
331
@@ -445,9 +435,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
445
435
446
436
if framework != cls ._framework_name :
447
437
raise ValueError (
448
- "Training job: {} didn't use image for requested framework" .format (
449
- job_details ["TrainingJobName" ]
450
- )
438
+ "Training job: {} didn't use image for requested framework" .format (job_details ["TrainingJobName" ])
451
439
)
452
440
453
441
return init_params
0 commit comments