Skip to content

Commit 163bffd

Browse files
nadiayajesterhazy
authored andcommitted
Make InputDataConfig optional for training. (#459)
* Make InputDataConfig optional for training. * Update boto3 dependency to make sure boto support no InputDataConfig. * Update changelog. * Add missing assertion for chainer failure script test.
1 parent 868f81b commit 163bffd

File tree

12 files changed

+82
-33
lines changed

12 files changed

+82
-33
lines changed

CHANGELOG.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
CHANGELOG
33
=========
44

5+
1.13.1.dev
6+
==========
7+
8+
* feature: Estimator: make input channels optional
9+
10+
511
1.13.0
612
======
713

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def read(fname):
5353
],
5454

5555
# Declare minimal set for installation
56-
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
56+
install_requires=['boto3>=1.9.38', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0',
5757
'urllib3 >=1.21, <1.23',
5858
'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5', 'docker-compose>=1.21.0'],
5959

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _prepare_for_training(self, job_name=None):
176176
else:
177177
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
178178

179-
def fit(self, inputs, wait=True, logs=True, job_name=None):
179+
def fit(self, inputs=None, wait=True, logs=True, job_name=None):
180180
"""Train a model using the input training dataset.
181181
182182
The API calls the Amazon SageMaker CreateTrainingJob API to start model training.

src/sagemaker/job.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def _load_config(inputs, estimator):
6464

6565
model_channel = _Job._prepare_model_channel(input_config, estimator.model_uri, estimator.model_channel_name)
6666
if model_channel:
67+
input_config = [] if input_config is None else input_config
6768
input_config.append(model_channel)
6869

6970
return {'input_config': input_config,
@@ -75,6 +76,9 @@ def _load_config(inputs, estimator):
7576

7677
@staticmethod
7778
def _format_inputs_to_input_config(inputs):
79+
if inputs is None:
80+
return None
81+
7882
# Deferred import due to circular dependency
7983
from sagemaker.amazon.amazon_estimator import RecordSet
8084
if isinstance(inputs, RecordSet):
@@ -130,9 +134,10 @@ def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None
130134
elif not model_channel_name:
131135
raise ValueError('Expected a pre-trained model channel name if a model URL is specified.')
132136

133-
for channel in input_config:
134-
if channel['ChannelName'] == model_channel_name:
135-
raise ValueError('Duplicate channels not allowed.')
137+
if input_config:
138+
for channel in input_config:
139+
if channel['ChannelName'] == model_channel_name:
140+
raise ValueError('Duplicate channels not allowed.')
136141

137142
model_input = _Job._format_model_uri_input(model_uri)
138143
model_channel = _Job._convert_input_to_channel(model_channel_name, model_input)

src/sagemaker/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,16 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
257257
'TrainingImage': image,
258258
'TrainingInputMode': input_mode
259259
},
260-
'InputDataConfig': input_config,
261260
'OutputDataConfig': output_config,
262261
'TrainingJobName': job_name,
263262
'StoppingCondition': stop_condition,
264263
'ResourceConfig': resource_config,
265264
'RoleArn': role,
266265
}
267266

267+
if input_config is not None:
268+
train_request['InputDataConfig'] = input_config
269+
268270
if hyperparameters and len(hyperparameters) > 0:
269271
train_request['HyperParameters'] = hyperparameters
270272

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def _validate_requirements_file(self, requirements_file):
207207
if not os.path.exists(os.path.join(self.source_dir, requirements_file)):
208208
raise ValueError('Requirements file {} does not exist.'.format(requirements_file))
209209

210-
def fit(self, inputs, wait=True, logs=True, job_name=None, run_tensorboard_locally=False):
210+
def fit(self, inputs=None, wait=True, logs=True, job_name=None, run_tensorboard_locally=False):
211211
"""Train a model using the input training dataset.
212212
213213
See :func:`~sagemaker.estimator.EstimatorBase.fit` for more details.

tests/integ/test_chainer_train.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,18 +105,15 @@ def test_async_fit(sagemaker_session):
105105
def test_failed_training_job(sagemaker_session, chainer_full_version):
106106
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
107107
script_path = os.path.join(DATA_DIR, 'chainer_mnist', 'failure_script.py')
108-
data_path = os.path.join(DATA_DIR, 'chainer_mnist')
109108

110109
chainer = Chainer(entry_point=script_path, role='SageMakerRole',
111110
framework_version=chainer_full_version, py_version=PYTHON_VERSION,
112111
train_instance_count=1, train_instance_type='ml.c4.xlarge',
113112
sagemaker_session=sagemaker_session)
114113

115-
train_input = chainer.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
116-
key_prefix='integ-test-data/chainer_mnist/train')
117-
118-
with pytest.raises(ValueError):
119-
chainer.fit(train_input)
114+
with pytest.raises(ValueError) as e:
115+
chainer.fit()
116+
assert 'This failure is expected' in str(e.value)
120117

121118

122119
def _run_mnist_training_job(sagemaker_session, instance_type, instance_count,

tests/integ/test_mxnet_train.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,11 @@ def test_async_fit(sagemaker_session, mxnet_full_version):
105105
def test_failed_training_job(sagemaker_session, mxnet_full_version):
106106
with timeout():
107107
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'failure_script.py')
108-
data_path = os.path.join(DATA_DIR, 'mxnet_mnist')
109108

110109
mx = MXNet(entry_point=script_path, role='SageMakerRole', framework_version=mxnet_full_version,
111110
py_version=PYTHON_VERSION, train_instance_count=1, train_instance_type='ml.c4.xlarge',
112111
sagemaker_session=sagemaker_session)
113112

