@@ -169,7 +169,7 @@ def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_sessi
169
169
raise NotImplementedError ()
170
170
171
171
@classmethod
172
- def attach (cls , training_job_name , sagemaker_session = None ):
172
+ def attach (cls , training_job_name , sagemaker_session = None , job_details = None ):
173
173
"""Attach to an existing training job.
174
174
175
175
Create an Estimator bound to an existing training job, each subclass is responsible to implement
@@ -185,6 +185,7 @@ def attach(cls, training_job_name, sagemaker_session=None):
185
185
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
186
186
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
187
187
using the default AWS configuration chain.
188
+ training_job_details (
188
189
189
190
Examples:
190
191
>>> my_estimator.fit(wait=False)
@@ -198,13 +199,10 @@ def attach(cls, training_job_name, sagemaker_session=None):
198
199
"""
199
200
sagemaker_session = sagemaker_session or Session ()
200
201
201
- if training_job_name :
202
- job_details = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
203
- init_params , hp , image = cls ._prepare_estimator_params_from_job_description (job_details )
204
- else :
205
- raise ValueError ('must specify training_job name' )
202
+ job_details = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
203
+ init_params = cls ._prepare_init_params_from_job_description (job_details )
206
204
207
- estimator = cls . _from_training_job ( init_params , hp , image , sagemaker_session )
205
+ estimator = cls ( sagemaker_session = sagemaker_session , ** init_params )
208
206
estimator .latest_training_job = _TrainingJob (sagemaker_session = sagemaker_session ,
209
207
training_job_name = init_params ['base_job_name' ])
210
208
estimator .latest_training_job .wait ()
@@ -257,21 +255,33 @@ def create_model(self, **kwargs):
257
255
"""
258
256
pass
259
257
260
- @staticmethod
261
- def _prepare_estimator_params_from_job_description (job_details ):
262
- estimator_params = dict ()
258
+ @classmethod
259
+ def _prepare_init_params_from_job_description (cls , job_details ):
260
+ """Convert the job description to init params that can be handled by the class constructor
261
+
262
+ Args:
263
+ job_details: the returned job details from a describe_training_job API call.
264
+
265
+ Returns:
266
+ dictionary: The transformed init_params
263
267
264
- estimator_params ['role' ] = job_details ['RoleArn' ]
265
- estimator_params ['train_instance_count' ] = job_details ['ResourceConfig' ]['InstanceCount' ]
266
- estimator_params ['train_instance_type' ] = job_details ['ResourceConfig' ]['InstanceType' ]
267
- estimator_params ['train_volume_size' ] = job_details ['ResourceConfig' ]['VolumeSizeInGB' ]
268
- estimator_params ['train_max_run' ] = job_details ['StoppingCondition' ]['MaxRuntimeInSeconds' ]
269
- estimator_params ['input_mode' ] = job_details ['AlgorithmSpecification' ]['TrainingInputMode' ]
270
- estimator_params ['base_job_name' ] = job_details ['TrainingJobName' ]
271
- estimator_params ['output_path' ] = job_details ['OutputDataConfig' ]['S3OutputPath' ]
272
- estimator_params ['output_kms_key' ] = job_details ['OutputDataConfig' ]['KmsKeyId' ]
268
+ """
269
+ init_params = dict ()
270
+
271
+ init_params ['role' ] = job_details ['RoleArn' ]
272
+ init_params ['train_instance_count' ] = job_details ['ResourceConfig' ]['InstanceCount' ]
273
+ init_params ['train_instance_type' ] = job_details ['ResourceConfig' ]['InstanceType' ]
274
+ init_params ['train_volume_size' ] = job_details ['ResourceConfig' ]['VolumeSizeInGB' ]
275
+ init_params ['train_max_run' ] = job_details ['StoppingCondition' ]['MaxRuntimeInSeconds' ]
276
+ init_params ['input_mode' ] = job_details ['AlgorithmSpecification' ]['TrainingInputMode' ]
277
+ init_params ['base_job_name' ] = job_details ['TrainingJobName' ]
278
+ init_params ['output_path' ] = job_details ['OutputDataConfig' ]['S3OutputPath' ]
279
+ init_params ['output_kms_key' ] = job_details ['OutputDataConfig' ]['KmsKeyId' ]
280
+
281
+ init_params ['hyperparameters' ] = job_details ['HyperParameters' ]
282
+ init_params ['image' ] = job_details ['AlgorithmSpecification' ]['TrainingImage' ]
273
283
274
- return estimator_params , job_details [ 'HyperParameters' ], job_details [ 'AlgorithmSpecification' ][ 'TrainingImage' ]
284
+ return init_params
275
285
276
286
def delete_endpoint (self ):
277
287
"""Delete an Amazon SageMaker ``Endpoint``.
@@ -388,7 +398,8 @@ class Estimator(EstimatorBase):
388
398
389
399
def __init__ (self , image_name , role , train_instance_count , train_instance_type ,
390
400
train_volume_size = 30 , train_max_run = 24 * 60 * 60 , input_mode = 'File' ,
391
- output_path = None , output_kms_key = None , base_job_name = None , sagemaker_session = None ):
401
+ output_path = None , output_kms_key = None , base_job_name = None , sagemaker_session = None ,
402
+ hyperparameters = None ):
392
403
"""Initialize an ``Estimator`` instance.
393
404
394
405
Args:
@@ -420,9 +431,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
420
431
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
421
432
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
422
433
using the default AWS configuration chain.
434
+ hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with.
423
435
"""
424
436
self .image_name = image_name
425
- self .hyperparam_dict = {}
437
+ self .hyperparam_dict = hyperparameters . copy () if hyperparameters else {}
426
438
super (Estimator , self ).__init__ (role , train_instance_count , train_instance_type ,
427
439
train_volume_size , train_max_run , input_mode ,
428
440
output_path , output_kms_key , base_job_name , sagemaker_session )
@@ -478,23 +490,20 @@ def predict_wrapper(endpoint, session):
478
490
predictor_cls = predictor_cls , ** kwargs )
479
491
480
492
@classmethod
481
- def _from_training_job (cls , init_params , hyperparameters , image , sagemaker_session ):
482
- """Create an Estimator from existing training job data.
493
+ def _prepare_init_params_from_job_description (cls , job_details ):
494
+ """Convert the job description to init params that can be handled by the class constructor
483
495
484
496
Args:
485
- init_params (dict): The init_params the training job was created with.
486
- hyperparameters (dict): The hyperparameters the training job was created with.
487
- image (str): Container image (if any) the training job was created with
488
- sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
497
+ job_details: the returned job details from a describe_training_job API call.
489
498
490
- Returns: An instance of the calling Estimator Class.
499
+ Returns:
500
+ dictionary: The transformed init_params
491
501
492
502
"""
503
+ init_params = super (Estimator , cls )._prepare_init_params_from_job_description (job_details )
493
504
494
- estimator = cls (sagemaker_session = sagemaker_session , ** init_params )
495
- cls .set_hyperparameters (** hyperparameters )
496
-
497
- return estimator
505
+ init_params ['image_name' ] = init_params .pop ('image' )
506
+ return init_params
498
507
499
508
500
509
class Framework (EstimatorBase ):
@@ -602,35 +611,32 @@ def hyperparameters(self):
602
611
return self ._json_encode_hyperparameters (self ._hyperparameters )
603
612
604
613
@classmethod
605
- def _from_training_job (cls , init_params , hyperparameters , image , sagemaker_session ):
606
- """Create an Estimator from existing training job data.
614
+ def _prepare_init_params_from_job_description (cls , job_details ):
615
+ """Convert the job description to init params that can be handled by the class constructor
607
616
608
617
Args:
609
- init_params (dict): The init_params the training job was created with.
610
- hyperparameters (dict): The hyperparameters the training job was created with.
611
- image (str): Container image (if any) the training job was created with
612
- sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
618
+ job_details: the returned job details from a describe_training_job API call.
613
619
614
- Returns: An instance of the calling Estimator Class.
620
+ Returns:
621
+ dictionary: The transformed init_params
615
622
616
623
"""
624
+ init_params = super (Framework , cls )._prepare_init_params_from_job_description (job_details )
617
625
618
- # parameters for framework classes
619
- framework_init_params = dict ()
620
- framework_init_params ['entry_point' ] = json .loads (hyperparameters .get (SCRIPT_PARAM_NAME ))
621
- framework_init_params ['source_dir' ] = json .loads (hyperparameters .get (DIR_PARAM_NAME ))
622
- framework_init_params ['enable_cloudwatch_metrics' ] = json .loads (
623
- hyperparameters .get (CLOUDWATCH_METRICS_PARAM_NAME ))
624
- framework_init_params ['container_log_level' ] = json .loads (
625
- hyperparameters .get (CONTAINER_LOG_LEVEL_PARAM_NAME ))
626
+ init_params ['entry_point' ] = json .loads (init_params ['hyperparameters' ].get (SCRIPT_PARAM_NAME ))
627
+ init_params ['source_dir' ] = json .loads (init_params ['hyperparameters' ].get (DIR_PARAM_NAME ))
628
+ init_params ['enable_cloudwatch_metrics' ] = json .loads (
629
+ init_params ['hyperparameters' ].get (CLOUDWATCH_METRICS_PARAM_NAME ))
630
+ init_params ['container_log_level' ] = json .loads (
631
+ init_params ['hyperparameters' ].get (CONTAINER_LOG_LEVEL_PARAM_NAME ))
626
632
627
- # drop json and remove other SageMaker specific additions
628
- deserialized_hps = {entry : json .loads (hyperparameters [entry ]) for entry in hyperparameters }
629
- framework_init_params ['hyperparameters' ] = deserialized_hps
633
+ init_params ['hyperparameters' ] = {k : json .loads (v ) for k , v in init_params ['hyperparameters' ].items ()}
630
634
631
- init_params . update ( framework_init_params )
635
+ return init_params
632
636
633
- estimator = cls (sagemaker_session = sagemaker_session , ** init_params )
637
+ @classmethod
638
+ def attach (cls , training_job_name , sagemaker_session = None ):
639
+ estimator = super (Framework , cls ).attach (training_job_name , sagemaker_session )
634
640
estimator .uploaded_code = UploadedCode (estimator .source_dir , estimator .entry_point )
635
641
return estimator
636
642
0 commit comments