Skip to content

Commit b6d5abf

Browse files
author
Ignacio Quintero
committed
Fix and add new Unit Tests, Flake 8
1 parent 1c6a8a6 commit b6d5abf

File tree

5 files changed

+113
-32
lines changed

5 files changed

+113
-32
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import print_function, absolute_import
1414

15-
import os
1615
import json
1716
import logging
1817
from abc import ABCMeta
@@ -382,7 +381,8 @@ def _format_string_uri_input(input):
382381
elif input.startswith('file://'):
383382
return file_input(input)
384383
else:
385-
raise ValueError('Training input data must be a valid S3 or FILE URI and must start with "s3://" or "file://"')
384+
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
385+
'"file://"')
386386
elif isinstance(input, s3_input):
387387
return input
388388
elif isinstance(input, file_input):

src/sagemaker/local/local_session.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -179,42 +179,25 @@ def logs_for_job(self, job_name, wait=False, poll=5):
179179
# on local mode.
180180
pass
181181

182-
# TODO Naming consistent with session.s3_input. May want to change both
183-
# (e.g. S3Input and FileInput)
182+
184183
class file_input(object):
185184
"""Amazon SageMaker channel configuration for FILE data sources, used in local mode.
186185
187186
Attributes:
188-
config (dict[str, dict]): A SageMaker ``DataSource`` referencing a SageMaker ``S3DataSource``.
187+
config (dict[str, dict]): A SageMaker ``DataSource`` referencing a SageMaker ``FileDataSource``.
189188
"""
190189

191-
def __init__(self, fileUri):
190+
def __init__(self, fileUri, content_type=None):
192191
"""Create a definition for input data used by an SageMaker training job in local mode.
193192
"""
194-
195-
"""TODO Keeping this consistent with s3_input data structure. May be
196-
better to have a Type key under DataSource, but that really would mess
197-
with the standard implementation....
198-
"""
199-
200193
self.config = {
201194
'DataSource': {
202195
'FileDataSource': {
203-
# TODO Ok to hardcode this here or allow input?
204196
'FileDataDistributionType': 'FullyReplicated',
205197
'FileUri': fileUri
206198
}
207199
}
208200
}
209201

210-
# As per docs, leave unset in FILE mode
211-
# if compression is not None:
212-
# self.config['CompressionType'] = compression
213-
214-
# if content_type is not None:
215-
# self.config['ContentType'] = content_type
216-
217-
# As per docs, leave unset in FILE mode
218-
# if record_wrapping is not None:
219-
# self.config['RecordWrapperType'] = record_wrapping
220-
202+
if content_type is not None:
203+
self.config['ContentType'] = content_type

tests/unit/test_estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_sagemaker_s3_uri_invalid(sagemaker_session):
101101
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
102102
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)
103103
t.fit('thisdoesntstartwiths3')
104-
assert 'must be a valid S3 URI' in str(error)
104+
assert 'must be a valid S3 or FILE URI' in str(error)
105105

106106

107107
@patch('time.strftime', return_value=TIMESTAMP)
@@ -427,9 +427,8 @@ def test_unsupported_type():
427427

428428

429429
def test_unsupported_type_in_dict():
430-
with pytest.raises(ValueError) as error:
430+
with pytest.raises(ValueError):
431431
_TrainingJob._format_inputs_to_input_config({'a': 66})
432-
assert 'Expecting one of str or s3_input' in str(error)
433432

434433

435434
#################################################################################

