Skip to content

Commit 5bedd29

Browse files
author
Ignacio Quintero
committed
Refactor attach to remove _from_training_job()
_prepare_init_params_from_job_description() is now a classmethod instead of being a static method. Each class is responsible to implement their specific logic to convert a training job description into arguments that can be passed to its own __init__()
1 parent 24ae51e commit 5bedd29

File tree

8 files changed

+104
-108
lines changed

8 files changed

+104
-108
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,28 +65,29 @@ def data_location(self, data_location):
6565
self._data_location = data_location
6666

6767
@classmethod
68-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
69-
"""Create an Estimator from existing training job data.
68+
def _prepare_init_params_from_job_description(cls, job_details):
69+
"""Convert the job description to init params that can be handled by the class constructor
7070
7171
Args:
72-
init_params (dict): The init_params the training job was created with.
73-
hyperparameters (dict): The hyperparameters the training job was created with.
74-
image (str): Container image (if any) the training job was created with
75-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
72+
job_details: the returned job details from a describe_training_job API call.
7673
77-
Returns: An instance of the calling Estimator Class.
74+
Returns:
75+
dictionary: The transformed init_params
7876
7977
"""
78+
init_params = super(AmazonAlgorithmEstimatorBase, cls)._prepare_init_params_from_job_description(job_details)
8079

8180
# The hyperparam names may not be the same as the class attribute that holds them,
8281
# for instance: local_lloyd_init_method is called local_init_method. We need to map these
8382
# and pass the correct name to the constructor.
8483
for attribute, value in cls.__dict__.items():
8584
if isinstance(value, hp):
86-
if value.name in hyperparameters:
87-
init_params[attribute] = hyperparameters[value.name]
85+
if value.name in init_params['hyperparameters']:
86+
init_params[attribute] = init_params['hyperparameters'][value.name]
8887

89-
return cls(sagemaker_session=sagemaker_session, **init_params)
88+
del init_params['hyperparameters']
89+
del init_params['image']
90+
return init_params
9091

9192
def fit(self, records, mini_batch_size=None, **kwargs):
9293
"""Fit this Estimator on serialized Record objects, stored in S3.

src/sagemaker/estimator.py

Lines changed: 59 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_sessi
169169
raise NotImplementedError()
170170

171171
@classmethod
172-
def attach(cls, training_job_name, sagemaker_session=None):
172+
def attach(cls, training_job_name, sagemaker_session=None, job_details=None):
173173
"""Attach to an existing training job.
174174
175175
Create an Estimator bound to an existing training job, each subclass is responsible to implement
@@ -185,6 +185,7 @@ def attach(cls, training_job_name, sagemaker_session=None):
185185
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
186186
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
187187
using the default AWS configuration chain.
188+
training_job_details (
188189
189190
Examples:
190191
>>> my_estimator.fit(wait=False)
@@ -198,13 +199,10 @@ def attach(cls, training_job_name, sagemaker_session=None):
198199
"""
199200
sagemaker_session = sagemaker_session or Session()
200201

201-
if training_job_name:
202-
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
203-
init_params, hp, image = cls._prepare_estimator_params_from_job_description(job_details)
204-
else:
205-
raise ValueError('must specify training_job name')
202+
job_details = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
203+
init_params = cls._prepare_init_params_from_job_description(job_details)
206204

