Skip to content

Commit f9838b6

Browse files
Fix multiple channel (aws#45)
1 parent f9c460a commit f9838b6

File tree

9 files changed

+132
-18
lines changed

9 files changed

+132
-18
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from sagemaker.amazon import validation
2020
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2121
from sagemaker.amazon.common import write_numpy_to_dense_tensor
22-
from sagemaker.estimator import EstimatorBase
22+
from sagemaker.estimator import EstimatorBase, _TrainingJob
2323
from sagemaker.session import s3_input
2424
from sagemaker.utils import sagemaker_timestamp
2525

@@ -104,10 +104,22 @@ def prepare_for_training(self, records, mini_batch_size=None, job_name=None):
104104
"""
105105
super(AmazonAlgorithmEstimatorBase, self).prepare_for_training(job_name=job_name)
106106

107-
self.feature_dim = records.feature_dim
107+
feature_dim = None
108+
109+
if isinstance(records, list):
110+
for record in records:
111+
if record.channel == 'train':
112+
feature_dim = record.feature_dim
113+
break
114+
if feature_dim is None:
115+
raise ValueError('Must provide train channel.')
116+
else:
117+
feature_dim = records.feature_dim
118+
119+
self.feature_dim = feature_dim
108120
self.mini_batch_size = mini_batch_size
109121

110-
def fit(self, records, mini_batch_size=None, **kwargs):
122+
def fit(self, records, mini_batch_size=None, wait=True, logs=True, job_name=None):
111123
"""Fit this Estimator on serialized Record objects, stored in S3.
112124
113125
``records`` should be an instance of :class:`~RecordSet`. This defines a collection of
@@ -127,9 +139,17 @@ def fit(self, records, mini_batch_size=None, **kwargs):
127139
records (:class:`~RecordSet`): The records to train this ``Estimator`` on
128140
mini_batch_size (int or None): The size of each mini-batch to use when training. If ``None``, a
129141
default value will be used.
142+
wait (bool): Whether the call should wait until the job completes (default: True).
143+
logs (bool): Whether to show the logs produced by the job.
144+
Only meaningful when wait is True (default: True).
145+
job_name (str): Training job name. If not specified, the estimator generates a default job name,
146+
based on the training image name and current timestamp.
130147
"""
131-
super(AmazonAlgorithmEstimatorBase, self).fit(records.data_channel(), records=records,
132-
mini_batch_size=mini_batch_size, **kwargs)
148+
self.prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
149+
150+
self.latest_training_job = _TrainingJob.start_new(self, records)
151+
if wait:
152+
self.latest_training_job.wait(logs=logs)
133153

134154
def record_set(self, train, labels=None, channel="train"):
135155
"""Build a :class:`~RecordSet` from a numpy :class:`~ndarray` matrix and label vector.
@@ -193,8 +213,11 @@ def __repr__(self):
193213

194214
def data_channel(self):
195215
"""Return a dictionary to represent the training data in a channel for use with ``fit()``"""
196-
return {self.channel: s3_input(self.s3_data, distribution='ShardedByS3Key',
197-
s3_data_type=self.s3_data_type)}
216+
return {self.channel: self.records_s3_input()}
217+
218+
def records_s3_input(self):
219+
"""Return a s3_input to represent the training data"""
220+
return s3_input(self.s3_data, distribution='ShardedByS3Key', s3_data_type=self.s3_data_type)
198221

199222

200223
def _build_shards(num_shards, array):

src/sagemaker/amazon/pca.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,20 @@ def prepare_for_training(self, records, mini_batch_size=None, job_name=None):
102102
* job_name (str): Name of the training job to be created. If not specified, one is generated,
103103
using the base name given to the constructor if applicable.
104104
"""
105+
num_records = None
106+
if isinstance(records, list):
107+
for record in records:
108+
if record.channel == 'train':
109+
num_records = record.num_records
110+
break
111+
if num_records is None:
112+
raise ValueError('Must provide train channel.')
113+
else:
114+
num_records = records.num_records
115+
105116
# mini_batch_size is a required parameter
106117
default_mini_batch_size = min(self.DEFAULT_MINI_BATCH_SIZE,
107-
max(1, int(records.num_records / self.train_instance_count)))
118+
max(1, int(num_records / self.train_instance_count)))
108119
use_mini_batch_size = mini_batch_size or default_mini_batch_size
109120

110121
super(PCA, self).prepare_for_training(records=records, mini_batch_size=use_mini_batch_size, job_name=job_name)

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def prepare_for_training(self, job_name=None):
144144
else:
145145
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
146146

147-
def fit(self, inputs, wait=True, logs=True, job_name=None, **kwargs):
147+
def fit(self, inputs, wait=True, logs=True, job_name=None):
148148
"""Train a model using the input training dataset.
149149
150150
The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
@@ -172,7 +172,7 @@ def fit(self, inputs, wait=True, logs=True, job_name=None, **kwargs):
172172
job_name (str): Training job name. If not specified, the estimator generates a default job name,
173173
based on the training image name and current timestamp.
174174
"""
175-
self.prepare_for_training(job_name=job_name, **kwargs)
175+
self.prepare_for_training(job_name=job_name)
176176

177177
self.latest_training_job = _TrainingJob.start_new(self, inputs)
178178
if wait:

src/sagemaker/job.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def _load_config(inputs, estimator):
6868

6969
@staticmethod
7070
def _format_inputs_to_input_config(inputs):
71+
# Circular dependency
72+
from sagemaker.amazon.amazon_estimator import RecordSet
73+
if isinstance(inputs, RecordSet):
74+
inputs = inputs.data_channel()
75+
7176
input_dict = {}
7277
if isinstance(inputs, string_types):
7378
input_dict['training'] = _Job._format_string_uri_input(inputs)
@@ -78,6 +83,15 @@ def _format_inputs_to_input_config(inputs):
7883
elif isinstance(inputs, dict):
7984
for k, v in inputs.items():
8085
input_dict[k] = _Job._format_string_uri_input(v)
86+
elif isinstance(inputs, list):
87+
for record in inputs:
88+
if not isinstance(record, RecordSet):
89+
raise ValueError('List compatible only with RecordSets.')
90+
91+
if record.channel in input_dict:
92+
raise ValueError('Duplicate channels not allowed.')
93+
94+
input_dict[record.channel] = record.records_s3_input()
8195
else:
8296
raise ValueError(
8397
'Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))

src/sagemaker/tuner.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,14 @@ def fit(self, inputs, job_name=None, **kwargs):
133133
134134
Args:
135135
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
136-
job_name (str): Job name
136+
job_name (str): Tuning job name. If not specified, the tuner generates a default job name,
137+
based on the training image name and current timestamp.
137138
**kwargs: Other arguments
138139
"""
139-
# 1P estimators require a RecordSet object
140-
if isinstance(inputs, RecordSet):
140+
if isinstance(inputs, list) or isinstance(inputs, RecordSet):
141141
self.estimator.prepare_for_training(inputs, **kwargs)
142-
inputs = inputs.data_channel()
143142
else:
144-
self.estimator.prepare_for_training(**kwargs)
143+
self.estimator.prepare_for_training(job_name)
145144

146145
self.prepare_for_training(job_name=job_name)
147146
self.latest_tuning_job = _TuningJob.start_new(self, inputs)

tests/integ/test_tuner.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def test_fit_1p(sagemaker_session):
4646
kmeans.half_life_time_size = 1
4747
kmeans.epochs = 1
4848

49-
records = kmeans.record_set(train_set[0][:100], channel='test')
49+
records = kmeans.record_set(train_set[0][:100])
50+
test_records = kmeans.record_set(train_set[0][:100], channel='test')
5051

5152
# specify which hp you want to optimize over
5253
hyperparameter_ranges = {'extra_center_factor': IntegerParameter(1, 10),
@@ -59,7 +60,7 @@ def test_fit_1p(sagemaker_session):
5960
hyperparameter_ranges=hyperparameter_ranges, objective_type='Minimize', max_jobs=2,
6061
max_parallel_jobs=2)
6162

62-
tuner.fit(records)
63+
tuner.fit([records, test_records])
6364

6465
print('Started HPO job with name:' + tuner.latest_tuning_job.name)
6566

tests/unit/test_amazon_estimator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,31 @@ def test_prepare_for_training():
110110
assert pca.mini_batch_size == 1
111111

112112

113+
def test_prepare_for_training_list():
114+
pca = PCA(num_components=55, **COMMON_ARGS)
115+
116+
train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]]
117+
labels = [99, 85, 87, 2]
118+
records = [pca.record_set(np.array(train), np.array(labels))]
119+
120+
pca.prepare_for_training(records, mini_batch_size=1)
121+
assert pca.feature_dim == 3
122+
assert pca.mini_batch_size == 1
123+
124+
125+
def test_prepare_for_training_list_no_train_channel():
126+
pca = PCA(num_components=55, **COMMON_ARGS)
127+
128+
train = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 8.0], [44.0, 55.0, 66.0]]
129+
labels = [99, 85, 87, 2]
130+
records = [pca.record_set(np.array(train), np.array(labels), 'test')]
131+
132+
with pytest.raises(ValueError) as ex:
133+
pca.prepare_for_training(records, mini_batch_size=1)
134+
135+
assert 'Must provide train channel.' in str(ex)
136+
137+
113138
@patch('time.strftime', return_value=TIMESTAMP)
114139
def test_fit_ndarray(time, sagemaker_session):
115140
mock_s3 = Mock()

