Skip to content

Commit 42f9de8

Browse files
authored
Add option for not including estimator metadata in hyperparameter tuning job (#237)
Using an Amazon ML algorithm with the generic Estimator revealed that class can't be used with an algorithm that won't accept extra (unrecognized) hyperparameters. Since that generic class was created primarily for use with the Amazon ML algorithms that we don't have custom estimators for, this change adds a new kwarg for not injecting the estimator class and module in a hyperparameter tuning job. This also includes an integ test for the BYO estimator case.
1 parent 6ed1a77 commit 42f9de8

File tree

5 files changed

+116
-7
lines changed

5 files changed

+116
-7
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ CHANGELOG
55
1.4.3dev
66
========
77
* feature: Allow Local Serving of Models in S3
8+
* enhancement: Allow option for ``HyperparameterTuner`` to not include estimator metadata in job
89

910

1011
1.4.2

README.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,14 @@ In addition, the ``fit()`` call uses a list of ``RecordSet`` objects instead of
321321
# Start hyperparameter tuning job
322322
my_tuner.fit([train_records, test_records])
323323
324+
To aid with attaching a previously-started hyperparameter tuning job with a ``HyperparameterTuner`` instance, ``fit()`` injects metadata in the hyperparameters by default.
325+
If the algorithm you are using cannot handle unknown hyperparameters (e.g. an Amazon ML algorithm that does not have a custom estimator in the Python SDK), then you can set ``include_cls_metadata`` to ``False`` when calling fit:
326+
327+
.. code:: python
328+
329+
my_tuner.fit({'train': 's3://my_bucket/my_training_data', 'test': 's3://my_bucket_my_testing_data'},
330+
include_cls_metadata=False)
331+
324332
There is also an analytics object associated with each ``HyperparameterTuner`` instance that presents useful information about the hyperparameter tuning job.
325333
For example, the ``dataframe`` method gets a pandas dataframe summarizing the associated training jobs:
326334

src/sagemaker/tuner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
204204
self._current_job_name = None
205205
self.latest_tuning_job = None
206206

207-
def _prepare_for_training(self, job_name=None):
207+
def _prepare_for_training(self, job_name=None, include_cls_metadata=True):
208208
if job_name is not None:
209209
self._current_job_name = job_name
210210
else:
@@ -217,12 +217,12 @@ def _prepare_for_training(self, job_name=None):
217217

218218
# For attach() to know what estimator to use for non-1P algorithms
219219
# (1P algorithms don't accept extra hyperparameters)
220-
if not isinstance(self.estimator, AmazonAlgorithmEstimatorBase):
220+
if include_cls_metadata and not isinstance(self.estimator, AmazonAlgorithmEstimatorBase):
221221
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_CLASS_NAME] = json.dumps(
222222
self.estimator.__class__.__name__)
223223
self.static_hyperparameters[self.SAGEMAKER_ESTIMATOR_MODULE] = json.dumps(self.estimator.__module__)
224224