207-
estimator = cls._from_training_job(init_params, hp, image, sagemaker_session)
205+
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
208206
estimator.latest_training_job = _TrainingJob(sagemaker_session=sagemaker_session,
209207
training_job_name=init_params['base_job_name'])
210208
estimator.latest_training_job.wait()
@@ -257,21 +255,33 @@ def create_model(self, **kwargs):
257255
"""
258256
pass
259257

260-
@staticmethod
261-
def _prepare_estimator_params_from_job_description(job_details):
262-
estimator_params = dict()
258+
@classmethod
259+
def _prepare_init_params_from_job_description(cls, job_details):
260+
"""Convert the job description to init params that can be handled by the class constructor
261+
262+
Args:
263+
job_details: the returned job details from a describe_training_job API call.
264+
265+
Returns:
266+
dictionary: The transformed init_params
263267
264-
estimator_params['role'] = job_details['RoleArn']
265-
estimator_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
266-
estimator_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
267-
estimator_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
268-
estimator_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
269-
estimator_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
270-
estimator_params['base_job_name'] = job_details['TrainingJobName']
271-
estimator_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
272-
estimator_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']
268+
"""
269+
init_params = dict()
270+
271+
init_params['role'] = job_details['RoleArn']
272+
init_params['train_instance_count'] = job_details['ResourceConfig']['InstanceCount']
273+
init_params['train_instance_type'] = job_details['ResourceConfig']['InstanceType']
274+
init_params['train_volume_size'] = job_details['ResourceConfig']['VolumeSizeInGB']
275+
init_params['train_max_run'] = job_details['StoppingCondition']['MaxRuntimeInSeconds']
276+
init_params['input_mode'] = job_details['AlgorithmSpecification']['TrainingInputMode']
277+
init_params['base_job_name'] = job_details['TrainingJobName']
278+
init_params['output_path'] = job_details['OutputDataConfig']['S3OutputPath']
279+
init_params['output_kms_key'] = job_details['OutputDataConfig']['KmsKeyId']
280+
281+
init_params['hyperparameters'] = job_details['HyperParameters']
282+
init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage']
273283

274-
return estimator_params, job_details['HyperParameters'], job_details['AlgorithmSpecification']['TrainingImage']
284+
return init_params
275285

276286
def delete_endpoint(self):
277287
"""Delete an Amazon SageMaker ``Endpoint``.
@@ -388,7 +398,8 @@ class Estimator(EstimatorBase):
388398

389399
def __init__(self, image_name, role, train_instance_count, train_instance_type,
390400
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
391-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
401+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None,
402+
hyperparameters=None):
392403
"""Initialize an ``Estimator`` instance.
393404
394405
Args:
@@ -420,9 +431,10 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
420431
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
421432
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
422433
using the default AWS configuration chain.
434+
hyperparameters (dict): Dictionary containing the hyperparameters to initialize this estimator with.
423435
"""
424436
self.image_name = image_name
425-
self.hyperparam_dict = {}
437+
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
426438
super(Estimator, self).__init__(role, train_instance_count, train_instance_type,
427439
train_volume_size, train_max_run, input_mode,
428440
output_path, output_kms_key, base_job_name, sagemaker_session)
@@ -478,23 +490,20 @@ def predict_wrapper(endpoint, session):
478490
predictor_cls=predictor_cls, **kwargs)
479491

480492
@classmethod
481-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
482-
"""Create an Estimator from existing training job data.
493+
def _prepare_init_params_from_job_description(cls, job_details):
494+
"""Convert the job description to init params that can be handled by the class constructor
483495
484496
Args:
485-
init_params (dict): The init_params the training job was created with.
486-
hyperparameters (dict): The hyperparameters the training job was created with.
487-
image (str): Container image (if any) the training job was created with
488-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
497+
job_details: the returned job details from a describe_training_job API call.
489498
490-
Returns: An instance of the calling Estimator Class.
499+
Returns:
500+
dictionary: The transformed init_params
491501
492502
"""
503+
init_params = super(Estimator, cls)._prepare_init_params_from_job_description(job_details)
493504

494-
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
495-
cls.set_hyperparameters(**hyperparameters)
496-
497-
return estimator
505+
init_params['image_name'] = init_params.pop('image')
506+
return init_params
498507

499508

500509
class Framework(EstimatorBase):
@@ -602,35 +611,32 @@ def hyperparameters(self):
602611
return self._json_encode_hyperparameters(self._hyperparameters)
603612

604613
@classmethod
605-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
606-
"""Create an Estimator from existing training job data.
614+
def _prepare_init_params_from_job_description(cls, job_details):
615+
"""Convert the job description to init params that can be handled by the class constructor
607616
608617
Args:
609-
init_params (dict): The init_params the training job was created with.
610-
hyperparameters (dict): The hyperparameters the training job was created with.
611-
image (str): Container image (if any) the training job was created with
612-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
618+
job_details: the returned job details from a describe_training_job API call.
613619
614-
Returns: An instance of the calling Estimator Class.
620+
Returns:
621+
dictionary: The transformed init_params
615622
616623
"""
624+
init_params = super(Framework, cls)._prepare_init_params_from_job_description(job_details)
617625

