@@ -51,9 +51,9 @@ def wait(self):
51
51
pass
52
52
53
53
@staticmethod
54
- def _load_config (inputs , estimator ):
55
- input_config = _Job ._format_inputs_to_input_config (inputs )
56
- role = estimator .sagemaker_session .expand_role (estimator .role )
54
+ def _load_config (inputs , estimator , expand_role = True , validate_uri = True ):
55
+ input_config = _Job ._format_inputs_to_input_config (inputs , validate_uri )
56
+ role = estimator .sagemaker_session .expand_role (estimator .role ) if expand_role else estimator . role
57
57
output_config = _Job ._prepare_output_config (estimator .output_path , estimator .output_kms_key )
58
58
resource_config = _Job ._prepare_resource_config (estimator .train_instance_count ,
59
59
estimator .train_instance_type ,
@@ -62,7 +62,8 @@ def _load_config(inputs, estimator):
62
62
stop_condition = _Job ._prepare_stop_condition (estimator .train_max_run )
63
63
vpc_config = estimator .get_vpc_config ()
64
64
65
- model_channel = _Job ._prepare_model_channel (input_config , estimator .model_uri , estimator .model_channel_name )
65
+ model_channel = _Job ._prepare_model_channel (input_config , estimator .model_uri , estimator .model_channel_name ,
66
+ validate_uri )
66
67
if model_channel :
67
68
input_config = [] if input_config is None else input_config
68
69
input_config .append (model_channel )
@@ -75,7 +76,7 @@ def _load_config(inputs, estimator):
75
76
'vpc_config' : vpc_config }
76
77
77
78
@staticmethod
78
- def _format_inputs_to_input_config (inputs ):
79
+ def _format_inputs_to_input_config (inputs , validate_uri = True ):
79
80
if inputs is None :
80
81
return None
81
82
@@ -86,14 +87,14 @@ def _format_inputs_to_input_config(inputs):
86
87
87
88
input_dict = {}
88
89
if isinstance (inputs , string_types ):
89
- input_dict ['training' ] = _Job ._format_string_uri_input (inputs )
90
+ input_dict ['training' ] = _Job ._format_string_uri_input (inputs , validate_uri )
90
91
elif isinstance (inputs , s3_input ):
91
92
input_dict ['training' ] = inputs
92
93
elif isinstance (inputs , file_input ):
93
94
input_dict ['training' ] = inputs
94
95
elif isinstance (inputs , dict ):
95
96
for k , v in inputs .items ():
96
- input_dict [k ] = _Job ._format_string_uri_input (v )
97
+ input_dict [k ] = _Job ._format_string_uri_input (v , validate_uri )
97
98
elif isinstance (inputs , list ):
98
99
input_dict = _Job ._format_record_set_list_input (inputs )
99
100
else :
@@ -111,15 +112,16 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
111
112
return channel_config
112
113
113
114
@staticmethod
114
- def _format_string_uri_input (uri_input ):
115
- if isinstance (uri_input , str ):
116
- if uri_input .startswith ('s3://' ):
117
- return s3_input (uri_input )
118
- elif uri_input .startswith ('file://' ):
119
- return file_input (uri_input )
120
- else :
121
- raise ValueError ('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
122
- '"file://"' )
115
+ def _format_string_uri_input (uri_input , validate_uri = True ):
116
+ if isinstance (uri_input , str ) and validate_uri and uri_input .startswith ('s3://' ):
117
+ return s3_input (uri_input )
118
+ elif isinstance (uri_input , str ) and validate_uri and uri_input .startswith ('file://' ):
119
+ return file_input (uri_input )
120
+ elif isinstance (uri_input , str ) and validate_uri :
121
+ raise ValueError ('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
122
+ '"file://"' )
123
+ elif isinstance (uri_input , str ):
124
+ return s3_input (uri_input )
123
125
elif isinstance (uri_input , s3_input ):
124
126
return uri_input
125
127
elif isinstance (uri_input , file_input ):
@@ -128,7 +130,7 @@ def _format_string_uri_input(uri_input):
128
130
raise ValueError ('Cannot format input {}. Expecting one of str, s3_input, or file_input' .format (uri_input ))
129
131
130
132
@staticmethod
131
- def _prepare_model_channel (input_config , model_uri = None , model_channel_name = None ):
133
+ def _prepare_model_channel (input_config , model_uri = None , model_channel_name = None , validate_uri = True ):
132
134
if not model_uri :
133
135
return
134
136
elif not model_channel_name :
@@ -139,22 +141,24 @@ def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None
139
141
if channel ['ChannelName' ] == model_channel_name :
140
142
raise ValueError ('Duplicate channels not allowed.' )
141
143
142
- model_input = _Job ._format_model_uri_input (model_uri )
144
+ model_input = _Job ._format_model_uri_input (model_uri , validate_uri )
143
145
model_channel = _Job ._convert_input_to_channel (model_channel_name , model_input )
144
146
145
147
return model_channel
146
148
147
149
@staticmethod
148
- def _format_model_uri_input (model_uri ):
149
- if isinstance (model_uri , string_types ):
150
- if model_uri .startswith ('s3://' ):
151
- return s3_input (model_uri , input_mode = 'File' , distribution = 'FullyReplicated' ,
152
- content_type = 'application/x-sagemaker-model' )
153
- elif model_uri .startswith ('file://' ):
154
- return file_input (model_uri )
155
- else :
156
- raise ValueError ('Model URI must be a valid S3 or FILE URI: must start with "s3://" or '
157
- '"file://' )
150
+ def _format_model_uri_input (model_uri , validate_uri = True ):
151
+ if isinstance (model_uri , string_types )and validate_uri and model_uri .startswith ('s3://' ):
152
+ return s3_input (model_uri , input_mode = 'File' , distribution = 'FullyReplicated' ,
153
+ content_type = 'application/x-sagemaker-model' )
154
+ elif isinstance (model_uri , string_types ) and validate_uri and model_uri .startswith ('file://' ):
155
+ return file_input (model_uri )
156
+ elif isinstance (model_uri , string_types ) and validate_uri :
157
+ raise ValueError ('Model URI must be a valid S3 or FILE URI: must start with "s3://" or '
158
+ '"file://' )
159
+ elif isinstance (model_uri , string_types ):
160
+ return s3_input (model_uri , input_mode = 'File' , distribution = 'FullyReplicated' ,
161
+ content_type = 'application/x-sagemaker-model' )
158
162
else :
159
163
raise ValueError ('Cannot format model URI {}. Expecting str' .format (model_uri ))
160
164
0 commit comments