Skip to content

Commit 8cab930

Browse files
support kms key in processor pack local code
1 parent 5bc3ccf commit 8cab930

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

src/sagemaker/processing.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def _normalize_outputs(self, outputs=None):
392392
output.destination = s3_uri
393393
normalized_outputs.append(output)
394394
return normalized_outputs
395+
return normalized_outputs
395396

396397

397398
class ScriptProcessor(Processor):
@@ -1622,7 +1623,7 @@ def run( # type: ignore[override]
16221623
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
16231624
"""
16241625
s3_runproc_sh, inputs, job_name = self._pack_and_upload_code(
1625-
code, source_dir, dependencies, git_config, job_name, inputs
1626+
code, source_dir, dependencies, git_config, job_name, inputs, kms_key
16261627
)
16271628

16281629
# Submit a processing job.
@@ -1638,7 +1639,9 @@ def run( # type: ignore[override]
16381639
kms_key=kms_key,
16391640
)
16401641

1641-
def _pack_and_upload_code(self, code, source_dir, dependencies, git_config, job_name, inputs):
1642+
def _pack_and_upload_code(
1643+
self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None
1644+
):
16421645
"""Pack local code bundle and upload to Amazon S3."""
16431646
if code.startswith("s3://"):
16441647
return code, inputs, job_name
@@ -1676,6 +1679,7 @@ def _pack_and_upload_code(self, code, source_dir, dependencies, git_config, job_
16761679
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
16771680
self._generate_framework_script(script),
16781681
desired_s3_uri=entrypoint_s3_uri,
1682+
kms_key=kms_key,
16791683
sagemaker_session=self.sagemaker_session,
16801684
)
16811685
logger.info("runproc.sh uploaded to %s", s3_runproc_sh)

tests/unit/sagemaker/workflow/test_pipeline_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_pipeline_session_init(sagemaker_client_config, boto_session):
5050
sagemaker_client=sagemaker_client,
5151
)
5252
assert sess.sagemaker_client is not None
53-
assert sess.default_bucket() is not None
53+
assert sess.default_bucket is not None
5454
assert sess.context is None
5555

5656

0 commit comments

Comments
 (0)