618-
# parameters for framework classes
619-
framework_init_params = dict()
620-
framework_init_params['entry_point'] = json.loads(hyperparameters.get(SCRIPT_PARAM_NAME))
621-
framework_init_params['source_dir'] = json.loads(hyperparameters.get(DIR_PARAM_NAME))
622-
framework_init_params['enable_cloudwatch_metrics'] = json.loads(
623-
hyperparameters.get(CLOUDWATCH_METRICS_PARAM_NAME))
624-
framework_init_params['container_log_level'] = json.loads(
625-
hyperparameters.get(CONTAINER_LOG_LEVEL_PARAM_NAME))
626+
init_params['entry_point'] = json.loads(init_params['hyperparameters'].get(SCRIPT_PARAM_NAME))
627+
init_params['source_dir'] = json.loads(init_params['hyperparameters'].get(DIR_PARAM_NAME))
628+
init_params['enable_cloudwatch_metrics'] = json.loads(
629+
init_params['hyperparameters'].get(CLOUDWATCH_METRICS_PARAM_NAME))
630+
init_params['container_log_level'] = json.loads(
631+
init_params['hyperparameters'].get(CONTAINER_LOG_LEVEL_PARAM_NAME))
626632

627-
# drop json and remove other SageMaker specific additions
628-
deserialized_hps = {entry: json.loads(hyperparameters[entry]) for entry in hyperparameters}
629-
framework_init_params['hyperparameters'] = deserialized_hps
633+
init_params['hyperparameters'] = {k: json.loads(v) for k, v in init_params['hyperparameters'].items()}
630634

631-
init_params.update(framework_init_params)
635+
return init_params
632636

633-
estimator = cls(sagemaker_session=sagemaker_session, **init_params)
637+
@classmethod
638+
def attach(cls, training_job_name, sagemaker_session=None):
639+
estimator = super(Framework, cls).attach(training_job_name, sagemaker_session)
634640
estimator.uploaded_code = UploadedCode(estimator.source_dir, estimator.entry_point)
635641
return estimator
636642

src/sagemaker/mxnet/estimator.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -82,24 +82,23 @@ def create_model(self, model_server_workers=None):
8282
sagemaker_session=self.sagemaker_session)
8383

8484
@classmethod
85-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
86-
"""Create an Estimator from existing training job data.
85+
def _prepare_init_params_from_job_description(cls, job_details):
86+
"""Convert the job description to init params that can be handled by the class constructor
8787
8888
Args:
89-
init_params (dict): The init_params the training job was created with.
90-
hyperparameters (dict): The hyperparameters the training job was created with.
91-
image (str): Container image (if any) the training job was created with
92-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
89+
job_details: the returned job details from a describe_training_job API call.
9390
94-
Returns: An instance of the calling Estimator Class.
91+
Returns:
92+
dictionary: The transformed init_params
9593
9694
"""
95+
init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details)
96+
framework, py_version = framework_name_from_image(init_params.pop('image'))
9797

98-
framework, py_version = framework_name_from_image(image)
99-
init_params.update({'py_version': py_version})
100-
98+
init_params['py_version'] = py_version
10199
training_job_name = init_params['base_job_name']
100+
102101
if framework != cls.__framework_name__:
103102
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
104103

105-
return super(MXNet, cls)._from_training_job(init_params, hyperparameters, image, sagemaker_session)
104+
return init_params

src/sagemaker/tensorflow/estimator.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
import logging
14+
import os
1415
import subprocess
1516
import tempfile
1617
import threading
1718

18-
import os
19-
2019
import sagemaker.tensorflow
2120
from sagemaker.estimator import Framework
2221
from sagemaker.fw_utils import create_image_uri, framework_name_from_image
@@ -168,31 +167,32 @@ def fit_super():
168167
fit_super()
169168

170169
@classmethod
171-
def _from_training_job(cls, init_params, hyperparameters, image, sagemaker_session):
172-
"""Create an Estimator from existing training job data.
170+
def _prepare_init_params_from_job_description(cls, job_details):
171+
"""Convert the job description to init params that can be handled by the class constructor
173172
174173
Args:
175-
init_params (dict): The init_params the training job was created with.
176-
hyperparameters (dict): The hyperparameters the training job was created with.
177-
image (str): Container image (if any) the training job was created with
178-
sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
174+
job_details: the returned job details from a describe_training_job API call.
179175
180-
Returns: An instance of the calling Estimator Class.
176+
Returns:
177+
dictionary: The transformed init_params
181178
182179
"""
180+
init_params = super(TensorFlow, cls)._prepare_init_params_from_job_description(job_details)
183181

