Skip to content

Commit 1c6a8a6

Browse files
author
van Roekel, Gertjan
committed
Local mode support for file:// URI as the input for training data, bypassing uploading to/downloading from S3.
1 parent 597a4f5 commit 1c6a8a6

File tree

3 files changed

+92
-13
lines changed

3 files changed

+92
-13
lines changed

src/sagemaker/estimator.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import print_function, absolute_import
1414

15+
import os
1516
import json
1617
import logging
1718
from abc import ABCMeta
@@ -21,13 +22,18 @@
2122
from sagemaker.fw_utils import tar_and_upload_dir
2223
from sagemaker.fw_utils import parse_s3_url
2324
from sagemaker.fw_utils import UploadedCode
24-
from sagemaker.local.local_session import LocalSession
25+
26+
from sagemaker.local.local_session import LocalSession, file_input
27+
2528
from sagemaker.model import Model
2629
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
2730
CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME)
31+
2832
from sagemaker.predictor import RealTimePredictor
33+
2934
from sagemaker.session import Session
3035
from sagemaker.session import s3_input
36+
3137
from sagemaker.utils import base_name_from_image, name_from_base
3238

3339

@@ -321,6 +327,13 @@ def start_new(cls, estimator, inputs):
321327
sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
322328
"""
323329

330+
local_mode = estimator.local_mode
331+
332+
# Allow file:// input only in local mode
333+
if isinstance(inputs, str) and inputs.startswith('file://'):
334+
if not local_mode:
335+
raise ValueError('File URIs are supported in local mode only. Please use a S3 URI instead.')
336+
324337
input_config = _TrainingJob._format_inputs_to_input_config(inputs)
325338
role = estimator.sagemaker_session.expand_role(estimator.role)
326339
output_config = _TrainingJob._prepare_output_config(estimator.output_path, estimator.output_kms_key)
@@ -343,12 +356,14 @@ def start_new(cls, estimator, inputs):
343356
def _format_inputs_to_input_config(inputs):
344357
input_dict = {}
345358
if isinstance(inputs, string_types):
346-
input_dict['training'] = _TrainingJob._format_s3_uri_input(inputs)
359+
input_dict['training'] = _TrainingJob._format_string_uri_input(inputs)
347360
elif isinstance(inputs, s3_input):
348361
input_dict['training'] = inputs
362+
elif isinstance(input, file_input):
363+
input_dict['training'] = inputs
349364
elif isinstance(inputs, dict):
350365
for k, v in inputs.items():
351-
input_dict[k] = _TrainingJob._format_s3_uri_input(v)
366+
input_dict[k] = _TrainingJob._format_string_uri_input(v)
352367
else:
353368
raise ValueError('Cannot format input {}. Expecting one of str, dict or s3_input'.format(inputs))
354369

@@ -360,15 +375,20 @@ def _format_inputs_to_input_config(inputs):
360375
return channels
361376

362377
@staticmethod
363-
def _format_s3_uri_input(input):
378+
def _format_string_uri_input(input):
364379
if isinstance(input, str):
365-
if not input.startswith('s3://'):
366-
raise ValueError('Training input data must be a valid S3 URI and must start with "s3://"')
367-
return s3_input(input)
368-
if isinstance(input, s3_input):
380+
if input.startswith('s3://'):
381+
return s3_input(input)
382+
elif input.startswith('file://'):
383+
return file_input(input)
384+
else:
385+
raise ValueError('Training input data must be a valid S3 or FILE URI and must start with "s3://" or "file://"')
386+
elif isinstance(input, s3_input):
387+
return input
388+
elif isinstance(input, file_input):
369389
return input
370390
else:
371-
raise ValueError('Cannot format input {}. Expecting one of str or s3_input'.format(input))
391+
raise ValueError('Cannot format input {}. Expecting one of str, s3_input, or file_input'.format(input))
372392

373393
@staticmethod
374394
def _prepare_output_config(s3_path, kms_key_id):

src/sagemaker/local/image.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,14 @@ def train(self, input_data_config, hyperparameters):
9393
# mount the local directory to the container. For S3 Data we will download the S3 data
9494
# first.
9595
for channel in input_data_config:
96-
uri = channel['DataSource']['S3DataSource']['S3Uri']
96+
97+
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
98+
uri = channel['DataSource']['S3DataSource']['S3Uri']
99+
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
100+
uri = channel['DataSource']['FileDataSource']['FileUri']
101+
else:
102+
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
103+
97104
parsed_uri = urlparse(uri)
98105
key = parsed_uri.path.lstrip('/')
99106

@@ -104,8 +111,13 @@ def train(self, input_data_config, hyperparameters):
104111
if parsed_uri.scheme == 's3':
105112
bucket_name = parsed_uri.netloc
106113
self._download_folder(bucket_name, key, channel_dir)
114+
elif parsed_uri.scheme == 'file':
115+
# TODO Check why this is file:/... and not file:///...
116+
# TODO use the parsed_uri.xxx and use os.path.join
117+
path = uri.lstrip('file:')
118+
volumes.append(_Volume(path, channel=channel_name))
107119
else:
108-
volumes.append(_Volume(uri, channel=channel_name))
120+
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))
109121

110122
# Create the configuration files for each container that we will create
111123
# Each container will map the additional local volumes (if any).

src/sagemaker/local/local_session.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,22 @@ def create_training_job(self, TrainingJobName, AlgorithmSpecification, RoleArn,
5656
AlgorithmSpecification['TrainingImage'], self.sagemaker_session)
5757

5858
for channel in InputDataConfig:
59-
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
59+
60+
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
61+
data_distribution = channel['DataSource']['S3DataSource']['S3DataDistributionType']
62+
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
63+
data_distribution = channel['DataSource']['FileDataSource']['FileDataDistributionType']
64+
else:
65+
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')
66+
6067
if data_distribution != 'FullyReplicated':
6168
raise RuntimeError("DataDistribution: %s is not currently supported in Local Mode" %
6269
data_distribution)
6370

6471
self.s3_model_artifacts = self.train_container.train(InputDataConfig, HyperParameters)
6572

6673
def describe_training_job(self, TrainingJobName):
67-
"""Describe a local traininig job.
74+
"""Describe a local training job.
6875
6976
Args:
7077
TrainingJobName (str): Not used in this implmentation.
@@ -171,3 +178,43 @@ def logs_for_job(self, job_name, wait=False, poll=5):
171178
# override logs_for_job() as it doesn't need to perform any action
172179
# on local mode.
173180
pass
181+
182+
# TODO Naming consistent with session.s3_input. May want to change both
183+
# (e.g. S3Input and FileInput)
184+
class file_input(object):
185+
"""Amazon SageMaker channel configuration for FILE data sources, used in local mode.
186+
187+
Attributes:
188+
config (dict[str, dict]): A SageMaker ``DataSource`` referencing a SageMaker ``S3DataSource``.
189+
"""
190+
191+
def __init__(self, fileUri):
192+
"""Create a definition for input data used by an SageMaker training job in local mode.
193+
"""
194+
195+
"""TODO Keeping this consistent with s3_input data structure. May be
196+
better to have a Type key under DataSource, but that really would mess
197+
with the standard implementation....
198+
"""
199+
200+
self.config = {
201+
'DataSource': {
202+
'FileDataSource': {
203+
# TODO Ok to hardcode this here or allow input?
204+
'FileDataDistributionType': 'FullyReplicated',
205+
'FileUri': fileUri
206+
}
207+
}
208+
}
209+
210+
# As per docs, leave unset in FILE mode
211+
# if compression is not None:
212+
# self.config['CompressionType'] = compression
213+
214+
# if content_type is not None:
215+
# self.config['ContentType'] = content_type
216+
217+
# As per docs, leave unset in FILE mode
218+
# if record_wrapping is not None:
219+
# self.config['RecordWrapperType'] = record_wrapping
220+

0 commit comments

Comments
 (0)