Skip to content

Commit 42974a2

Browse files
authored
Add support for hyperparameter tuning jobs
1 parent afb3bbe commit 42974a2

35 files changed

+2816
-372
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ doc/_templates
2424
venv/
2525
*~
2626
.pytest_cache/
27+
*.swp

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
========
77

88
* bug-fix: Estimators: Change max_iterations hyperparameter key for KMeans
9+
* feature: Analytics functions for metrics in Training and HyperparameterTuning jobs
910

1011
1.3.0
1112
=====

src/sagemaker/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor
2323
from sagemaker.amazon.randomcutforest import RandomCutForest, RandomCutForestModel, RandomCutForestPredictor
2424

25+
from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics
2526
from sagemaker.local.local_session import LocalSession
2627

2728
from sagemaker.model import Model
@@ -39,4 +40,5 @@
3940
'FactorizationMachines', 'FactorizationMachinesModel', 'FactorizationMachinesPredictor',
4041
'RandomCutForest', 'RandomCutForestModel', 'RandomCutForestPredictor',
4142
'Model', 'NTM', 'NTMModel', 'NTMPredictor', 'RealTimePredictor', 'Session', 'LocalSession',
43+
'TrainingJobAnalytics', 'HyperparameterTuningJobAnalytics',
4244
'container_def', 's3_input', 'production_variant', 'get_execution_role']

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sagemaker.amazon import validation
2020
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2121
from sagemaker.amazon.common import write_numpy_to_dense_tensor
22-
from sagemaker.estimator import EstimatorBase
22+
from sagemaker.estimator import EstimatorBase, _TrainingJob
2323
from sagemaker.session import s3_input
2424
from sagemaker.utils import sagemaker_timestamp
2525

@@ -92,11 +92,38 @@ def _prepare_init_params_from_job_description(cls, job_details):
9292
del init_params['image']
9393
return init_params
9494

95-
def fit(self, records, mini_batch_size=None, **kwargs):
95+
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
96+
"""Set hyperparameters needed for training.
97+
98+
Args:
99+
* records (:class:`~RecordSet`): The records to train this ``Estimator`` on.
100+
* mini_batch_size (int or None): The size of each mini-batch to use when training. If ``None``, a
101+
default value will be used.
102+
* job_name (str): Name of the training job to be created. If not specified, one is generated,
103+
using the base name given to the constructor if applicable.
104+
"""
105+
super(AmazonAlgorithmEstimatorBase, self)._prepare_for_training(job_name=job_name)
106+
107+
feature_dim = None
108+
109+
if isinstance(records, list):
110+
for record in records:
111+
if record.channel == 'train':
112+
feature_dim = record.feature_dim
113+
break
114+
if feature_dim is None:
115+
raise ValueError('Must provide train channel.')
116+
else:
117+
feature_dim = records.feature_dim
118+
119+
self.feature_dim = feature_dim
120+
self.mini_batch_size = mini_batch_size
121+
122+
def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None):
96123
"""Fit this Estimator on serialized Record objects, stored in S3.
97124
98125
``records`` should be an instance of :class:`~RecordSet`. This defines a collection of
99-
s3 data files to train this ``Estimator`` on.
126+
S3 data files to train this ``Estimator`` on.
100127
101128
Training data is expected to be encoded as dense or sparse vectors in the "values" feature
102129
on each Record. If the data is labeled, the label is expected to be encoded as a list of
@@ -110,15 +137,19 @@ def fit(self, records, mini_batch_size=None, **kwargs):
110137
111138
Args:
112139
records (:class:`~RecordSet`): The records to train this ``Estimator`` on
113-
mini_batch_size (int or None): The size of each mini-batch to use when training. If None, a
140+
mini_batch_size (int or None): The size of each mini-batch to use when training. If ``None``, a
114141
default value will be used.
142+
wait (bool): Whether the call should wait until the job completes (default: True).
143+
logs (bool): Whether to show the logs produced by the job.
144+
Only meaningful when wait is True (default: True).
145+
job_name (str): Training job name. If not specified, the estimator generates a default job name,
146+
based on the training image name and current timestamp.
115147
"""
116-
self.feature_dim = records.feature_dim
117-
self.mini_batch_size = mini_batch_size
148+
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
118149

119-
data = {records.channel: s3_input(records.s3_data, distribution='ShardedByS3Key',
120-
s3_data_type=records.s3_data_type)}
121-
super(AmazonAlgorithmEstimatorBase, self).fit(data, **kwargs)
150+
self.latest_training_job = _TrainingJob.start_new(self, records)
151+
if wait:
152+
self.latest_training_job.wait(logs=logs)
122153

123154
def record_set(self, train, labels=None, channel="train"):
124155
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
@@ -180,6 +211,14 @@ def __repr__(self):
180211
"""Return an unambiguous representation of this RecordSet"""
181212
return str((RecordSet, self.__dict__))
182213

214+
def data_channel(self):
215+
"""Return a dictionary to represent the training data in a channel for use with ``fit()``"""
216+
return {self.channel: self.records_s3_input()}
217+
218+
def records_s3_input(self):
219+
"""Return a s3_input to represent the training data"""
220+
return s3_input(self.s3_data, distribution='ShardedByS3Key', s3_data_type=self.s3_data_type)
221+
183222

184223
def _build_shards(num_shards, array):
185224
if num_shards < 1:

src/sagemaker/amazon/hyperparameter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def validate(self, value):
4646
raise ValueError(error_message)
4747

4848
def __get__(self, obj, objtype):
49-
"""Return the value of this hyperparameter"""
5049
if '_hyperparameters' not in dir(obj) or self.name not in obj._hyperparameters:
5150
raise AttributeError()
5251
return obj._hyperparameters[self.name]