114-
train_input = mx.sagemaker_session.upload_data(path=os.path.join(data_path, 'train'),
115-
key_prefix='integ-test-data/mxnet_mnist/train-failure')
116-
117113
with pytest.raises(ValueError) as e:
118-
mx.fit(train_input)
114+
mx.fit()
119115
assert 'This failure is expected' in str(e.value)

tests/integ/test_pytorch_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_failed_training_job(sagemaker_session, pytorch_full_version):
106106
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, entry_point=script_path)
107107

108108
with pytest.raises(ValueError) as e:
109-
pytorch.fit(_upload_training_data(pytorch))
109+
pytorch.fit()
110110
assert 'This failure is expected' in str(e.value)
111111

112112

tests/integ/test_tf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,6 @@ def test_failed_tf_training(sagemaker_session, tf_full_version):
160160
train_instance_type='ml.c4.xlarge',
161161
sagemaker_session=sagemaker_session)
162162

163-
inputs = estimator.sagemaker_session.upload_data(path=DATA_PATH, key_prefix='integ-test-data/tf-failure')
164-
165163
with pytest.raises(ValueError) as e:
166-
estimator.fit(inputs)
164+
estimator.fit()
167165
assert 'This failure is expected' in str(e.value)

tests/unit/test_estimator.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -706,19 +706,10 @@ def test_unsupported_type_in_dict():
706706
#################################################################################
707707
# Tests for the generic Estimator class
708708

709-
BASE_TRAIN_CALL = {
709+
NO_INPUT_TRAIN_CALL = {
710710
'hyperparameters': {},
711711
'image': IMAGE_NAME,
712-
'input_config': [{
713-
'DataSource': {
714-
'S3DataSource': {
715-
'S3DataDistributionType': 'FullyReplicated',
716-
'S3DataType': 'S3Prefix',
717-
'S3Uri': 's3://bucket/training-prefix'
718-
}
719-
},
720-
'ChannelName': 'train'
721-
}],
712+
'input_config': None,
722713
'input_mode': 'File',
723714
'output_config': {'S3OutputPath': OUTPUT_PATH},
724715
'resource_config': {
@@ -731,12 +722,43 @@ def test_unsupported_type_in_dict():
731722
'vpc_config': None
732723
}
733724

725+
INPUT_CONFIG = [{
726+
'DataSource': {
727+
'S3DataSource': {
728+
'S3DataDistributionType': 'FullyReplicated',
729+
'S3DataType': 'S3Prefix',
730+
'S3Uri': 's3://bucket/training-prefix'
731+
}
732+
},
733+
'ChannelName': 'train'
734+
}]
735+
736+
BASE_TRAIN_CALL = dict(NO_INPUT_TRAIN_CALL)
737+
BASE_TRAIN_CALL.update({'input_config': INPUT_CONFIG})
738+
734739
HYPERPARAMS = {'x': 1, 'y': 'hello'}
735740
STRINGIFIED_HYPERPARAMS = dict([(x, str(y)) for x, y in HYPERPARAMS.items()])
736741
HP_TRAIN_CALL = dict(BASE_TRAIN_CALL)
737742
HP_TRAIN_CALL.update({'hyperparameters': STRINGIFIED_HYPERPARAMS})
738743

739744

745+
def test_generic_to_fit_no_input(sagemaker_session):
746+
e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
747+
sagemaker_session=sagemaker_session)
748+
749+
e.fit()
750+
751+
sagemaker_session.train.assert_called_once()
752+
assert len(sagemaker_session.train.call_args[0]) == 0
753+
args = sagemaker_session.train.call_args[1]
754+
assert args['job_name'].startswith(IMAGE_NAME)
755+
756+
args.pop('job_name')
757+
args.pop('role')
758+
759+
assert args == NO_INPUT_TRAIN_CALL
760+
761+
740762
def test_generic_to_fit_no_hps(sagemaker_session):
741763
e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
742764
sagemaker_session=sagemaker_session)

tests/unit/test_job.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,29 @@ def test_load_config_with_model_channel(estimator):
8686
assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME
8787

8888

89+
def test_load_config_with_model_channel_no_inputs(estimator):
90+
estimator.model_uri = MODEL_URI
91+
estimator.model_channel_name = CHANNEL_NAME
92+
93+
config = _Job._load_config(inputs=None, estimator=estimator)
94+
95+
assert config['input_config'][0]['DataSource']['S3DataSource']['S3Uri'] == MODEL_URI
96+
assert config['input_config'][0]['ChannelName'] == CHANNEL_NAME
97+
assert config['role'] == ROLE
98+
assert config['output_config']['S3OutputPath'] == S3_OUTPUT_PATH
99+
assert 'KmsKeyId' not in config['output_config']
100+
assert config['resource_config']['InstanceCount'] == INSTANCE_COUNT
101+
assert config['resource_config']['InstanceType'] == INSTANCE_TYPE
102+
assert config['resource_config']['VolumeSizeInGB'] == VOLUME_SIZE
103+
assert config['stop_condition']['MaxRuntimeInSeconds'] == MAX_RUNTIME
104+
105+
106+
def test_format_inputs_none():
107+
channels = _Job._format_inputs_to_input_config(inputs=None)
108+
109+
assert channels is None
110+
111+
89112
def test_format_inputs_to_input_config_string():
90113
inputs = BUCKET_NAME
91114

0 commit comments

Comments
 (0)