@@ -218,13 +218,17 @@ def __init__(
218
218
219
219
kwargs ["py_version" ] = self .py_version
220
220
221
- super (HuggingFace , self ).__init__ (entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs )
221
+ super (HuggingFace , self ).__init__ (
222
+ entry_point , source_dir , hyperparameters , image_uri = image_uri , ** kwargs
223
+ )
222
224
223
225
if "entry_point" not in kwargs :
224
226
kwargs ["entry_point" ] = entry_point
225
227
226
228
self .base_framework_name = "tensorflow" if tensorflow_version is not None else "pytorch"
227
- self .base_framework_version = tensorflow_version if tensorflow_version is not None else pytorch_version
229
+ self .base_framework_version = (
230
+ tensorflow_version if tensorflow_version is not None else pytorch_version
231
+ )
228
232
229
233
if distribution is not None :
230
234
distribution = validate_distribution (
@@ -323,12 +327,18 @@ def _huggingface_distribution_configuration(self, distribution):
323
327
def hyperparameters (self ):
324
328
"""Return hyperparameters used by your custom PyTorch code during model training."""
325
329
hyperparameters = super (HuggingFace , self ).hyperparameters ()
326
- additional_hyperparameters = self ._huggingface_distribution_configuration (distribution = self .distribution )
327
- hyperparameters .update (EstimatorBase ._json_encode_hyperparameters (additional_hyperparameters ))
330
+ additional_hyperparameters = self ._huggingface_distribution_configuration (
331
+ distribution = self .distribution
332
+ )
333
+ hyperparameters .update (
334
+ EstimatorBase ._json_encode_hyperparameters (additional_hyperparameters )
335
+ )
328
336
329
337
if self .compiler_config :
330
338
training_compiler_hyperparameters = self .compiler_config ._to_hyperparameter_dict ()
331
- hyperparameters .update (EstimatorBase ._json_encode_hyperparameters (training_compiler_hyperparameters ))
339
+ hyperparameters .update (
340
+ EstimatorBase ._json_encode_hyperparameters (training_compiler_hyperparameters )
341
+ )
332
342
333
343
return hyperparameters
334
344
@@ -438,7 +448,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
438
448
439
449
if framework != cls ._framework_name :
440
450
raise ValueError (
441
- "Training job: {} didn't use image for requested framework" .format (job_details ["TrainingJobName" ])
451
+ "Training job: {} didn't use image for requested framework" .format (
452
+ job_details ["TrainingJobName" ]
453
+ )
442
454
)
443
455
444
456
return init_params
0 commit comments