tests/unit/test_image.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
{
2828
'ChannelName': 'a',
2929
'DataSource': {
30-
'S3DataSource': {
31-
'S3DataDistributionType': 'FullyReplicated',
32-
'S3DataType': 'S3Prefix',
33-
'S3Uri': '/tmp/source1'
30+
'FileDataSource': {
31+
'FileDataDistributionType': 'FullyReplicated',
32+
'FileUri': 'file:///tmp/source1'
3433
}
3534
}
3635
},

tests/unit/test_local_session.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,15 @@ def test_create_training_job(train, LocalSession):
3535
image = "my-docker-image:1.0"
3636

3737
algo_spec = {'TrainingImage': image}
38-
input_data_config = {}
38+
input_data_config = [{
39+
'ChannelName': 'a',
40+
'DataSource': {
41+
'S3DataSource': {
42+
'S3DataDistributionType': 'FullyReplicated',
43+
'S3Uri': 's3://my_bucket/tmp/source1'
44+
}
45+
}
46+
}]
3947
output_data_config = {}
4048
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
4149
hyperparameters = {'a': 1, 'b': 'bee'}
@@ -61,6 +69,67 @@ def test_create_training_job(train, LocalSession):
6169
assert response['ModelArtifacts']['S3ModelArtifacts'] == expected['ModelArtifacts']['S3ModelArtifacts']
6270

6371

72+
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
73+
@patch('sagemaker.local.local_session.LocalSession')
74+
def test_create_training_job_invalid_data_source(train, LocalSession):
75+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
76+
77+
instance_count = 2
78+
image = "my-docker-image:1.0"
79+
80+
algo_spec = {'TrainingImage': image}
81+
82+
# InvalidDataSource is not supported. S3DataSource and FileDataSource are currently the only
83+
# valid Data Sources. We expect a ValueError if we pass this input data config.
84+
input_data_config = [{
85+
'ChannelName': 'a',
86+
'DataSource': {
87+
'InvalidDataSource': {
88+
'FileDataDistributionType': 'FullyReplicated',
89+
'FileUri': 'ftp://myserver.com/tmp/source1'
90+
}
91+
}
92+
}]
93+
94+
output_data_config = {}
95+
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
96+
hyperparameters = {'a': 1, 'b': 'bee'}
97+
98+
with pytest.raises(ValueError):
99+
local_sagemaker_client.create_training_job("my-training-job", algo_spec, 'arn:my-role', input_data_config,
100+
output_data_config, resource_config, None, hyperparameters)
101+
102+
103+
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
104+
@patch('sagemaker.local.local_session.LocalSession')
105+
def test_create_training_job_not_fully_replicated(train, LocalSession):
106+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
107+
108+
instance_count = 2
109+
image = "my-docker-image:1.0"
110+
111+
algo_spec = {'TrainingImage': image}
112+
113+
# Local Mode only supports FullyReplicated as Data Distribution type.
114+
input_data_config = [{
115+
'ChannelName': 'a',
116+
'DataSource': {
117+
'S3DataSource': {
118+
'S3DataDistributionType': 'ShardedByS3Key',
119+
'S3Uri': 's3://my_bucket/tmp/source1'
120+
}
121+
}
122+
}]
123+
124+
output_data_config = {}
125+
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
126+
hyperparameters = {'a': 1, 'b': 'bee'}
127+
128+
with pytest.raises(RuntimeError):
129+
local_sagemaker_client.create_training_job("my-training-job", algo_spec, 'arn:my-role', input_data_config,
130+
output_data_config, resource_config, None, hyperparameters)
131+
132+
64133
@patch('sagemaker.local.local_session.LocalSession')
65134
def test_create_model(LocalSession):
66135
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
@@ -130,3 +199,34 @@ def test_create_endpoint_fails(serve, request, LocalSession):
130199

131200
with pytest.raises(RuntimeError):
132201
local_sagemaker_client.create_endpoint('my-endpoint', 'some-endpoint-config')
202+
203+
204+
def test_file_input_all_defaults():
205+
prefix = 'pre'
206+
actual = sagemaker.local.local_session.file_input(fileUri=prefix)
207+
expected = \
208+
{
209+
'DataSource': {
210+
'FileDataSource': {
211+
'FileDataDistributionType': 'FullyReplicated',
212+
'FileUri': prefix
213+
}
214+
}
215+
}
216+
assert actual.config == expected
217+
218+
219+
def test_file_input_content_type():
220+
prefix = 'pre'
221+
actual = sagemaker.local.local_session.file_input(fileUri=prefix, content_type='text/csv')
222+
expected = \
223+
{
224+
'DataSource': {
225+
'FileDataSource': {
226+
'FileDataDistributionType': 'FullyReplicated',
227+
'FileUri': prefix
228+
}
229+
},
230+
'ContentType': 'text/csv'
231+
}
232+
assert actual.config == expected

0 commit comments

Comments
 (0)