tests/unit/test_job.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616
from mock import Mock
1717

18+
from sagemaker.amazon.amazon_estimator import RecordSet
1819
from sagemaker.estimator import Estimator
1920
from sagemaker.job import _Job
2021
from sagemaker.session import s3_input
@@ -86,6 +87,45 @@ def test_format_inputs_to_input_config_dict():
8687
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs['train']
8788

8889

90+
def test_format_inputs_to_input_config_record_set():
91+
inputs = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
92+
93+
channels = _Job._format_inputs_to_input_config(inputs)
94+
95+
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == inputs.s3_data
96+
assert channels[0]['DataSource']['S3DataSource']['S3DataType'] == inputs.s3_data_type
97+
98+
99+
def test_format_inputs_to_input_config_list():
100+
records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
101+
inputs = [records]
102+
103+
channels = _Job._format_inputs_to_input_config(inputs)
104+
105+
assert channels[0]['DataSource']['S3DataSource']['S3Uri'] == records.s3_data
106+
assert channels[0]['DataSource']['S3DataSource']['S3DataType'] == records.s3_data_type
107+
108+
109+
def test_format_inputs_to_input_config_list_not_all_records():
110+
records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
111+
inputs = [records, 'mock']
112+
113+
with pytest.raises(ValueError) as ex:
114+
_Job._format_inputs_to_input_config(inputs)
115+
116+
assert 'List compatible only with RecordSets.' in str(ex)
117+
118+
119+
def test_format_inputs_to_input_config_list_duplicate_channel():
120+
record = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
121+
inputs = [record, record]
122+
123+
with pytest.raises(ValueError) as ex:
124+
_Job._format_inputs_to_input_config(inputs)
125+
126+
assert 'Duplicate channels not allowed.' in str(ex)
127+
128+
89129
def test_format_input_single_unamed_channel():
90130
input_dict = _Job._format_inputs_to_input_config('s3://blah/blah')
91131
assert input_dict == [{

tests/unit/test_tuner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,14 +207,15 @@ def test_fit_1p(sagemaker_session, tuner):
207207
tuner._hyperparameter_ranges = hyperparameter_ranges
208208

209209
records = RecordSet(s3_data=INPUTS, num_records=1, feature_dim=1)
210-
tuner.fit(records)
210+
tuner.fit(records, mini_batch_size=9999)
211211

212212
_, _, tune_kwargs = sagemaker_session.tune.mock_calls[0]
213213

214214
assert len(tune_kwargs['static_hyperparameters']) == 4
215215
assert tune_kwargs['static_hyperparameters']['extra_components'] == '5'
216216
assert len(tune_kwargs['parameter_ranges']['IntegerParameterRanges']) == 1
217217
assert tune_kwargs['job_name'].startswith('pca')
218+
assert tuner.estimator.mini_batch_size == 9999
218219

219220

220221
def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session):

0 commit comments

Comments
 (0)