@@ -149,11 +149,12 @@ def run(
149
149
Please either set wait to True or set logs to False."""
150
150
)
151
151
152
- self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
153
-
154
- normalized_inputs = self ._normalize_inputs (inputs )
155
- normalized_outputs = self ._normalize_outputs (outputs )
156
- self .arguments = arguments
152
+ normalized_inputs , normalized_outputs = self ._normalize_args (
153
+ job_name = job_name ,
154
+ arguments = arguments ,
155
+ inputs = inputs ,
156
+ outputs = outputs ,
157
+ )
157
158
158
159
self .latest_job = ProcessingJob .start_new (
159
160
processor = self ,
@@ -165,6 +166,48 @@ def run(
165
166
if wait :
166
167
self .latest_job .wait (logs = logs )
167
168
169
+ def _normalize_args (self , job_name = None , arguments = None , inputs = None , outputs = None , code = None ):
170
+ """Normalizes the arguments so that they can be passed to the job run
171
+
172
+ Args:
173
+ job_name (str): Name of the processing job to be created. If not specified, one
174
+ is generated, using the base name given to the constructor, if applicable
175
+ (default: None).
176
+ arguments (list[str]): A list of string arguments to be passed to a
177
+ processing job (default: None).
178
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
179
+ the processing job. These must be provided as
180
+ :class:`~sagemaker.processing.ProcessingInput` objects (default: None).
181
+ outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
182
+ the processing job. These can be specified as either path strings or
183
+ :class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
184
+ code (str): This can be an S3 URI or a local path to a file with the framework
185
+ script to run (default: None). A no op in the base class.
186
+ """
187
+ self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
188
+
189
+ inputs_with_code = self ._include_code_in_inputs (inputs , code )
190
+ normalized_inputs = self ._normalize_inputs (inputs_with_code )
191
+ normalized_outputs = self ._normalize_outputs (outputs )
192
+ self .arguments = arguments
193
+
194
+ return normalized_inputs , normalized_outputs
195
+
196
+ def _include_code_in_inputs (self , inputs , _code ):
197
+ """A no op in the base class to include code in the processing job inputs.
198
+
199
+ Args:
200
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
201
+ the processing job. These must be provided as
202
+ :class:`~sagemaker.processing.ProcessingInput` objects.
203
+ _code (str): This can be an S3 URI or a local path to a file with the framework
204
+ script to run (default: None). A no op in the base class.
205
+
206
+ Returns:
207
+ list[:class:`~sagemaker.processing.ProcessingInput`]: inputs
208
+ """
209
+ return inputs
210
+
168
211
def _generate_current_job_name (self , job_name = None ):
169
212
"""Generates the job name before running a processing job.
170
213
@@ -388,18 +431,13 @@ def run(
388
431
Dictionary contains three optional keys:
389
432
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
390
433
"""
391
- self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
392
-
393
- user_code_s3_uri = self ._handle_user_code_url (code )
394
- user_script_name = self ._get_user_code_name (code )
395
-
396
- inputs_with_code = self ._convert_code_and_add_to_inputs (inputs , user_code_s3_uri )
397
-
398
- self ._set_entrypoint (self .command , user_script_name )
399
-
400
- normalized_inputs = self ._normalize_inputs (inputs_with_code )
401
- normalized_outputs = self ._normalize_outputs (outputs )
402
- self .arguments = arguments
434
+ normalized_inputs , normalized_outputs = self ._normalize_args (
435
+ job_name = job_name ,
436
+ arguments = arguments ,
437
+ inputs = inputs ,
438
+ outputs = outputs ,
439
+ code = code ,
440
+ )
403
441
404
442
self .latest_job = ProcessingJob .start_new (
405
443
processor = self ,
@@ -411,6 +449,33 @@ def run(
411
449
if wait :
412
450
self .latest_job .wait (logs = logs )
413
451
452
+ def _include_code_in_inputs (self , inputs , code ):
453
+ """Converts code to appropriate input and includes in input list.
454
+
455
+ Side effects include:
456
+ * uploads code to S3 if the code is a local file.
457
+ * sets the entrypoint attribute based on the command and user script name from code.
458
+
459
+ Args:
460
+ inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
461
+ the processing job. These must be provided as
462
+ :class:`~sagemaker.processing.ProcessingInput` objects.
463
+ code (str): This can be an S3 URI or a local path to a file with the framework
464
+ script to run (default: None).
465
+
466
+ Returns:
467
+ list[:class:`~sagemaker.processing.ProcessingInput`]: inputs together with the
468
+ code as `ProcessingInput`.
469
+ """
470
+ user_code_s3_uri = self ._handle_user_code_url (code )
471
+ user_script_name = self ._get_user_code_name (code )
472
+
473
+ inputs_with_code = self ._convert_code_and_add_to_inputs (inputs , user_code_s3_uri )
474
+
475
+ self ._set_entrypoint (self .command , user_script_name )
476
+
477
+ return inputs_with_code
478
+
414
479
def _get_user_code_name (self , code ):
415
480
"""Gets the basename of the user's code from the URL the customer provided.
416
481
0 commit comments