12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import print_function , absolute_import
14
14
15
+ import os
15
16
import json
16
17
import logging
17
18
from abc import ABCMeta
21
22
from sagemaker .fw_utils import tar_and_upload_dir
22
23
from sagemaker .fw_utils import parse_s3_url
23
24
from sagemaker .fw_utils import UploadedCode
24
- from sagemaker .local .local_session import LocalSession
25
+
26
+ from sagemaker .local .local_session import LocalSession , file_input
27
+
25
28
from sagemaker .model import Model
26
29
from sagemaker .model import (SCRIPT_PARAM_NAME , DIR_PARAM_NAME , CLOUDWATCH_METRICS_PARAM_NAME ,
27
30
CONTAINER_LOG_LEVEL_PARAM_NAME , JOB_NAME_PARAM_NAME , SAGEMAKER_REGION_PARAM_NAME )
31
+
28
32
from sagemaker .predictor import RealTimePredictor
33
+
29
34
from sagemaker .session import Session
30
35
from sagemaker .session import s3_input
36
+
31
37
from sagemaker .utils import base_name_from_image , name_from_base
32
38
33
39
@@ -321,6 +327,13 @@ def start_new(cls, estimator, inputs):
321
327
sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
322
328
"""
323
329
330
+ local_mode = estimator .local_mode
331
+
332
+ # Allow file:// input only in local mode
333
+ if isinstance (inputs , str ) and inputs .startswith ('file://' ):
334
+ if not local_mode :
335
+ raise ValueError ('File URIs are supported in local mode only. Please use a S3 URI instead.' )
336
+
324
337
input_config = _TrainingJob ._format_inputs_to_input_config (inputs )
325
338
role = estimator .sagemaker_session .expand_role (estimator .role )
326
339
output_config = _TrainingJob ._prepare_output_config (estimator .output_path , estimator .output_kms_key )
@@ -343,12 +356,14 @@ def start_new(cls, estimator, inputs):
343
356
def _format_inputs_to_input_config (inputs ):
344
357
input_dict = {}
345
358
if isinstance (inputs , string_types ):
346
- input_dict ['training' ] = _TrainingJob ._format_s3_uri_input (inputs )
359
+ input_dict ['training' ] = _TrainingJob ._format_string_uri_input (inputs )
347
360
elif isinstance (inputs , s3_input ):
348
361
input_dict ['training' ] = inputs
362
+ elif isinstance (input , file_input ):
363
+ input_dict ['training' ] = inputs
349
364
elif isinstance (inputs , dict ):
350
365
for k , v in inputs .items ():
351
- input_dict [k ] = _TrainingJob ._format_s3_uri_input (v )
366
+ input_dict [k ] = _TrainingJob ._format_string_uri_input (v )
352
367
else :
353
368
raise ValueError ('Cannot format input {}. Expecting one of str, dict or s3_input' .format (inputs ))
354
369
@@ -360,15 +375,20 @@ def _format_inputs_to_input_config(inputs):
360
375
return channels
361
376
362
377
@staticmethod
363
- def _format_s3_uri_input (input ):
378
+ def _format_string_uri_input (input ):
364
379
if isinstance (input , str ):
365
- if not input .startswith ('s3://' ):
366
- raise ValueError ('Training input data must be a valid S3 URI and must start with "s3://"' )
367
- return s3_input (input )
368
- if isinstance (input , s3_input ):
380
+ if input .startswith ('s3://' ):
381
+ return s3_input (input )
382
+ elif input .startswith ('file://' ):
383
+ return file_input (input )
384
+ else :
385
+ raise ValueError ('Training input data must be a valid S3 or FILE URI and must start with "s3://" or "file://"' )
386
+ elif isinstance (input , s3_input ):
387
+ return input
388
+ elif isinstance (input , file_input ):
369
389
return input
370
390
else :
371
- raise ValueError ('Cannot format input {}. Expecting one of str or s3_input ' .format (input ))
391
+ raise ValueError ('Cannot format input {}. Expecting one of str, s3_input, or file_input ' .format (input ))
372
392
373
393
@staticmethod
374
394
def _prepare_output_config (s3_path , kms_key_id ):
0 commit comments