Skip to content

fix: FrameworkProcessor S3 uploads #3493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,6 +1704,8 @@ def _pack_and_upload_code(
self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None
):
"""Pack local code bundle and upload to Amazon S3."""
from sagemaker.workflow.utilities import _pipeline_config, hash_object

if code.startswith("s3://"):
return code, inputs, job_name

Expand Down Expand Up @@ -1737,12 +1739,29 @@ def _pack_and_upload_code(
"runproc.sh",
)
script = estimator.uploaded_code.script_name
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
self._generate_framework_script(script),
desired_s3_uri=entrypoint_s3_uri,
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)

# If we are leveraging a pipeline session with optimized s3 artifact paths,
# we need to hash and upload the runproc.sh file to a separate location.
if _pipeline_config and _pipeline_config.pipeline_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: may be good to extract to a helper function according to SRP

runproc_file_str = self._generate_framework_script(script)
runproc_file_hash = hash_object(runproc_file_str)
s3_uri = (
f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/"
f"code/{runproc_file_hash}/runproc.sh"
)
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
runproc_file_str,
desired_s3_uri=s3_uri,
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)
else:
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
self._generate_framework_script(script),
desired_s3_uri=entrypoint_s3_uri,
kms_key=kms_key,
sagemaker_session=self.sagemaker_session,
)
logger.info("runproc.sh uploaded to %s", s3_runproc_sh)

return s3_runproc_sh, inputs, job_name
Expand Down Expand Up @@ -1827,14 +1846,19 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
# a7399455f5386d83ddc5cb15c0db00c04bd518ec/src/sagemaker/processing.py#L425-L426
if inputs is None:
inputs = []
inputs.append(

# make a shallow copy of user inputs
patched_inputs = []
for user_input in inputs:
patched_inputs.append(user_input)
patched_inputs.append(
ProcessingInput(
input_name="code",
source=s3_payload,
destination="/opt/ml/processing/input/code/",
)
)
return inputs
return patched_inputs

def _set_entrypoint(self, command, user_script_name):
"""Framework processor override for setting processing job entrypoint.
Expand Down
44 changes: 37 additions & 7 deletions src/sagemaker/spark/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,12 @@ def run(
def _extend_processing_args(self, inputs, outputs, **kwargs):
"""Extends processing job args such as inputs."""

# make a copy of user outputs
outputs = outputs or []
extended_outputs = []
for user_output in outputs:
extended_outputs.append(user_output)

if kwargs.get("spark_event_logs_s3_uri"):
spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri")
self._validate_s3_uri(spark_event_logs_s3_uri)
Expand All @@ -297,16 +303,23 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
s3_upload_mode="Continuous",
)

outputs = outputs or []
outputs.append(output)
extended_outputs.append(output)

# make a copy of user inputs
inputs = inputs or []
extended_inputs = []
for user_input in inputs:
extended_inputs.append(user_input)

if kwargs.get("configuration"):
configuration = kwargs.get("configuration")
self._validate_configuration(configuration)
inputs = inputs or []
inputs.append(self._stage_configuration(configuration))
extended_inputs.append(self._stage_configuration(configuration))

return inputs, outputs
return (
extended_inputs if extended_inputs else None,
extended_outputs if extended_outputs else None,
)

def start_history_server(self, spark_event_logs_s3_uri=None):
"""Starts a Spark history server.
Expand Down Expand Up @@ -940,9 +953,18 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
outputs: Processing outputs.
kwargs: Additional keyword arguments passed to `super()`.
"""

if inputs is None:
inputs = []

# make a shallow copy of user inputs
extended_inputs = []
for user_input in inputs:
extended_inputs.append(user_input)

self.command = [_SparkProcessorBase._default_command]
extended_inputs = self._handle_script_dependencies(
inputs, kwargs.get("submit_py_files"), FileType.PYTHON
extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON
)
extended_inputs = self._handle_script_dependencies(
extended_inputs, kwargs.get("submit_jars"), FileType.JAR
Expand Down Expand Up @@ -1199,8 +1221,16 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
else:
raise ValueError("submit_class is required")

if inputs is None:
inputs = []

# make a shallow copy of user inputs
extended_inputs = []
for user_input in inputs:
extended_inputs.append(user_input)

extended_inputs = self._handle_script_dependencies(
inputs, kwargs.get("submit_jars"), FileType.JAR
extended_inputs, kwargs.get("submit_jars"), FileType.JAR
)
extended_inputs = self._handle_script_dependencies(
extended_inputs, kwargs.get("submit_files"), FileType.FILE
Expand Down
7 changes: 4 additions & 3 deletions src/sagemaker/workflow/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ def get_code_hash(step: Entity) -> str:
if isinstance(step, ProcessingStep) and step.step_args:
kwargs = step.step_args.func_kwargs
source_dir = kwargs.get("source_dir")
submit_class = kwargs.get("submit_class")
dependencies = get_processing_dependencies(
[
kwargs.get("dependencies"),
kwargs.get("submit_py_files"),
kwargs.get("submit_class"),
[submit_class] if submit_class else None,
kwargs.get("submit_jars"),
kwargs.get("submit_files"),
]
Expand Down Expand Up @@ -168,7 +169,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
str: A hash string representing the unique code artifact(s) for the step
"""

# FrameworkProcessor
# If FrameworkProcessor contains source_dir
if source_dir:
source_dir_url = urlparse(source_dir)
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
Expand Down Expand Up @@ -400,5 +401,5 @@ def execute_job_functions(step_args: _StepArguments):
"""

chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs)
if chained_args:
if isinstance(chained_args, _StepArguments):
execute_job_functions(chained_args)
5 changes: 5 additions & 0 deletions tests/data/framework_processor_data/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Integ test file evaluate.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: As these test data are used by pipeline tests only. Let's move it to the data/pipeline or data/workflow folder.
In the future we'd better merge these 2 folders into 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Name it as "data/pipeline/test_source_dir"

"""

print("test evaluate script")
5 changes: 5 additions & 0 deletions tests/data/framework_processor_data/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Integ test file preprocess.py
"""

print("test preprocess script")
5 changes: 5 additions & 0 deletions tests/data/framework_processor_data/query_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Integ test file query_data.py
"""

print("test query data script")
5 changes: 5 additions & 0 deletions tests/data/framework_processor_data/train_test_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Integ test file train_test_split.py
"""

print("test train, test, split script")
Binary file added tests/data/spark/code/java/TestJarFile.jar
Binary file not shown.
Binary file not shown.
Loading