Skip to content

Commit 320e7d4

Browse files
authored
Add HyperparameterTuner.attach() (aws#33)
1 parent 5f41e70 commit 320e7d4

File tree

3 files changed

+252
-6
lines changed

3 files changed

+252
-6
lines changed

src/sagemaker/tuner.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,29 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import importlib
1516
import inspect
1617
import json
1718

1819
from sagemaker.analytics import HyperparameterTuningJobAnalytics
1920
from sagemaker.estimator import Framework
2021
from sagemaker.job import _Job
22+
from sagemaker.session import Session
2123
from sagemaker.utils import base_name_from_image, name_from_base
2224

25+
# TODO: probably move these somewhere to Amazon Estimator land after
26+
# the circular dependency issue is resolved
27+
AMAZON_ESTIMATOR_MODULE = 'sagemaker'
28+
AMAZON_ESTIMATOR_CLS_NAMES = {
29+
'factorization-machines': 'FactorizationMachines',
30+
'kmeans': 'KMeans',
31+
'lda': 'LDA',
32+
'linear-learner': 'LinearLearner',
33+
'ntm': 'NTM',
34+
'pca': 'PCA',
35+
'randomcutforest': 'RandomCutForest',
36+
}
37+
2338

2439
class _ParameterRange(object):
2540
__all_types__ = ['Continuous', 'Categorical', 'Integer']
@@ -66,8 +81,11 @@ def __init__(self, min_value, max_value):
6681

6782

6883
class HyperparameterTuner(object):
69-
SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
7084
SAGEMAKER_ESTIMATOR_MODULE = 'sagemaker_estimator_module'
85+
SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
86+
87+
DEFAULT_ESTIMATOR_MODULE = 'sagemaker.estimator'
88+
DEFAULT_ESTIMATOR_CLS_NAME = 'Estimator'
7189

7290
def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metric_definitions, strategy='Bayesian',
7391
objective_type='Maximize', max_jobs=1, max_parallel_jobs=1, base_tuning_job_name=None):
@@ -100,8 +118,8 @@ def prepare_for_training(self):
100118
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
101119

102120
if not isinstance(self.estimator, AmazonAlgorithmEstimatorBase):
103-
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = self.estimator.__class__.__name__
104-
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = self.estimator.__module__
121+
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = json.dumps(self.estimator.__class__.__name__)
122+
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(self.estimator.__module__)
105123

106124
def fit(self, inputs, job_name=None, **kwargs):
107125
"""Start a hyperparameter tuning job.
@@ -124,6 +142,24 @@ def fit(self, inputs, job_name=None, **kwargs):
124142
self.prepare_for_training()
125143
self.latest_tuning_job = _TuningJob.start_new(self, inputs)
126144

145+
@classmethod
146+
def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estimator_cls=None):
147+
sagemaker_session = sagemaker_session or Session()
148+
149+
if job_details is None:
150+
job_details = sagemaker_session.sagemaker_client\
151+
.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=tuning_job_name)
152+
153+
estimator_cls = cls._prepare_estimator_cls(estimator_cls, job_details['TrainingJobDefinition'])
154+
estimator = cls._prepare_estimator_from_job_description(estimator_cls, job_details['TrainingJobDefinition'],
155+
sagemaker_session)
156+
init_params = cls._prepare_init_params_from_job_description(job_details)
157+
158+
tuner = cls(estimator=estimator, **init_params)
159+
tuner.latest_tuning_job = _TuningJob(sagemaker_session=sagemaker_session, tuning_job_name=tuning_job_name)
160+
161+
return tuner
162+
127163
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, **kwargs):
128164
"""Deploy the best trained or user specified model to an Amazon SageMaker endpoint and return a
129165
``sagemaker.RealTimePredictor``
@@ -182,6 +218,75 @@ def _ensure_last_tuning_job(self):
182218
if self.latest_tuning_job is None:
183219
raise ValueError('No tuning job available')
184220

