Skip to content

Commit fa1b292

Browse files
qvoglundQuentin R. Voglundahsan-z-khan
authored
fix: add kms key for processing job code upload (#2329)
Co-authored-by: Quentin R. Voglund <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent cca8476 commit fa1b292

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

src/sagemaker/processing.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from sagemaker.local import LocalSession
3131
from sagemaker.utils import base_name_from_image, name_from_base
3232
from sagemaker.session import Session
33-
from sagemaker.network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
3433
from sagemaker.workflow.properties import Properties
3534
from sagemaker.workflow.parameters import Parameter
3635
from sagemaker.workflow.entities import Expression
@@ -219,14 +218,14 @@ def _normalize_args(
219218
"""
220219
self._current_job_name = self._generate_current_job_name(job_name=job_name)
221220

222-
inputs_with_code = self._include_code_in_inputs(inputs, code)
221+
inputs_with_code = self._include_code_in_inputs(inputs, code, kms_key)
223222
normalized_inputs = self._normalize_inputs(inputs_with_code, kms_key)
224223
normalized_outputs = self._normalize_outputs(outputs)
225224
self.arguments = arguments
226225

227226
return normalized_inputs, normalized_outputs
228227

229-
def _include_code_in_inputs(self, inputs, _code):
228+
def _include_code_in_inputs(self, inputs, _code, _kms_key):
230229
"""A no op in the base class to include code in the processing job inputs.
231230
232231
Args:
@@ -235,6 +234,8 @@ def _include_code_in_inputs(self, inputs, _code):
235234
:class:`~sagemaker.processing.ProcessingInput` objects.
236235
_code (str): This can be an S3 URI or a local path to a file with the framework
237236
script to run (default: None). A no op in the base class.
237+
kms_key (str): The ARN of the KMS key that is used to encrypt the
238+
user code file (default: None).
238239
239240
Returns:
240241
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs
@@ -528,7 +529,7 @@ def run(
528529
if wait:
529530
self.latest_job.wait(logs=logs)
530531

531-
def _include_code_in_inputs(self, inputs, code):
532+
def _include_code_in_inputs(self, inputs, code, kms_key=None):
532533
"""Converts code to appropriate input and includes in input list.
533534
534535
Side effects include:
@@ -541,12 +542,14 @@ def _include_code_in_inputs(self, inputs, code):
541542
:class:`~sagemaker.processing.ProcessingInput` objects.
542543
code (str): This can be an S3 URI or a local path to a file with the framework
543544
script to run (default: None).
545+
kms_key (str): The ARN of the KMS key that is used to encrypt the
546+
user code file (default: None).
544547
545548
Returns:
546549
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs together with the
547550
code as `ProcessingInput`.
548551
"""
549-
user_code_s3_uri = self._handle_user_code_url(code)
552+
user_code_s3_uri = self._handle_user_code_url(code, kms_key)
550553
user_script_name = self._get_user_code_name(code)
551554

552555
inputs_with_code = self._convert_code_and_add_to_inputs(inputs, user_code_s3_uri)
@@ -567,14 +570,16 @@ def _get_user_code_name(self, code):
567570
code_url = urlparse(code)
568571
return os.path.basename(code_url.path)
569572

570-
def _handle_user_code_url(self, code):
573+
def _handle_user_code_url(self, code, kms_key=None):
571574
"""Gets the S3 URL containing the user's code.
572575
573576
Inspects the scheme the customer passed in ("s3://" for code in S3, "file://" or nothing
574577
for absolute or local file paths. Uploads the code to S3 if the code is a local file.
575578
576579
Args:
577580
code (str): A URL to the customer's code.
581+
kms_key (str): The ARN of the KMS key that is used to encrypt the
582+
user code file (default: None).
578583
579584
Returns:
580585
str: The S3 URL to the customer's code.
@@ -603,7 +608,7 @@ def _handle_user_code_url(self, code):
603608
code
604609
)
605610
)
606-
user_code_s3_uri = self._upload_code(code_path)
611+
user_code_s3_uri = self._upload_code(code_path, kms_key)
607612
else:
608613
raise ValueError(
609614
"code {} url scheme {} is not recognized. Please pass a file path or S3 url".format(
@@ -612,11 +617,13 @@ def _handle_user_code_url(self, code):
612617
)
613618
return user_code_s3_uri
614619

615-
def _upload_code(self, code):
620+
def _upload_code(self, code, kms_key=None):
616621
"""Uploads a code file or directory specified as a string and returns the S3 URI.
617622
618623
Args:
619624
code (str): A file or directory to be uploaded to S3.
625+
kms_key (str): The ARN of the KMS key that is used to encrypt the
626+
user code file (default: None).
620627
621628
Returns:
622629
str: The S3 URI of the uploaded file or directory.
@@ -630,7 +637,10 @@ def _upload_code(self, code):
630637
self._CODE_CONTAINER_INPUT_NAME,
631638
)
632639
return s3.S3Uploader.upload(
633-
local_path=code, desired_s3_uri=desired_s3_uri, sagemaker_session=self.sagemaker_session
640+
local_path=code,
641+
desired_s3_uri=desired_s3_uri,
642+
kms_key=kms_key,
643+
sagemaker_session=self.sagemaker_session,
634644
)
635645

636646
def _convert_code_and_add_to_inputs(self, inputs, s3_uri):
@@ -666,7 +676,9 @@ def _set_entrypoint(self, command, user_script_name):
666676
"""
667677
user_script_location = str(
668678
pathlib.PurePosixPath(
669-
self._CODE_CONTAINER_BASE_PATH, self._CODE_CONTAINER_INPUT_NAME, user_script_name
679+
self._CODE_CONTAINER_BASE_PATH,
680+
self._CODE_CONTAINER_INPUT_NAME,
681+
user_script_name,
670682
)
671683
)
672684
self.entrypoint = command + [user_script_location]
@@ -1066,7 +1078,10 @@ def _to_request_dict(self):
10661078
"""Generates a request dictionary using the parameters provided to the class."""
10671079

10681080
# Create the request dictionary.
1069-
s3_input_request = {"InputName": self.input_name, "AppManaged": self.app_managed}
1081+
s3_input_request = {
1082+
"InputName": self.input_name,
1083+
"AppManaged": self.app_managed,
1084+
}
10701085

10711086
if self.s3_input:
10721087
# Check the compression type, then add it to the dictionary.

0 commit comments

Comments
 (0)