Skip to content

Commit 5611512

Browse files
committed
Support optional input channels in local mode.
1 parent 5201c60 commit 5611512

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

CHANGELOG.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22
CHANGELOG
33
=========
44

5+
1.14.1.dev
6+
==========
7+
8+
* enhancement: Local Mode: support optional input channels
9+
510
1.14.0
611
======
712

src/sagemaker/local/image.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,6 @@ def write_config_files(self, host, hyperparameters, input_data_config):
262262
'hosts': self.hosts
263263
}
264264

265-
print(input_data_config)
266265
json_input_data_config = {}
267266
for c in input_data_config:
268267
channel_name = c['ChannelName']

src/sagemaker/local/local_session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def __init__(self, sagemaker_session=None):
5353
"""
5454
self.sagemaker_session = sagemaker_session or LocalSession()
5555

56-
def create_training_job(self, TrainingJobName, AlgorithmSpecification, InputDataConfig, OutputDataConfig,
57-
ResourceConfig, **kwargs):
56+
def create_training_job(self, TrainingJobName, AlgorithmSpecification, OutputDataConfig,
57+
ResourceConfig, InputDataConfig=None, **kwargs):
5858
"""
5959
Create a training job in Local Mode
6060
Args:
@@ -66,7 +66,7 @@ def create_training_job(self, TrainingJobName, AlgorithmSpecification, InputData
6666
HyperParameters (dict) [optional]: Specifies these algorithm-specific parameters to influence the quality of
6767
the final model.
6868
"""
69-
69+
InputDataConfig = InputDataConfig or {}
7070
container = _SageMakerContainer(ResourceConfig['InstanceType'], ResourceConfig['InstanceCount'],
7171
AlgorithmSpecification['TrainingImage'], self.sagemaker_session)
7272
training_job = _LocalTrainingJob(container)

tests/unit/test_local_session.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def test_create_training_job(train, LocalSession):
6161
resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count}
6262
hyperparameters = {'a': 1, 'b': 'bee'}
6363

64-
local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config,
65-
output_data_config, resource_config, HyperParameters=hyperparameters)
64+
local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config,
65+
InputDataConfig=input_data_config, HyperParameters=hyperparameters)
6666

6767
expected = {
6868
'ResourceConfig': {'InstanceCount': instance_count},
@@ -111,8 +111,8 @@ def test_create_training_job_invalid_data_source(train, LocalSession):
111111
hyperparameters = {'a': 1, 'b': 'bee'}
112112

113113
with pytest.raises(ValueError):
114-
local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config,
115-
output_data_config, resource_config, HyperParameters=hyperparameters)
114+
local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config,
115+
InputDataConfig=input_data_config, HyperParameters=hyperparameters)
116116

117117

118118
@patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model")
@@ -141,8 +141,8 @@ def test_create_training_job_not_fully_replicated(train, LocalSession):
141141
hyperparameters = {'a': 1, 'b': 'bee'}
142142

143143
with pytest.raises(RuntimeError):
144-
local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config,
145-
output_data_config, resource_config, HyperParameters=hyperparameters)
144+
local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config,
145+
InputDataConfig=input_data_config, HyperParameters=hyperparameters)
146146

147147

148148
@patch('sagemaker.local.local_session.LocalSession')

0 commit comments

Comments
 (0)