Skip to content

Commit 7bb613e

Browse files
chuyang-dengChuyang Deng
and
Chuyang Deng
authored
fix: allow kms encryption upload for processing (#1897)
* fix: allow kms encryption upload for processing * update spark processing with kms keys= * move kms_key to end of arg list Co-authored-by: Chuyang Deng <[email protected]>
1 parent 233e9b9 commit 7bb613e

File tree

6 files changed

+54
-6
lines changed

6 files changed

+54
-6
lines changed

src/sagemaker/processing.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def run(
119119
logs=True,
120120
job_name=None,
121121
experiment_config=None,
122+
kms_key=None,
122123
):
123124
"""Runs a processing job.
124125
@@ -139,6 +140,8 @@ def run(
139140
experiment_config (dict[str, str]): Experiment management configuration.
140141
Dictionary contains three optional keys:
141142
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
143+
kms_key (str): The ARN of the KMS key that is used to encrypt the
144+
user code file (default: None).
142145
143146
Raises:
144147
ValueError: if ``logs`` is True but ``wait`` is False.
@@ -153,6 +156,7 @@ def run(
153156
job_name=job_name,
154157
arguments=arguments,
155158
inputs=inputs,
159+
kms_key=kms_key,
156160
outputs=outputs,
157161
)
158162

@@ -170,7 +174,15 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable
170174
"""Extend inputs and outputs based on extra parameters"""
171175
return inputs, outputs
172176

173-
def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=None, code=None):
177+
def _normalize_args(
178+
self,
179+
job_name=None,
180+
arguments=None,
181+
inputs=None,
182+
outputs=None,
183+
code=None,
184+
kms_key=None,
185+
):
174186
"""Normalizes the arguments so that they can be passed to the job run
175187
176188
Args:
@@ -182,6 +194,8 @@ def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=No
182194
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
183195
the processing job. These must be provided as
184196
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
197+
kms_key (str): The ARN of the KMS key that is used to encrypt the
198+
user code file (default: None).
185199
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
186200
the processing job. These can be specified as either path strings or
187201
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
@@ -191,7 +205,7 @@ def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=No
191205
self._current_job_name = self._generate_current_job_name(job_name=job_name)
192206

193207
inputs_with_code = self._include_code_in_inputs(inputs, code)
194-
normalized_inputs = self._normalize_inputs(inputs_with_code)
208+
normalized_inputs = self._normalize_inputs(inputs_with_code, kms_key)
195209
normalized_outputs = self._normalize_outputs(outputs)
196210
self.arguments = arguments
197211

@@ -233,13 +247,15 @@ def _generate_current_job_name(self, job_name=None):
233247

234248
return name_from_base(base_name)
235249

236-
def _normalize_inputs(self, inputs=None):
250+
def _normalize_inputs(self, inputs=None, kms_key=None):
237251
"""Ensures that all the ``ProcessingInput`` objects have names and S3 URIs.
238252
239253
Args:
240254
inputs (list[sagemaker.processing.ProcessingInput]): A list of ``ProcessingInput``
241255
objects to be normalized (default: None). If not specified,
242256
an empty list is returned.
257+
kms_key (str): The ARN of the KMS key that is used to encrypt the
258+
user code file (default: None).
243259
244260
Returns:
245261
list[sagemaker.processing.ProcessingInput]: The list of normalized
@@ -273,6 +289,7 @@ def _normalize_inputs(self, inputs=None):
273289
local_path=file_input.source,
274290
desired_s3_uri=desired_s3_uri,
275291
sagemaker_session=self.sagemaker_session,
292+
kms_key=kms_key,
276293
)
277294
file_input.source = s3_uri
278295
normalized_inputs.append(file_input)
@@ -412,6 +429,7 @@ def run(
412429
logs=True,
413430
job_name=None,
414431
experiment_config=None,
432+
kms_key=None,
415433
):
416434
"""Runs a processing job.
417435
@@ -434,13 +452,16 @@ def run(
434452
experiment_config (dict[str, str]): Experiment management configuration.
435453
Dictionary contains three optional keys:
436454
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
455+
kms_key (str): The ARN of the KMS key that is used to encrypt the
456+
user code file (default: None).
437457
"""
438458
normalized_inputs, normalized_outputs = self._normalize_args(
439459
job_name=job_name,
440460
arguments=arguments,
441461
inputs=inputs,
442462
outputs=outputs,
443463
code=code,
464+
kms_key=kms_key,
444465
)
445466

446467
self.latest_job = ProcessingJob.start_new(

src/sagemaker/s3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def upload(local_path, desired_s3_uri, kms_key=None, sagemaker_session=None):
8383
sagemaker_session = sagemaker_session or Session()
8484
bucket, key_prefix = parse_s3_url(url=desired_s3_uri)
8585
if kms_key is not None:
86-
extra_args = {"SSEKMSKeyId": kms_key}
86+
extra_args = {"SSEKMSKeyId": kms_key, "ServerSideEncryption": "aws:kms"}
87+
8788
else:
8889
extra_args = None
8990

src/sagemaker/spark/processing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def run(
180180
logs=True,
181181
job_name=None,
182182
experiment_config=None,
183+
kms_key=None,
183184
):
184185
"""Runs a processing job.
185186
@@ -201,6 +202,8 @@ def run(
201202
experiment_config (dict[str, str]): Experiment management configuration.
202203
Dictionary contains three optional keys:
203204
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
205+
kms_key (str): The ARN of the KMS key that is used to encrypt the
206+
user code file (default: None).
204207
"""
205208
self._current_job_name = self._generate_current_job_name(job_name=job_name)
206209

@@ -213,6 +216,7 @@ def run(
213216
logs,
214217
job_name,
215218
experiment_config,
219+
kms_key,
216220
)
217221

218222
def _extend_processing_args(self, inputs, outputs, **kwargs):
@@ -695,6 +699,7 @@ def run(
695699
experiment_config=None,
696700
configuration=None,
697701
spark_event_logs_s3_uri=None,
702+
kms_key=None,
698703
):
699704
"""Runs a processing job.
700705
@@ -728,6 +733,8 @@ def run(
728733
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
729734
spark_event_logs_s3_uri (str): S3 path where spark application events will
730735
be published to.
736+
kms_key (str): The ARN of the KMS key that is used to encrypt the
737+
user code file (default: None).
731738
"""
732739
self._current_job_name = self._generate_current_job_name(job_name=job_name)
733740

@@ -872,6 +879,7 @@ def run(
872879
experiment_config=None,
873880
configuration=None,
874881
spark_event_logs_s3_uri=None,
882+
kms_key=None,
875883
):
876884
"""Runs a processing job.
877885
@@ -905,6 +913,8 @@ def run(
905913
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
906914
spark_event_logs_s3_uri (str): S3 path where spark application events will
907915
be published to.
916+
kms_key (str): The ARN of the KMS key that is used to encrypt the
917+
user code file (default: None).
908918
"""
909919
self._current_job_name = self._generate_current_job_name(job_name=job_name)
910920

@@ -930,6 +940,7 @@ def run(
930940
logs=logs,
931941
job_name=self._current_job_name,
932942
experiment_config=experiment_config,
943+
kms_key=kms_key,
933944
)
934945

935946
def _extend_processing_args(self, inputs, outputs, **kwargs):

tests/integ/test_processing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ def volume_kms_key(sagemaker_session):
8585
)
8686

8787

88+
@pytest.fixture(scope="module")
89+
def input_kms_key(sagemaker_session):
90+
role_arn = sagemaker_session.expand_role(ROLE)
91+
return get_or_create_kms_key(
92+
sagemaker_session=sagemaker_session,
93+
role_arn=role_arn,
94+
alias="integ-test-processing-input-kms-key-{}".format(
95+
sagemaker_session.boto_session.region_name
96+
),
97+
)
98+
99+
88100
@pytest.fixture(scope="module")
89101
def output_kms_key(sagemaker_session):
90102
role_arn = sagemaker_session.expand_role(ROLE)
@@ -584,6 +596,7 @@ def test_processor_with_custom_bucket(
584596
image_uri,
585597
cpu_instance_type,
586598
output_kms_key,
599+
input_kms_key,
587600
):
588601
script_path = os.path.join(DATA_DIR, "dummy_script.py")
589602

@@ -609,6 +622,7 @@ def test_processor_with_custom_bucket(
609622
source=script_path, destination="/opt/ml/processing/input/code/", input_name="code"
610623
)
611624
],
625+
kms_key=input_kms_key,
612626
outputs=[
613627
ProcessingOutput(
614628
source="/opt/ml/processing/output/container/path/",

tests/unit/sagemaker/spark/test_processing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def test_configuration_validation(config, expected, sagemaker_session) -> None:
200200
def test_spark_processor_base_run(mock_super_run, spark_processor_base):
201201
spark_processor_base.run(submit_app="app")
202202

203-
mock_super_run.assert_called_with("app", None, None, None, True, True, None, None)
203+
mock_super_run.assert_called_with("app", None, None, None, True, True, None, None, None)
204204

205205

206206
@pytest.mark.parametrize(
@@ -854,6 +854,7 @@ def test_spark_jar_processor_run(
854854
logs=True,
855855
job_name="jobName",
856856
experiment_config=None,
857+
kms_key=None,
857858
)
858859

859860

tests/unit/test_s3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_upload_with_kms_key(sagemaker_session):
6767
path="/path/to/app.jar",
6868
bucket=BUCKET_NAME,
6969
key_prefix=os.path.join(CURRENT_JOB_NAME, SOURCE_NAME),
70-
extra_args={"SSEKMSKeyId": KMS_KEY},
70+
extra_args={"SSEKMSKeyId": KMS_KEY, "ServerSideEncryption": "aws:kms"},
7171
)
7272

7373

0 commit comments

Comments
 (0)