Skip to content

change: refactor normalization of args for processing #1861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 31, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 82 additions & 17 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,12 @@ def run(
Please either set wait to True or set logs to False."""
)

self._current_job_name = self._generate_current_job_name(job_name=job_name)

normalized_inputs = self._normalize_inputs(inputs)
normalized_outputs = self._normalize_outputs(outputs)
self.arguments = arguments
normalized_inputs, normalized_outputs = self._normalize_args(
job_name=job_name,
arguments=arguments,
inputs=inputs,
outputs=outputs,
)

self.latest_job = ProcessingJob.start_new(
processor=self,
Expand All @@ -165,6 +166,48 @@ def run(
if wait:
self.latest_job.wait(logs=logs)

def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=None, code=None):
"""Normalizes the arguments so that they can be passed to the job run

Args:
job_name (str): Name of the processing job to be created. If not specified, one
is generated, using the base name given to the constructor, if applicable
(default: None).
arguments (list[str]): A list of string arguments to be passed to a
processing job (default: None).
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
code (str): This can be an S3 URI or a local path to a file with the framework
script to run (default: None). A no op in the base class.
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

inputs_with_code = self._include_code_in_inputs(inputs, code)
normalized_inputs = self._normalize_inputs(inputs_with_code)
normalized_outputs = self._normalize_outputs(outputs)
self.arguments = arguments

return normalized_inputs, normalized_outputs

def _include_code_in_inputs(self, inputs, _code):
"""A no op in the base class to include code in the processing job inputs.

Args:
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects.
_code (str): This can be an S3 URI or a local path to a file with the framework
script to run (default: None). A no op in the base class.

Returns:
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs
"""
return inputs

def _generate_current_job_name(self, job_name=None):
"""Generates the job name before running a processing job.

Expand Down Expand Up @@ -388,18 +431,13 @@ def run(
Dictionary contains three optional keys:
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

user_code_s3_uri = self._handle_user_code_url(code)
user_script_name = self._get_user_code_name(code)

inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)

self._set_entrypoint(self.command, user_script_name)

normalized_inputs = self._normalize_inputs(inputs_with_code)
normalized_outputs = self._normalize_outputs(outputs)
self.arguments = arguments
normalized_inputs, normalized_outputs = self._normalize_args(
job_name=job_name,
arguments=arguments,
inputs=inputs,
outputs=outputs,
code=code,
)

self.latest_job = ProcessingJob.start_new(
processor=self,
Expand All @@ -411,6 +449,33 @@ def run(
if wait:
self.latest_job.wait(logs=logs)

def _include_code_in_inputs(self, inputs, code):
"""Converts code to appropriate input and includes in input list.

Side effects include:
* uploads code to S3 if the code is a local file.
* sets the entrypoint attribute based on the command and user script name from code.

Args:
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects.
code (str): This can be an S3 URI or a local path to a file with the framework
script to run (default: None).

Returns:
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs together with the
code as `ProcessingInput`.
"""
user_code_s3_uri = self._handle_user_code_url(code)
user_script_name = self._get_user_code_name(code)

inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)

self._set_entrypoint(self.command, user_script_name)

return inputs_with_code

def _get_user_code_name(self, code):
"""Gets the basename of the user's code from the URL the customer provided.

Expand Down