30
30
from sagemaker .local import LocalSession
31
31
from sagemaker .utils import base_name_from_image , name_from_base
32
32
from sagemaker .session import Session
33
- from sagemaker .network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
34
33
from sagemaker .workflow .properties import Properties
35
34
from sagemaker .workflow .parameters import Parameter
36
35
from sagemaker .workflow .entities import Expression
@@ -219,14 +218,14 @@ def _normalize_args(
219
218
"""
220
219
self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
221
220
222
- inputs_with_code = self ._include_code_in_inputs (inputs , code )
221
+ inputs_with_code = self ._include_code_in_inputs (inputs , code , kms_key )
223
222
normalized_inputs = self ._normalize_inputs (inputs_with_code , kms_key )
224
223
normalized_outputs = self ._normalize_outputs (outputs )
225
224
self .arguments = arguments
226
225
227
226
return normalized_inputs , normalized_outputs
228
227
229
- def _include_code_in_inputs (self , inputs , _code ):
228
+ def _include_code_in_inputs (self , inputs , _code , _kms_key ):
230
229
"""A no op in the base class to include code in the processing job inputs.
231
230
232
231
Args:
@@ -235,6 +234,8 @@ def _include_code_in_inputs(self, inputs, _code):
235
234
:class:`~sagemaker.processing.ProcessingInput` objects.
236
235
_code (str): This can be an S3 URI or a local path to a file with the framework
237
236
script to run (default: None). A no op in the base class.
237
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
238
+ user code file (default: None).
238
239
239
240
Returns:
240
241
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs
@@ -528,7 +529,7 @@ def run(
528
529
if wait :
529
530
self .latest_job .wait (logs = logs )
530
531
531
- def _include_code_in_inputs (self , inputs , code ):
532
+ def _include_code_in_inputs (self , inputs , code , kms_key = None ):
532
533
"""Converts code to appropriate input and includes in input list.
533
534
534
535
Side effects include:
@@ -541,12 +542,14 @@ def _include_code_in_inputs(self, inputs, code):
541
542
:class:`~sagemaker.processing.ProcessingInput` objects.
542
543
code (str): This can be an S3 URI or a local path to a file with the framework
543
544
script to run (default: None).
545
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
546
+ user code file (default: None).
544
547
545
548
Returns:
546
549
list[:class:`~sagemaker.processing.ProcessingInput`]: inputs together with the
547
550
code as `ProcessingInput`.
548
551
"""
549
- user_code_s3_uri = self ._handle_user_code_url (code )
552
+ user_code_s3_uri = self ._handle_user_code_url (code , kms_key )
550
553
user_script_name = self ._get_user_code_name (code )
551
554
552
555
inputs_with_code = self ._convert_code_and_add_to_inputs (inputs , user_code_s3_uri )
@@ -567,14 +570,16 @@ def _get_user_code_name(self, code):
567
570
code_url = urlparse (code )
568
571
return os .path .basename (code_url .path )
569
572
570
- def _handle_user_code_url (self , code ):
573
+ def _handle_user_code_url (self , code , kms_key = None ):
571
574
"""Gets the S3 URL containing the user's code.
572
575
573
576
Inspects the scheme the customer passed in ("s3://" for code in S3, "file://" or nothing
574
577
for absolute or local file paths. Uploads the code to S3 if the code is a local file.
575
578
576
579
Args:
577
580
code (str): A URL to the customer's code.
581
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
582
+ user code file (default: None).
578
583
579
584
Returns:
580
585
str: The S3 URL to the customer's code.
@@ -603,7 +608,7 @@ def _handle_user_code_url(self, code):
603
608
code
604
609
)
605
610
)
606
- user_code_s3_uri = self ._upload_code (code_path )
611
+ user_code_s3_uri = self ._upload_code (code_path , kms_key )
607
612
else :
608
613
raise ValueError (
609
614
"code {} url scheme {} is not recognized. Please pass a file path or S3 url" .format (
@@ -612,11 +617,13 @@ def _handle_user_code_url(self, code):
612
617
)
613
618
return user_code_s3_uri
614
619
615
- def _upload_code (self , code ):
620
+ def _upload_code (self , code , kms_key = None ):
616
621
"""Uploads a code file or directory specified as a string and returns the S3 URI.
617
622
618
623
Args:
619
624
code (str): A file or directory to be uploaded to S3.
625
+ kms_key (str): The ARN of the KMS key that is used to encrypt the
626
+ user code file (default: None).
620
627
621
628
Returns:
622
629
str: The S3 URI of the uploaded file or directory.
@@ -630,7 +637,10 @@ def _upload_code(self, code):
630
637
self ._CODE_CONTAINER_INPUT_NAME ,
631
638
)
632
639
return s3 .S3Uploader .upload (
633
- local_path = code , desired_s3_uri = desired_s3_uri , sagemaker_session = self .sagemaker_session
640
+ local_path = code ,
641
+ desired_s3_uri = desired_s3_uri ,
642
+ kms_key = kms_key ,
643
+ sagemaker_session = self .sagemaker_session ,
634
644
)
635
645
636
646
def _convert_code_and_add_to_inputs (self , inputs , s3_uri ):
@@ -666,7 +676,9 @@ def _set_entrypoint(self, command, user_script_name):
666
676
"""
667
677
user_script_location = str (
668
678
pathlib .PurePosixPath (
669
- self ._CODE_CONTAINER_BASE_PATH , self ._CODE_CONTAINER_INPUT_NAME , user_script_name
679
+ self ._CODE_CONTAINER_BASE_PATH ,
680
+ self ._CODE_CONTAINER_INPUT_NAME ,
681
+ user_script_name ,
670
682
)
671
683
)
672
684
self .entrypoint = command + [user_script_location ]
@@ -1066,7 +1078,10 @@ def _to_request_dict(self):
1066
1078
"""Generates a request dictionary using the parameters provided to the class."""
1067
1079
1068
1080
# Create the request dictionary.
1069
- s3_input_request = {"InputName" : self .input_name , "AppManaged" : self .app_managed }
1081
+ s3_input_request = {
1082
+ "InputName" : self .input_name ,
1083
+ "AppManaged" : self .app_managed ,
1084
+ }
1070
1085
1071
1086
if self .s3_input :
1072
1087
# Check the compression type, then add it to the dictionary.
0 commit comments