Skip to content

Commit e656965

Browse files
author
Ignacio Quintero
committed
Fix and add new Unit Tests, Flake 8
Also fix some TODOs here and there.
1 parent 1c6a8a6 commit e656965

File tree

6 files changed

+114
-39
lines changed

6 files changed

+114
-39
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import print_function, absolute_import
1414

15-
import os
1615
import json
1716
import logging
1817
from abc import ABCMeta
@@ -22,18 +21,15 @@
2221
from sagemaker.fw_utils import tar_and_upload_dir
2322
from sagemaker.fw_utils import parse_s3_url
2423
from sagemaker.fw_utils import UploadedCode
25-
2624
from sagemaker.local.local_session import LocalSession, file_input
2725

2826
from sagemaker.model import Model
2927
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
3028
CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME)
3129

3230
from sagemaker.predictor import RealTimePredictor
33-
3431
from sagemaker.session import Session
3532
from sagemaker.session import s3_input
36-
3733
from sagemaker.utils import base_name_from_image, name_from_base
3834

3935

@@ -382,7 +378,8 @@ def _format_string_uri_input(input):
382378
elif input.startswith('file://'):
383379
return file_input(input)
384380
else:
385-
raise ValueError('Training input data must be a valid S3 or FILE URI and must start with "s3://" or "file://"')
381+
raise ValueError('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
382+
'"file://"')
386383
elif isinstance(input, s3_input):
387384
return input
388385
elif isinstance(input, file_input):

src/sagemaker/local/image.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def train(self, input_data_config, hyperparameters):
9393
# mount the local directory to the container. For S3 Data we will download the S3 data
9494
# first.
9595
for channel in input_data_config:
96-
9796
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
9897
uri = channel['DataSource']['S3DataSource']['S3Uri']
9998
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
@@ -112,9 +111,7 @@ def train(self, input_data_config, hyperparameters):
112111
bucket_name = parsed_uri.netloc
113112
self._download_folder(bucket_name, key, channel_dir)
114113
elif parsed_uri.scheme == 'file':
115-
# TODO Check why this is file:/... and not file:///...
116-
# TODO use the parsed_uri.xxx and use os.path.join
117-
path = uri.lstrip('file:')
114+
path = parsed_uri.path
118115
volumes.append(_Volume(path, channel=channel_name))
119116
else:
120117
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))

src/sagemaker/local/local_session.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -179,42 +179,25 @@ def logs_for_job(self, job_name, wait=False, poll=5):
179179
# on local mode.
180180
pass
181181

182-
# TODO Naming consistent with session.s3_input. May want to change both
183-
# (e.g. S3Input and FileInput)
182+
184183
class file_input(object):
185184
"""Amazon SageMaker channel configuration for FILE data sources, used in local mode.
186185
187186
Attributes:
188-
config (dict[str, dict]): A SageMaker ``DataSource`` referencing a SageMaker ``S3DataSource``.
187+
config (dict[str, dict]): A SageMaker ``DataSource`` referencing a SageMaker ``FileDataSource``.
189188
"""
190189

191-
def __init__(self, fileUri):
190+
def __init__(self, fileUri, content_type=None):
192191
"""Create a definition for input data used by an SageMaker training job in local mode.
193192
"""
194-
195-
"""TODO Keeping this consistent with s3_input data structure. May be
196-
better to have a Type key under DataSource, but that really would mess
197-
with the standard implementation....
198-
"""
199-
200193
self.config = {
201194
'DataSource': {
202195
'FileDataSource': {
203-
# TODO Ok to hardcode this here or allow input?
204196
'FileDataDistributionType': 'FullyReplicated',
205197
'FileUri': fileUri
206198
}
207199
}
208200
}
209201

210-
# As per docs, leave unset in FILE mode
211-
# if compression is not None:
212-
# self.config['CompressionType'] = compression
213-
214-
# if content_type is not None:
215-
# self.config['ContentType'] = content_type
216-
217-
# As per docs, leave unset in FILE mode
218-
# if record_wrapping is not None:
219-
# self.config['RecordWrapperType'] = record_wrapping
220-
202+
if content_type is not None:
203+
self.config['ContentType'] = content_type

tests/unit/test_estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def test_sagemaker_s3_uri_invalid(sagemaker_session):
101101
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
102102
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)
103103
t.fit('thisdoesntstartwiths3')
104-
assert 'must be a valid S3 URI' in str(error)
104+
assert 'must be a valid S3 or FILE URI' in str(error)
105105

106106

107107
@patch('time.strftime', return_value=TIMESTAMP)
@@ -427,9 +427,8 @@ def test_unsupported_type():
427427

428428

429429
def test_unsupported_type_in_dict():
430-
with pytest.raises(ValueError) as error:
430+
with pytest.raises(ValueError):
431431
_TrainingJob._format_inputs_to_input_config({'a': 66})
432-
assert 'Expecting one of str or s3_input' in str(error)
433432

434433

