Skip to content

Commit 9b9272b

Browse files
authored
Minimal changes to demonstrate creation of HPO job. (aws#8)
1 parent 22d3d07 commit 9b9272b

File tree

9 files changed

+330
-19
lines changed

9 files changed

+330
-19
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def train_image(self):
5353
def hyperparameters(self):
5454
return hp.serialize_all(self)
5555

56+
def hpo_hyperparameters(self):
57+
return hp.serialize_all_hpo(self)
58+
5659
@property
5760
def data_location(self):
5861
return self._data_location

src/sagemaker/amazon/hyperparameter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,5 @@ def serialize_all_hpo(obj):
112112
for range_type in _HpoParameter.__all_types__:
113113
parameter_range = [param.as_hpo_range(p_name)
114114
for p_name, param in obj._hpo_parameters.items() if param.__name__ == range_type]
115-
parameter_ranges[range_type+'ParameterRange'] = parameter_range
115+
parameter_ranges[range_type+'ParameterRanges'] = parameter_range
116116
return parameter_ranges

src/sagemaker/amazon/kmeans.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,32 @@ def __init__(self, role, train_instance_count, train_instance_type, k, init_meth
8989
the score shall be reported in terms of all requested metrics.
9090
**kwargs: base class keyword argument values.
9191
"""
92-
super(KMeans, self).__init__(role, train_instance_count, train_instance_type, **kwargs)
92+
# TODO: shouldn't be defined here, delete this once HPO fixes validation
93+
metric_definitions = [
94+
{
95+
"Name": "test:msd",
96+
"Regex": "#quality_metric: host=\\S+, test msd <loss>=(\\S+)"
97+
},
98+
{
99+
"Name": "test:ssd",
100+
"Regex": "#quality_metric: host=\\S+, test ssd <loss>=(\\S+)"
101+
},
102+
{
103+
"Name": "train:msd",
104+
"Regex": "#quality_metric: host=\\S+, train msd <loss>=(\\S+)"
105+
},
106+
{
107+
"Name": "train:progress",
108+
"Regex": "#progress_metric: host=\\S+, completed (\\S+) %"
109+
},
110+
# updated below basing on current log format
111+
{
112+
"Name": "train:throughput",
113+
"Regex": "#throughput_metric: train throughput in records/second: (\\S+)"
114+
}
115+
]
116+
super(KMeans, self).__init__(role, train_instance_count, train_instance_type,
117+
metric_definitions=metric_definitions, **kwargs)
93118
self.k = k
94119
self.init_method = init_method
95120
self.max_iterations = max_iterations

src/sagemaker/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
4343

4444
def __init__(self, role, train_instance_count, train_instance_type,
4545
train_volume_size=30, train_max_run=24 * 60 * 60, input_mode='File',
46-
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None):
46+
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None,
47+
metric_definitions=None):
4748
"""Initialize an ``EstimatorBase`` instance.
4849
4950
Args:
@@ -72,6 +73,7 @@ def __init__(self, role, train_instance_count, train_instance_type,
7273
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
7374
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
7475
using the default AWS configuration chain.
76+
metric_definitions (list[dict]): Metrics definition with 'name' and 'regex' keys.
7577
"""
7678
self.role = role
7779
self.train_instance_count = train_instance_count
@@ -95,6 +97,7 @@ def __init__(self, role, train_instance_count, train_instance_type,
9597
self.output_path = output_path
9698
self.output_kms_key = output_kms_key
9799
self.latest_training_job = None
100+
self.metric_definitions = metric_definitions
98101

99102
@abstractmethod
100103
def train_image(self):

src/sagemaker/hpo.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313

14+
from sagemaker.estimator import _TrainingJob
15+
from sagemaker.utils import base_name_from_image, name_from_base
16+
1417

1518
class _HpoParameter(object):
1619
__all_types__ = ['Continuous', 'Categorical', 'Integer']
@@ -21,7 +24,6 @@ def __init__(self, min_value, max_value):
2124

2225
def as_hpo_range(self, name):
2326
return {'Name': name,
24-
'Type': self.__name__,
2527
'MinValue': str(self.min_value),
2628
'MaxValue': str(self.max_value)}
2729

@@ -44,7 +46,6 @@ def __init__(self, values):
4446

4547
def as_hpo_range(self, name):
4648
return {'Name': name,
47-
'Type': self.__name__,
4849
'Values': self.values}
4950

5051

@@ -53,3 +54,80 @@ class IntegerParameter(_HpoParameter):
5354

5455
def __init__(self, min_value, max_value):
5556
super(IntegerParameter, self).__init__(min_value, max_value)
57+
58+
59+
class HyperparameterTuner(object):
60+
__objectives__ = ['Minimize', 'Maximize']
61+
62+
def __init__(self, objective='Maximize', max_jobs=1, max_parallel_jobs=1):
63+
if objective not in HyperparameterTuner.__objectives__:
64+
raise ValueError("Unsupported 'objective' values")
65+
self.strategy = 'Bayesian'
66+
self.objective = objective
67+
self.max_jobs = max_jobs
68+
self.max_parallel_jobs = max_parallel_jobs
69+
70+
def tune(self, estimator, inputs, metric_name): # ,hyperparameters
71+
# self.optimize_hp = hyperparameters
72+
self.optimize_metric_name = metric_name
73+
self.estimator = estimator
74+
75+
self.latest_tuning_job = _TuningJob.start_new(self, inputs)
76+
77+
78+
class _TuningJob(_TrainingJob):
79+
def __init__(self, sagemaker_session, tuning_job_name):
80+
self.sagemaker_session = sagemaker_session
81+
self.tuning_job_name = tuning_job_name
82+
83+
@classmethod
84+
def start_new(cls, tuner, inputs):
85+
"""Create a new Amazon SageMaker HPO tuning job from the HyperparameterTuner.
86+
87+
Args:
88+
tuner (sagemaker.hpo.HyperparameterTuner): Tuner object created by the user.
89+
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
90+
91+
Returns:
92+
sagemaker.hpo._TuningJob: Constructed object that captures all information about the started job.
93+
"""
94+
95+
input_config = _TrainingJob._format_inputs_to_input_config(inputs)
96+
role = tuner.estimator.sagemaker_session.expand_role(tuner.estimator.role)
97+
output_config = _TrainingJob._prepare_output_config(tuner.estimator.output_path, tuner.estimator.output_kms_key)
98+
resource_config = _TrainingJob._prepare_resource_config(tuner.estimator.train_instance_count,
99+
tuner.estimator.train_instance_type,
100+
tuner.estimator.train_volume_size)
101+
stop_condition = _TrainingJob._prepare_stopping_condition(tuner.estimator.train_max_run)
102+
103+
if tuner.estimator.hyperparameters() is None:
104+
raise ValueError('Cannot tune estimator without hyperparameters')
105+
106+
# TODO: current code path only works for 1P, update with somthiong like this for others?
107+
# split hyperparameters defined in estimator into static and hpo-controlled parts
108+
# static_hp = {str(k): str(v) for (k, v) in tuner.estimator.hyperparameters().items()}
109+
# for hp_name in tuner.optimize_hp.keys():
110+
# del static_hp[hp_name]
111+
112+
# make sure the job name is unique for each invocation, honor supplied base_job_name or generate it
113+
# TODO: shall the tuner have separate logic/code for the base name?
114+
base_name = tuner.estimator.base_job_name or base_name_from_image(tuner.estimator.train_image())
115+
hpo_job_name = name_from_base(base_name)
116+
117+
tuner.estimator.sagemaker_session.tune(job_name=hpo_job_name, strategy=tuner.strategy,
118+
objective=tuner.objective, metric_name=tuner.optimize_metric_name,
119+
max_jobs=tuner.max_jobs, max_parallel_jobs=tuner.max_parallel_jobs,
120+
parameter_ranges=tuner.estimator.hpo_hyperparameters(),
121+
static_hp=tuner.estimator.hyperparameters(),
122+
image=tuner.estimator.train_image(),
123+
input_mode=tuner.estimator.input_mode,
124+
metric_definitions=tuner.estimator.metric_definitions,
125+
role=role, input_config=input_config,
126+
output_config=output_config, resource_config=resource_config,
127+
stop_condition=stop_condition)
128+
129+
return cls(tuner.estimator.sagemaker_session, hpo_job_name)
130+
131+
@property
132+
def name(self):
133+
return self.tuning_job_name

src/sagemaker/session.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,82 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
243243
LOGGER.debug('train request: {}'.format(json.dumps(train_request, indent=4)))
244244
self.sagemaker_client.create_training_job(**train_request)
245245

246+
def tune(self, job_name, strategy, objective, metric_name,
247+
max_jobs, max_parallel_jobs, parameter_ranges,
248+
static_hp, image, input_mode, metric_definitions,
249+
role, input_config, output_config, resource_config, stop_condition):
250+
"""Create an Amazon SageMaker HPO job.
251+
252+
Args:
253+
job_name (str): Name of the tuning job being created.
254+
strategy (str): Strategy to be used.
255+
objective (str): Minimize/Maximize
256+
metric_name (str): Name of the metric to use when evaluating training job.
257+
max_jobs (int): Maximum total number of jobs to start.
258+
max_parallel_jobs (int): Maximum number of parallel jobs to start.
259+
parameter_ranges (dict): Parameter ranges in a dictionary of types: Continuous, Integer, Categorical
260+
static_hp (dict): Hyperparameters for model training. The hyperparameters are made accessible as
261+
a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for
262+
keys and values, but ``str()`` will be called to convert them before training.
263+
image (str): Docker image containing training code.
264+
input_mode (str): The input mode that the algorithm supports. Valid modes:
265+
266+
* 'File' - Amazon SageMaker copies the training dataset from the S3 location to
267+
a directory in the Docker container.
268+
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
269+
metric_definitions (str):
270+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
271+
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
272+
You must grant sufficient permissions to this role.
273+
input_config (list): A list of Channel objects. Each channel is a named input source. Please refer to
274+
the format details described:
275+
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
276+
output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID.
277+
resource_config (dict): Contains values for ResourceConfig:
278+
instance_count (int): Number of EC2 instances to use for training.
279+
instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
280+
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
281+
service like ``MaxRuntimeInSeconds``.
282+
283+
Returns:
284+
285+
"""
286+
287+
tune_request = {
288+
'HyperParameterTuningJobName': job_name,
289+
'HyperParameterTuningJobConfig': {
290+
'Strategy': strategy,
291+
'HyperParameterTuningJobObjective': {
292+
'Type': objective,
293+
'MetricName': metric_name,
294+
},
295+
'ResourceLimits': {
296+
'MaxNumberOfTrainingJobs': max_jobs,
297+
'MaxParallelTrainingJobs': max_parallel_jobs
298+
},
299+
'ParameterRanges': parameter_ranges
300+
},
301+
'TrainingJobDefinition': {
302+
'StaticHyperParameters': static_hp,
303+
'AlgorithmSpecification': {
304+
'TrainingImage': image,
305+
'TrainingInputMode': input_mode
306+
},
307+
'RoleArn': role,
308+
'InputDataConfig': input_config,
309+
'OutputDataConfig': output_config,
310+
'ResourceConfig': resource_config,
311+
'StoppingCondition': stop_condition,
312+
}
313+
}
314+
315+
if metric_definitions is not None:
316+
tune_request['TrainingJobDefinition']['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions
317+
318+
LOGGER.info('Creating tuning-job with name: {}'.format(job_name))
319+
LOGGER.debug('tune request: {}'.format(json.dumps(tune_request, indent=4)))
320+
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
321+
246322
def create_model(self, name, role, primary_container):
247323
"""Create an Amazon SageMaker ``Model``.
248324
@@ -800,6 +876,29 @@ def _train_done(sagemaker_client, job_name):
800876
return desc
801877

802878

879+
def _tune_done(sagemaker_client, job_name):
880+
tuning_status_codes = {
881+
'Completed': '!',
882+
'InProgress': '.',
883+
'Failed': '*',
884+
'Stopped': 's',
885+
'Stopping': '_'
886+
}
887+
in_progress_statuses = ['InProgress', 'Stopping']
888+
889+
desc = sagemaker_client.describe_hyper_parameter_tuning_job(HyperParameterTuningJobName=job_name)
890+
status = desc['HyperParameterTuningJobStatus']
891+
892+
print(tuning_status_codes.get(status, '?'), end='')
893+
sys.stdout.flush()
894+
895+
if status in in_progress_statuses:
896+
return None
897+
898+
print('')
899+
return desc
900+
901+
803902
def _deploy_done(sagemaker_client, endpoint_name):
804903
hosting_status_codes = {
805904
"OutOfService": "x",

tests/integ/test_hpo.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import gzip
14+
import os
15+
import pickle
16+
import sys
17+
18+
from sagemaker.amazon.kmeans import KMeans
19+
from sagemaker.hpo import IntegerParameter, ContinuousParameter, CategoricalParameter, HyperparameterTuner
20+
from sagemaker.session import s3_input
21+
22+
from tests.integ import DATA_DIR
23+
24+
25+
def test_hpo(sagemaker_session):
26+
tuner = HyperparameterTuner(objective='Minimize', max_jobs=8, max_parallel_jobs=2)
27+
28+
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
29+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
30+
31+
# Load the data into memory as numpy arrays
32+
with gzip.open(data_path, 'rb') as f:
33+
train_set, _, _ = pickle.load(f, **pickle_args)
34+
35+
kmeans = KMeans(role='SageMakerRole', train_instance_count=1,
36+
train_instance_type='ml.c4.xlarge',
37+
k=10, sagemaker_session=sagemaker_session, base_job_name='tk',
38+
output_path='s3://{}/'.format(sagemaker_session.default_bucket()))
39+
40+
# set kmeans specific hp
41+
kmeans.init_method = 'random'
42+
kmeans.max_iterators = 1
43+
kmeans.tol = 1
44+
kmeans.num_trials = 1
45+
kmeans.local_init_method = 'kmeans++'
46+
kmeans.half_life_time_size = 1
47+
kmeans.epochs = 1
48+
49+
records = kmeans.record_set(train_set[0][:100])
50+
51+
# TODO: this is done during fit() need to refactor that
52+
kmeans.mini_batch_size = 5000
53+
kmeans.feature_dim = records.feature_dim
54+
55+
# specify which hp you want to optimize over
56+
kmeans.center_factor = IntegerParameter(1, 10)
57+
kmeans.mini_batch_size = IntegerParameter(10, 100)
58+
kmeans.tol = ContinuousParameter(1.0, 2.0)
59+
kmeans.local_init_method = CategoricalParameter(['kmeans++', 'random'])
60+
61+
data = {records.channel: s3_input(records.s3_data, distribution='ShardedByS3Key',
62+
s3_data_type=records.s3_data_type)}
63+
tuner.tune(kmeans, data, 'train:throughput')
64+
65+
print ('Started HPO job with name:' + tuner.latest_tuning_job.name)

0 commit comments

Comments
 (0)