225-
def fit(self, inputs, job_name=None, **kwargs):
225+
def fit(self, inputs, job_name=None, include_cls_metadata=True, **kwargs):
226226
"""Start a hyperparameter tuning job.
227227
228228
Args:
@@ -253,7 +253,7 @@ def fit(self, inputs, job_name=None, **kwargs):
253253
else:
254254
self.estimator._prepare_for_training(job_name)
255255

256-
self._prepare_for_training(job_name=job_name)
256+
self._prepare_for_training(job_name=job_name, include_cls_metadata=include_cls_metadata)
257257
self.latest_tuning_job = _TuningJob.start_new(self, inputs)
258258

259259
@classmethod

tests/integ/test_tuner.py

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,24 @@
1313
from __future__ import absolute_import
1414

1515
import gzip
16+
import io
17+
import json
1618
import os
1719
import pickle
1820
import sys
1921
import time
2022

23+
import boto3
2124
import numpy as np
2225
import pytest
2326

24-
from sagemaker import LDA, RandomCutForest
25-
from sagemaker.amazon.common import read_records
26-
from sagemaker.amazon.kmeans import KMeans
27+
from sagemaker import KMeans, LDA, RandomCutForest
28+
from sagemaker.amazon.amazon_estimator import registry
29+
from sagemaker.amazon.common import read_records, write_numpy_to_dense_tensor
2730
from sagemaker.chainer import Chainer
31+
from sagemaker.estimator import Estimator
2832
from sagemaker.mxnet.estimator import MXNet
33+
from sagemaker.predictor import json_deserializer
2934
from sagemaker.tensorflow import TensorFlow
3035
from sagemaker.tuner import IntegerParameter, ContinuousParameter, CategoricalParameter, HyperparameterTuner
3136
from tests.integ import DATA_DIR
@@ -307,3 +312,83 @@ def test_tuning_chainer(sagemaker_session):
307312
data = np.zeros((batch_size, 28, 28), dtype='float32')
308313
output = predictor.predict(data)
309314
assert len(output) == batch_size
315+
316+
317+
@pytest.mark.continuous_testing
318+
def test_tuning_byo_estimator(sagemaker_session):
319+
"""Use Factorization Machines algorithm as an example here.
320+
321+
First we need to prepare data for training. We take standard data set, convert it to the
322+
format that the algorithm can process and upload it to S3.
323+
Then we create the Estimator and set hyperparamets as required by the algorithm.
324+
Next, we can call fit() with path to the S3.
325+
Later the trained model is deployed and prediction is called against the endpoint.
326+
Default predictor is updated with json serializer and deserializer.
327+
"""
328+
image_name = registry(sagemaker_session.boto_session.region_name) + '/factorization-machines:1'
329+
330+
with timeout(minutes=15):
331+
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
332+
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
333+
334+
with gzip.open(data_path, 'rb') as f:
335+
train_set, _, _ = pickle.load(f, **pickle_args)
336+
337+
# take 100 examples for faster execution
338+
vectors = np.array([t.tolist() for t in train_set[0][:100]]).astype('float32')
339+
labels = np.where(np.array([t.tolist() for t in train_set[1][:100]]) == 0, 1.0, 0.0).astype('float32')
340+
341+
buf = io.BytesIO()
342+
write_numpy_to_dense_tensor(buf, vectors, labels)
343+
buf.seek(0)
344+
345+
bucket = sagemaker_session.default_bucket()
346+
prefix = 'test_byo_estimator'
347+
key = 'recordio-pb-data'
348+
boto3.resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train', key)).upload_fileobj(buf)
349+
s3_train_data = 's3://{}/{}/train/{}'.format(bucket, prefix, key)
350+
351+
estimator = Estimator(image_name=image_name,
352+
role='SageMakerRole', train_instance_count=1,
353+
train_instance_type='ml.c4.xlarge',
354+
sagemaker_session=sagemaker_session, base_job_name='test-byo')
355+
356+
estimator.set_hyperparameters(num_factors=10,
357+
feature_dim=784,
358+
mini_batch_size=100,
359+
predictor_type='binary_classifier')
360+
361+
hyperparameter_ranges = {'mini_batch_size': IntegerParameter(100, 200)}
362+
363+
tuner = HyperparameterTuner(estimator=estimator, base_tuning_job_name='byo',
364+
objective_metric_name='test:binary_classification_accuracy',
365+
hyperparameter_ranges=hyperparameter_ranges,
366+
max_jobs=2, max_parallel_jobs=2)
367+
368+
tuner.fit({'train': s3_train_data, 'test': s3_train_data}, include_cls_metadata=False)
369+
370+
print('Started hyperparameter tuning job with name:' + tuner.latest_tuning_job.name)
371+
372+
time.sleep(15)
373+
tuner.wait()
374+
375+
best_training_job = tuner.best_training_job()
376+
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
377+
predictor = tuner.deploy(1, 'ml.m4.xlarge', endpoint_name=best_training_job)
378+
predictor.serializer = _fm_serializer
379+
predictor.content_type = 'application/json'
380+
predictor.deserializer = json_deserializer
381+
382+
result = predictor.predict(train_set[0][:10])
383+
384+
assert len(result['predictions']) == 10
385+
for prediction in result['predictions']:
386+
assert prediction['score'] is not None
387+
388+
389+
# Serializer for the Factorization Machines predictor (for BYO example)
390+
def _fm_serializer(data):
391+
js = {'instances': []}
392+
for row in data:
393+
js['instances'].append({'features': row.tolist()})
394+
return json.dumps(js)

tests/unit/test_tuner.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,21 @@ def test_prepare_for_training(tuner):
159159
assert tuner.static_hyperparameters['sagemaker_estimator_module'] == module
160160

161161

162+
def test_prepare_for_training_with_amazon_estimator(tuner, sagemaker_session):
163+
tuner.estimator = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
164+
sagemaker_session=sagemaker_session)
165+
166+
tuner._prepare_for_training()
167+
assert 'sagemaker_estimator_class_name' not in tuner.static_hyperparameters
168+
assert 'sagemaker_estimator_module' not in tuner.static_hyperparameters
169+
170+
171+
def test_prepare_for_training_dont_include_estimator_cls(tuner):
172+
tuner._prepare_for_training(include_cls_metadata=False)
173+
assert 'sagemaker_estimator_class_name' not in tuner.static_hyperparameters
174+
assert 'sagemaker_estimator_module' not in tuner.static_hyperparameters
175+
176+
162177
def test_prepare_for_training_with_job_name(tuner):
163178
static_hyperparameters = {'validated': 1, 'another_one': 0}
164179
tuner.estimator.set_hyperparameters(**static_hyperparameters)

0 commit comments

Comments
 (0)