src/sagemaker/amazon/kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def create_model(self):
108108
s3 model data produced by this Estimator."""
109109
return KMeansModel(self.model_data, self.role, self.sagemaker_session)
110110

111-
def fit(self, records, mini_batch_size=5000, **kwargs):
112-
super(KMeans, self).fit(records, mini_batch_size, **kwargs)
111+
def _prepare_for_training(self, records, mini_batch_size=5000, job_name=None):
112+
super(KMeans, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
113113

114114
def hyperparameters(self):
115115
"""Return the SageMaker hyperparameters for training this KMeans Estimator"""

src/sagemaker/amazon/lda.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,12 @@ def create_model(self):
9393

9494
return LDAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
9595

96-
def fit(self, records, mini_batch_size, **kwargs):
96+
def _prepare_for_training(self, records, mini_batch_size, job_name=None):
9797
# mini_batch_size is required, prevent explicit calls with None
9898
if mini_batch_size is None:
9999
raise ValueError("mini_batch_size must be set")
100-
super(LDA, self).fit(records, mini_batch_size, **kwargs)
100+
101+
super(LDA, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
101102

102103

103104
class LDAPredictor(RealTimePredictor):

src/sagemaker/amazon/linear_learner.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,23 @@ def create_model(self):
228228

229229
return LinearLearnerModel(self.model_data, self.role, self.sagemaker_session)
230230

231-
def fit(self, records, mini_batch_size=None, **kwargs):
231+
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
232+
num_records = None
233+
if isinstance(records, list):
234+
for record in records:
235+
if record.channel == 'train':
236+
num_records = record.num_records
237+
break
238+
if num_records is None:
239+
raise ValueError('Must provide train channel.')
240+
else:
241+
num_records = records.num_records
242+
232243
# mini_batch_size can't be greater than number of records or training job fails
233244
default_mini_batch_size = min(self.DEFAULT_MINI_BATCH_SIZE,
234-
max(1, int(records.num_records / self.train_instance_count)))
245+
max(1, int(num_records / self.train_instance_count)))
235246
use_mini_batch_size = mini_batch_size or default_mini_batch_size
236-
super(LinearLearner, self).fit(records, use_mini_batch_size, **kwargs)
247+
super(LinearLearner, self)._prepare_for_training(records, mini_batch_size=use_mini_batch_size, job_name=job_name)
237248

238249

239250
class LinearLearnerPredictor(RealTimePredictor):

src/sagemaker/amazon/ntm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def create_model(self):
113113

114114
return NTMModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
115115

116-
def fit(self, records, mini_batch_size=None, **kwargs):
116+
def _prepare_for_training(self, records, mini_batch_size, job_name=None):
117117
if mini_batch_size is not None and (mini_batch_size < 1 or mini_batch_size > 10000):
118118
raise ValueError("mini_batch_size must be in [1, 10000]")
119-
super(NTM, self).fit(records, mini_batch_size, **kwargs)
119+
super(NTM, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
120120

121121

122122
class NTMPredictor(RealTimePredictor):

src/sagemaker/amazon/pca.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,12 +92,33 @@ def create_model(self):
9292

9393
return PCAModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
9494

95-
def fit(self, records, mini_batch_size=None, **kwargs):
95+
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
96+
"""Set hyperparameters needed for training.
97+
98+
Args:
99+
* records (:class:`~RecordSet`): The records to train this ``Estimator`` on.
100+
* mini_batch_size (int or None): The size of each mini-batch to use when training. If ``None``, a
101+
default value will be used.
102+
* job_name (str): Name of the training job to be created. If not specified, one is generated,
103+
using the base name given to the constructor if applicable.
104+
"""
105+
num_records = None
106+
if isinstance(records, list):
107+
for record in records:
108+
if record.channel == 'train':
109+
num_records = record.num_records
110+
break
111+
if num_records is None:
112+
raise ValueError('Must provide train channel.')
113+
else:
114+
num_records = records.num_records
115+
96116
# mini_batch_size is a required parameter
97117
default_mini_batch_size = min(self.DEFAULT_MINI_BATCH_SIZE,
98-
max(1, int(records.num_records / self.train_instance_count)))
118+
max(1, int(num_records / self.train_instance_count)))
99119
use_mini_batch_size = mini_batch_size or default_mini_batch_size
100-
super(PCA, self).fit(records, use_mini_batch_size, **kwargs)
120+
121+
super(PCA, self)._prepare_for_training(records=records, mini_batch_size=use_mini_batch_size, job_name=job_name)
101122

102123

103124
class PCAPredictor(RealTimePredictor):

src/sagemaker/amazon/randomcutforest.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,11 @@ def create_model(self):
8787

8888
return RandomCutForestModel(self.model_data, self.role, sagemaker_session=self.sagemaker_session)
8989

90-
def fit(self, records, mini_batch_size=None, **kwargs):
91-
if mini_batch_size is None:
92-
mini_batch_size = RandomCutForest.MINI_BATCH_SIZE
93-
elif mini_batch_size != RandomCutForest.MINI_BATCH_SIZE:
90+
def _prepare_for_training(self, records, mini_batch_size=MINI_BATCH_SIZE, job_name=None):
91+
if mini_batch_size != self.MINI_BATCH_SIZE:
9492
raise ValueError("Random Cut Forest uses a fixed mini_batch_size of {}"
95-
.format(RandomCutForest.MINI_BATCH_SIZE))
96-
super(RandomCutForest, self).fit(records, mini_batch_size, **kwargs)
93+
.format(self.MINI_BATCH_SIZE))
94+
super(RandomCutForest, self)._prepare_for_training(records, mini_batch_size=mini_batch_size, job_name=job_name)
9795

9896

9997
class RandomCutForestPredictor(RealTimePredictor):

0 commit comments

Comments
 (0)