Skip to content

Commit 3481e2b

Browse files
committed
Support optional input channels in local mode.
1 parent 5201c60 commit 3481e2b

File tree

5 files changed

+50
-40
lines changed

5 files changed

+50
-40
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.14.1.dev
6+
==========
7+
8+
* enhancement: Local Mode: support optional input channels
9+
510
1.14.0
611
======
712

src/sagemaker/local/entities.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,23 @@ def __init__(self, container):
4747
self.end_time = None
4848

4949
def start(self, input_data_config, output_data_config, hyperparameters, job_name):
50-
for channel in input_data_config:
51-
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
52-
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
53-
data_uri = channel['DataSource']['S3DataSource']['S3Uri']
54-
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
55-
data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType']
56-
data_uri = channel['DataSource']['FileDataSource']['FileUri']
57-
else:
58-
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
59-
60-
# use a single Data URI - this makes handling S3 and File Data easier down the stack
61-
channel['DataUri'] = data_uri
62-
63-
if data_distribution != 'FullyReplicated':
64-
raise RuntimeError('DataDistribution: %s is not currently supported in Local Mode' %
65-
data_distribution)
50+
if input_data_config:
51+
for channel in input_data_config:
52+
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
53+
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
54+
data_uri = channel['DataSource']['S3DataSource']['S3Uri']
55+
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
56+
data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType']
57+
data_uri = channel['DataSource']['FileDataSource']['FileUri']
58+
else:
59+
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
60+
61+
# use a single Data URI - this makes handling S3 and File Data easier down the stack
62+
channel['DataUri'] = data_uri
63+
64+
if data_distribution != 'FullyReplicated':
65+
raise RuntimeError('DataDistribution: %s is not currently supported in Local Mode' %
66+
data_distribution)
6667

6768
self.start = datetime.datetime.now()
6869
self.state = self._TRAINING

src/sagemaker/local/image.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -262,15 +262,17 @@ def write_config_files(self, host, hyperparameters, input_data_config):
262262
'hosts': self.hosts
263263
}
264264

265-
print(input_data_config)
266265
json_input_data_config = {}
267-
for c in input_data_config:
268-
channel_name = c['ChannelName']
269-
json_input_data_config[channel_name] = {
270-
'TrainingInputMode': 'File'
271-
}
272-
if 'ContentType' in c:
273-
json_input_data_config[channel_name]['ContentType'] = c['ContentType']
266+
if input_data_config:
267+
print(input_data_config)
268+
269+
for c in input_data_config:
270+
channel_name = c['ChannelName']
271+
json_input_data_config[channel_name] = {
272+
'TrainingInputMode': 'File'
273+
}
274+
if 'ContentType' in c:
275+
json_input_data_config[channel_name]['ContentType'] = c['ContentType']
274276

275277
_write_json_file(os.path.join(config_path, 'hyperparameters.json'), hyperparameters)
276278
_write_json_file(os.path.join(config_path, 'resourceconfig.json'), resource_config)
@@ -285,14 +287,16 @@ def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters
285287
# Set up the channels for the containers. For local data we will
286288
# mount the local directory to the container. For S3 Data we will download the S3 data
287289
# first.
288-
for channel in input_data_config:
289-
uri = channel['DataUri']
290-
channel_name = channel['ChannelName']
291-
channel_dir = os.path.join(data_dir, channel_name)
292-
os.mkdir(channel_dir)
293-
294-
data_source = sagemaker.local.data.get_data_source_instance(uri, self.sagemaker_session)
295-
volumes.append(_Volume(data_source.get_root_dir(), channel=channel_name))
290+
291+
if input_data_config:
292+
for channel in input_data_config:
293+
uri = channel['DataUri']
294+
channel_name = channel['ChannelName']
295+
channel_dir = os.path.join(data_dir, channel_name)
296+
os.mkdir(channel_dir)
297+
298+
data_source = sagemaker.local.data.get_data_source_instance(uri, self.sagemaker_session)
299+
volumes.append(_Volume(data_source.get_root_dir(), channel=channel_name))
296300

297301
# If there is a training script directory and it is a local directory,
298302
# mount it to the container.

src/sagemaker/local/local_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def __init__(self, sagemaker_session=None):
5353
"""
5454
self.sagemaker_session = sagemaker_session or LocalSession()
5555

56-
def create_training_job(self, TrainingJobName, AlgorithmSpecification, InputDataConfig, OutputDataConfig,
57-
ResourceConfig, **kwargs):
56+
def create_training_job(self, TrainingJobName, AlgorithmSpecification, OutputDataConfig,
57+
ResourceConfig, InputDataConfig=None, **kwargs):
5858
"""
5959
Create a training job in Local Mode
6060
Args:

tests/unit/test_local_session.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def test_create_training_job(train, LocalSession):
6161
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
6262
hyperparameters = {'a': 1, 'b': 'bee'}
6363

64-
local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config,
65-
output_data_config, resource_config, HyperParameters=hyperparameters)
64+
local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config,
65+
InputDataConfig=input_data_config, HyperParameters=hyperparameters)
6666

6767
expected = {
6868
'ResourceConfig': {'InstanceCount': instance_count},
@@ -111,8 +111,8 @@ def test_create_training_job_invalid_data_source(train, LocalSession):
111111
hyperparameters = {'a': 1, 'b': 'bee'}
112112

113113
with pytest.raises(ValueError):
114-
local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config,
115-
output_data_config, resource_config, HyperParameters=hyperparameters)
114+
local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config,
115+
InputDataConfig=input_data_config, HyperParameters=hyperparameters)
116116

117117

118118
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
@@ -141,8 +141,8 @@ def test_create_training_job_not_fully_replicated(train, LocalSession):
141141
hyperparameters = {'a': 1, 'b': 'bee'}
142142

143143
with pytest.raises(RuntimeError):
144-
local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config,
145-
output_data_config, resource_config, HyperParameters=hyperparameters)
144+
local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config,
145+
InputDataConfig=input_data_config, HyperParameters=hyperparameters)
146146

147147

148148
@patch('sagemaker.local.local_session.LocalSession')

0 commit comments

Comments
 (0)