221+
@classmethod
222+
def _prepare_estimator_cls(cls, estimator_cls, training_details):
223+
# Check for customer-specified estimator first
224+
if estimator_cls is not None:
225+
module, cls_name = estimator_cls.rsplit('.', 1)
226+
return getattr(importlib.import_module(module), cls_name)
227+
228+
# Then check for estimator class in hyperparameters
229+
hyperparameters = training_details['StaticHyperParameters']
230+
if cls.SAGEMAKER_ESTIMATOR_CLASS_NAME in hyperparameters and cls.SAGEMAKER_ESTIMATOR_MODULE in hyperparameters:
231+
module = hyperparameters.get(cls.SAGEMAKER_ESTIMATOR_MODULE)
232+
cls_name = hyperparameters.get(cls.SAGEMAKER_ESTIMATOR_CLASS_NAME)
233+
return getattr(importlib.import_module(json.loads(module)), json.loads(cls_name))
234+
235+
# Then try to derive the estimator from the image name for 1P algorithms
236+
image_name = training_details['AlgorithmSpecification']['TrainingImage']
237+
algorithm = image_name[image_name.find('/')+1:image_name.find(':')]
238+
if algorithm in AMAZON_ESTIMATOR_CLS_NAMES:
239+
cls_name = AMAZON_ESTIMATOR_CLS_NAMES[algorithm]
240+
return getattr(importlib.import_module(AMAZON_ESTIMATOR_MODULE), cls_name)
241+
242+
# Default to the BYO estimator
243+
return getattr(importlib.import_module(cls.DEFAULT_ESTIMATOR_MODULE), cls.DEFAULT_ESTIMATOR_CLS_NAME)
244+
245+
@classmethod
246+
def _prepare_estimator_from_job_description(cls, estimator_cls, training_details, sagemaker_session):
247+
# Swap name for static hyperparameters to what an estimator would expect
248+
training_details['HyperParameters'] = training_details['StaticHyperParameters']
249+
del training_details['StaticHyperParameters']
250+
251+
# Remove hyperparameter reserved by SageMaker for tuning jobs
252+
del training_details['HyperParameters']['_tuning_objective_metric']
253+
254+
# Add items expected by the estimator (but aren't needed otherwise)
255+
training_details['TrainingJobName'] = ''
256+
if 'KmsKeyId' not in training_details['OutputDataConfig']:
257+
training_details['OutputDataConfig']['KmsKeyId'] = ''
258+
259+
estimator_init_params = estimator_cls._prepare_init_params_from_job_description(training_details)
260+
return estimator_cls(sagemaker_session=sagemaker_session, **estimator_init_params)
261+
262+
@classmethod
263+
def _prepare_init_params_from_job_description(cls, job_details):
264+
tuning_config = job_details['HyperParameterTuningJobConfig']
265+
return {
266+
'metric_definitions': job_details['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'],
267+
'objective_metric_name': tuning_config['HyperParameterTuningJobObjective']['MetricName'],
268+
'objective_type': tuning_config['HyperParameterTuningJobObjective']['Type'],
269+
'hyperparameter_ranges': cls._prepare_parameter_ranges(tuning_config['ParameterRanges']),
270+
'strategy': tuning_config['Strategy'],
271+
'max_jobs': tuning_config['ResourceLimits']['MaxNumberOfTrainingJobs'],
272+
'max_parallel_jobs': tuning_config['ResourceLimits']['MaxParallelTrainingJobs'],
273+
}
274+
275+
@classmethod
276+
def _prepare_parameter_ranges(cls, parameter_ranges):
277+
ranges = {}
278+
279+
for parameter in parameter_ranges['CategoricalParameterRanges']:
280+
ranges[parameter['Name']] = CategoricalParameter(parameter['Values'])
281+
282+
for parameter in parameter_ranges['ContinuousParameterRanges']:
283+
ranges[parameter['Name']] = ContinuousParameter(float(parameter['MinValue']), float(parameter['MaxValue']))
284+
285+
for parameter in parameter_ranges['IntegerParameterRanges']:
286+
ranges[parameter['Name']] = IntegerParameter(int(parameter['MinValue']), int(parameter['MaxValue']))
287+
288+
return ranges
289+
185290
def hyperparameter_ranges(self):
186291
"""Return collections of ``ParameterRanges``
187292

tests/integ/test_tuner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from sagemaker.amazon.kmeans import KMeans
2121
from sagemaker.mxnet.estimator import MXNet
2222
from sagemaker.tuner import IntegerParameter, ContinuousParameter, CategoricalParameter, HyperparameterTuner
23-
from sagemaker.session import s3_input
2423
from tests.integ import DATA_DIR
2524
from tests.integ.timeout import timeout
2625

tests/unit/test_tuner.py

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import copy
16+
import json
17+
1518
import pytest
1619
from mock import Mock
1720

@@ -44,6 +47,80 @@
4447
'blank': CategoricalParameter([0, 5])}
4548
METRIC_DEFINTIONS = 'mock_metric_definitions'
4649

50+
TUNING_JOB_DETAILS = {
51+
'HyperParameterTuningJobConfig': {
52+
'ResourceLimits': {
53+
'MaxParallelTrainingJobs': 1,
54+
'MaxNumberOfTrainingJobs': 1
55+
},
56+
'HyperParameterTuningJobObjective': {
57+
'MetricName': OBJECTIVE_METRIC_NAME,
58+
'Type': 'Minimize'
59+
},
60+
'Strategy': 'Bayesian',
61+
'ParameterRanges': {
62+
'CategoricalParameterRanges': [],
63+
'ContinuousParameterRanges': [],
64+
'IntegerParameterRanges': [
65+
{
66+
'MaxValue': '100',
67+
'Name': 'mini_batch_size',
68+
'MinValue': '10',
69+
},
70+
]
71+
}
72+
},
73+
'HyperParameterTuningJobName': JOB_NAME,
74+
'TrainingJobDefinition': {
75+
'RoleArn': ROLE,
76+
'StaticHyperParameters': {
77+
'num_components': '1',
78+
'_tuning_objective_metric': 'train:throughput',
79+
'feature_dim': '784',
80+
'sagemaker_estimator_module': '"sagemaker.amazon.pca"',
81+
'sagemaker_estimator_class_name': '"PCA"',
82+
},
83+
'ResourceConfig': {
84+
'VolumeSizeInGB': 30,
85+
'InstanceType': 'ml.c4.xlarge',
86+
'InstanceCount': 1
87+
},
88+
'AlgorithmSpecification': {
89+
'TrainingImage': IMAGE_NAME,
90+
'TrainingInputMode': 'File',
91+
'MetricDefinitions': METRIC_DEFINTIONS,
92+
},
93+
'InputDataConfig': [
94+
{
95+
'ChannelName': 'train',
96+
'DataSource': {
97+
'S3DataSource': {
98+
'S3DataDistributionType': 'ShardedByS3Key',
99+
'S3Uri': INPUTS,
100+
'S3DataType': 'ManifestFile'
101+
}
102+
}
103+
}
104+
],
105+
'StoppingCondition': {
106+
'MaxRuntimeInSeconds': 86400
107+
},
108+
'OutputDataConfig': {
109+
'S3OutputPath': BUCKET_NAME,
110+
}
111+
},
112+
'TrainingJobCounters': {
113+
'ClientError': 0,
114+
'Completed': 1,
115+
'InProgress': 0,
116+
'Fault': 0,
117+
'Stopped': 0
118+
},
119+
'HyperParameterTuningEndTime': 1526605831.0,
120+
'CreationTime': 1526605605.0,
121+
'HyperParameterTuningJobArn': 'arn:tuning_job',
122+
}
123+
47124

48125
@pytest.fixture()
49126
def sagemaker_session():
@@ -73,8 +150,11 @@ def test_prepare_for_training(tuner):
73150

74151
assert len(tuner.static_hyperparameters) == 3
75152
assert tuner.static_hyperparameters['another_one'] == '0'
76-
assert tuner.static_hyperparameters['sagemaker_estimator_class_name'] == tuner.estimator.__class__.__name__
77-
assert tuner.static_hyperparameters['sagemaker_estimator_module'] == tuner.estimator.__module__
153+
154+
class_name = json.dumps(tuner.estimator.__class__.__name__)
155+
assert tuner.static_hyperparameters['sagemaker_estimator_class_name'] == class_name
156+
module = json.dumps(tuner.estimator.__module__)
157+
assert tuner.static_hyperparameters['sagemaker_estimator_module'] == module
78158

79159

80160
def test_validate_parameter_ranges_number_validation_error(sagemaker_session):
@@ -127,6 +207,68 @@ def test_fit_1p(sagemaker_session, tuner):
127207
assert tune_kwargs['job_name'].startswith('pca')
128208

129209

210+
def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session):
211+
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
212+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job',
213+
return_value=job_details)
214+
tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session)
215+
216+
assert tuner.latest_tuning_job.name == JOB_NAME
217+
assert tuner.objective_metric_name == OBJECTIVE_METRIC_NAME
218+
assert tuner.max_jobs == 1
219+
assert tuner.max_parallel_jobs == 1
220+
assert tuner.metric_definitions == METRIC_DEFINTIONS
221+
assert tuner.strategy == 'Bayesian'
222+
assert tuner.objective_type == 'Minimize'
223+
224+
assert isinstance(tuner.estimator, PCA)
225+
assert tuner.estimator.role == ROLE
226+
assert tuner.estimator.train_instance_count == 1
227+
assert tuner.estimator.train_max_run == 24 * 60 * 60
228+
assert tuner.estimator.input_mode == 'File'
229+
assert tuner.estimator.output_path == BUCKET_NAME
230+
assert tuner.estimator.output_kms_key == ''
231+
232+
assert '_tuning_objective_metric' not in tuner.estimator.hyperparameters()
233+
assert tuner.estimator.hyperparameters()['num_components'] == '1'
234+
235+
236+
def test_attach_tuning_job_with_job_details(sagemaker_session):
237+
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
238+
HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session, job_details=job_details)
239+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.assert_not_called
240+
241+
242+
def test_attach_tuning_job_with_estimator_from_image(sagemaker_session):
243+
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
244+
job_details['TrainingJobDefinition']['AlgorithmSpecification']['TrainingImage'] = '1111.amazonaws.com/pca:1'
245+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job',
246+
return_value=job_details)
247+
248+
tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session)
249+
assert isinstance(tuner.estimator, PCA)
250+
251+
252+
def test_attach_tuning_job_with_estimator_from_kwarg(sagemaker_session):
253+
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
254+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job',
255+
return_value=job_details)
256+
tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session,
257+
estimator_cls='sagemaker.estimator.Estimator')
258+
assert isinstance(tuner.estimator, Estimator)
259+
260+
261+
def test_attach_with_no_specified_estimator(sagemaker_session):
262+
job_details = copy.deepcopy(TUNING_JOB_DETAILS)
263+
del job_details['TrainingJobDefinition']['StaticHyperParameters']['sagemaker_estimator_module']
264+
del job_details['TrainingJobDefinition']['StaticHyperParameters']['sagemaker_estimator_class_name']
265+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(name='describe_tuning_job',
266+
return_value=job_details)
267+
268+
tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session)
269+
assert isinstance(tuner.estimator, Estimator)
270+
271+
130272
def test_serialize_parameter_ranges(tuner):
131273
hyperparameter_ranges = tuner.hyperparameter_ranges()
132274

0 commit comments

Comments
 (0)