184-
updated_params = cls._update_init_params(hyperparameters,
185-
['checkpoint_path', 'training_steps', 'evaluation_steps'])
186-
init_params.update(updated_params)
182+
# Move some of the tensorflow specific init params from hyperparameters into the main init params.
183+
for argument in ['checkpoint_path', 'training_steps', 'evaluation_steps']:
184+
value = init_params['hyperparameters'].pop(argument, None)
185+
if value is not None:
186+
init_params[argument] = value
187187

188-
framework, py_version = framework_name_from_image(image)
189-
init_params.update({'py_version': py_version})
190-
training_job_name = init_params['base_job_name']
188+
framework, py_version = framework_name_from_image(init_params.pop('image'))
189+
init_params['py_version'] = py_version
191190

191+
training_job_name = init_params['base_job_name']
192192
if framework != cls.__framework_name__:
193193
raise ValueError("Training job: {} didn't use image for requested framework".format(training_job_name))
194194

195-
return super(TensorFlow, cls)._from_training_job(init_params, hyperparameters, image, sagemaker_session)
195+
return init_params
196196

197197
def train_image(self):
198198
"""Return the Docker image to use for training.

tests/integ/test_byo_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,5 @@ def test_async_byo_estimator():
154154
assert len(result['predictions']) == 10
155155
for prediction in result['predictions']:
156156
assert prediction['score'] is not None
157+
158+
assert estimator.train_image() == image_name

tests/unit/test_estimator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ def train_image(self):
7070
def create_model(self):
7171
return DummyFrameworkModel(self.sagemaker_session)
7272

73+
@classmethod
74+
def _prepare_init_params_from_job_description(cls, job_details):
75+
init_params = super(DummyFramework, cls)._prepare_init_params_from_job_description(job_details)
76+
init_params.pop("image", None)
77+
return init_params
78+
7379

7480
class DummyFrameworkModel(FrameworkModel):
7581

@@ -251,12 +257,6 @@ def test_attach_framework(sagemaker_session):
251257
assert framework_estimator.entry_point == 'iris-dnn-classifier.py'
252258

253259

254-
def test_attach_no_job_name_framework(sagemaker_session):
255-
with pytest.raises(ValueError) as error:
256-
Framework.attach(training_job_name=None, sagemaker_session=sagemaker_session)
257-
assert 'must specify training_job name' in str(error)
258-
259-
260260
def test_fit_then_fit_again(sagemaker_session):
261261
fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
262262
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,

tests/unit/test_mxnet.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,3 @@ def test_attach_wrong_framework(sagemaker_session):
201201
with pytest.raises(ValueError) as error:
202202
MXNet.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
203203
assert "didn't use image for requested framework" in str(error)
204-
205-
206-
def test_attach_no_job_name(sagemaker_session):
207-
with pytest.raises(ValueError) as error:
208-
MXNet.attach(training_job_name=None, sagemaker_session=sagemaker_session)
209-
assert "must specify training_job name" in str(error)

tests/unit/test_tf_estimator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,3 @@ def test_attach_wrong_framework(sagemaker_session):
379379
with pytest.raises(ValueError) as error:
380380
TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
381381
assert "didn't use image for requested framework" in str(error)
382-
383-
384-
def test_attach_no_job_name(sagemaker_session):
385-
with pytest.raises(ValueError) as error:
386-
TensorFlow.attach(training_job_name=None, sagemaker_session=sagemaker_session)
387-
assert "must specify training_job name" in str(error)

0 commit comments

Comments
 (0)