435434
#################################################################################

tests/unit/test_image.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@
2727
{
2828
'ChannelName': 'a',
2929
'DataSource': {
30-
'S3DataSource': {
31-
'S3DataDistributionType': 'FullyReplicated',
32-
'S3DataType': 'S3Prefix',
33-
'S3Uri': '/tmp/source1'
30+
'FileDataSource': {
31+
'FileDataDistributionType': 'FullyReplicated',
32+
'FileUri': 'file:///tmp/source1'
3433
}
3534
}
3635
},

tests/unit/test_local_session.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,15 @@ def test_create_training_job(train, LocalSession):
3535
image = "my-docker-image:1.0"
3636

3737
algo_spec = {'TrainingImage': image}
38-
input_data_config = {}
38+
input_data_config = [{
39+
'ChannelName': 'a',
40+
'DataSource': {
41+
'S3DataSource': {
42+
'S3DataDistributionType': 'FullyReplicated',
43+
'S3Uri': 's3://my_bucket/tmp/source1'
44+
}
45+
}
46+
}]
3947
output_data_config = {}
4048
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
4149
hyperparameters = {'a': 1, 'b': 'bee'}
@@ -61,6 +69,67 @@ def test_create_training_job(train, LocalSession):
6169
assert response['ModelArtifacts']['S3ModelArtifacts'] == expected['ModelArtifacts']['S3ModelArtifacts']
6270

6371

72+
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
73+
@patch('sagemaker.local.local_session.LocalSession')
74+
def test_create_training_job_invalid_data_source(train, LocalSession):
75+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
76+
77+
instance_count = 2
78+
image = "my-docker-image:1.0"
79+
80+
algo_spec = {'TrainingImage': image}
81+
82+
# InvalidDataSource is not supported. S3DataSource and FileDataSource are currently the only
83+
# valid Data Sources. We expect a ValueError if we pass this input data config.
84+
input_data_config = [{
85+
'ChannelName': 'a',
86+
'DataSource': {
87+
'InvalidDataSource': {
88+
'FileDataDistributionType': 'FullyReplicated',
89+
'FileUri': 'ftp://myserver.com/tmp/source1'
90+
}
91+
}
92+
}]
93+
94+
output_data_config = {}
95+
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
96+
hyperparameters = {'a': 1, 'b': 'bee'}
97+
98+
with pytest.raises(ValueError):
99+
local_sagemaker_client.create_training_job("my-training-job", algo_spec, 'arn:my-role', input_data_config,
100+
output_data_config, resource_config, None, hyperparameters)
101+
102+
103+
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
104+
@patch('sagemaker.local.local_session.LocalSession')
105+
def test_create_training_job_not_fully_replicated(train, LocalSession):
106+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
107+
108+
instance_count = 2
109+
image = "my-docker-image:1.0"
110+
111+
algo_spec = {'TrainingImage': image}
112+
113+
# Local Mode only supports FullyReplicated as Data Distribution type.
114+
input_data_config = [{
115+
'ChannelName': 'a',
116+
'DataSource': {
117+
'S3DataSource': {
118+
'S3DataDistributionType': 'ShardedByS3Key',
119+
'S3Uri': 's3://my_bucket/tmp/source1'
120+
}
121+
}
122+
}]
123+
124+
output_data_config = {}
125+
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
126+
hyperparameters = {'a': 1, 'b': 'bee'}
127+
128+
with pytest.raises(RuntimeError):
129+
local_sagemaker_client.create_training_job("my-training-job", algo_spec, 'arn:my-role', input_data_config,
130+
output_data_config, resource_config, None, hyperparameters)
131+
132+
64133
@patch('sagemaker.local.local_session.LocalSession')
65134
def test_create_model(LocalSession):
66135
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
@@ -130,3 +199,34 @@ def test_create_endpoint_fails(serve, request, LocalSession):
130199

131200
with pytest.raises(RuntimeError):
132201
local_sagemaker_client.create_endpoint('my-endpoint', 'some-endpoint-config')
202+
203+
204+
def test_file_input_all_defaults():
205+
prefix = 'pre'
206+
actual = sagemaker.local.local_session.file_input(fileUri=prefix)
207+
expected = \
208+
{
209+
'DataSource': {
210+
'FileDataSource': {
211+
'FileDataDistributionType': 'FullyReplicated',
212+
'FileUri': prefix
213+
}
214+
}
215+
}
216+
assert actual.config == expected
217+
218+
219+
def test_file_input_content_type():
220+
prefix = 'pre'
221+
actual = sagemaker.local.local_session.file_input(fileUri=prefix, content_type='text/csv')
222+
expected = \
223+
{
224+
'DataSource': {
225+
'FileDataSource': {
226+
'FileDataDistributionType': 'FullyReplicated',
227+
'FileUri': prefix
228+
}
229+
},
230+
'ContentType': 'text/csv'
231+
}
232+
assert actual.config == expected

0 commit comments

Comments
 (0)