Skip to content

fix: allow kms encryption upload for processing #1897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 21, 2020
21 changes: 18 additions & 3 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(
def run(
self,
inputs=None,
kms_key=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe later in the arg list. if a customer is relying on ordinal position rather naming args on invocation, then this would be a breaking change.

outputs=None,
arguments=None,
wait=True,
Expand All @@ -126,6 +127,8 @@ def run(
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
Expand Down Expand Up @@ -153,6 +156,7 @@ def run(
job_name=job_name,
arguments=arguments,
inputs=inputs,
kms_key=kms_key,
outputs=outputs,
)

Expand All @@ -170,7 +174,9 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable
"""Extend inputs and outputs based on extra parameters"""
return inputs, outputs

def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=None, code=None):
def _normalize_args(
self, job_name=None, arguments=None, inputs=None, kms_key=None, outputs=None, code=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make the code and kms_key as adjacent args, preferably after code, since they are the args that are associated?

Suggested change
self, job_name=None, arguments=None, inputs=None, kms_key=None, outputs=None, code=None
self, job_name=None, arguments=None, inputs=None, outputs=None, code=None, kms_key=None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it makes more sense to put relevant args together.

):
"""Normalizes the arguments so that they can be passed to the job run

Args:
Expand All @@ -182,6 +188,8 @@ def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=No
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
Expand All @@ -191,7 +199,7 @@ def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=No
self._current_job_name = self._generate_current_job_name(job_name=job_name)

inputs_with_code = self._include_code_in_inputs(inputs, code)
normalized_inputs = self._normalize_inputs(inputs_with_code)
normalized_inputs = self._normalize_inputs(inputs_with_code, kms_key)
normalized_outputs = self._normalize_outputs(outputs)
self.arguments = arguments

Expand Down Expand Up @@ -233,13 +241,15 @@ def _generate_current_job_name(self, job_name=None):

return name_from_base(base_name)

def _normalize_inputs(self, inputs=None):
def _normalize_inputs(self, inputs=None, kms_key=None):
"""Ensures that all the ``ProcessingInput`` objects have names and S3 URIs.

Args:
inputs (list[sagemaker.processing.ProcessingInput]): A list of ``ProcessingInput``
objects to be normalized (default: None). If not specified,
an empty list is returned.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).

Returns:
list[sagemaker.processing.ProcessingInput]: The list of normalized
Expand Down Expand Up @@ -273,6 +283,7 @@ def _normalize_inputs(self, inputs=None):
local_path=file_input.source,
desired_s3_uri=desired_s3_uri,
sagemaker_session=self.sagemaker_session,
kms_key=kms_key,
)
file_input.source = s3_uri
normalized_inputs.append(file_input)
Expand Down Expand Up @@ -406,6 +417,7 @@ def run(
self,
code,
inputs=None,
kms_key=None,
outputs=None,
arguments=None,
wait=True,
Expand All @@ -421,6 +433,8 @@ def run(
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
Expand All @@ -439,6 +453,7 @@ def run(
job_name=job_name,
arguments=arguments,
inputs=inputs,
kms_key=kms_key,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

push this to after code?

outputs=outputs,
code=code,
)
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def upload(local_path, desired_s3_uri, kms_key=None, sagemaker_session=None):
sagemaker_session = sagemaker_session or Session()
bucket, key_prefix = parse_s3_url(url=desired_s3_uri)
if kms_key is not None:
extra_args = {"SSEKMSKeyId": kms_key}
extra_args = {"SSEKMSKeyId": kms_key, "ServerSideEncryption": "aws:kms"}

else:
extra_args = None

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/spark/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def run(
self,
submit_app,
inputs=None,
kms_key=None,
outputs=None,
arguments=None,
wait=True,
Expand All @@ -188,6 +189,8 @@ def run(
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
Expand All @@ -207,6 +210,7 @@ def run(
super().run(
submit_app,
inputs,
kms_key,
outputs,
arguments,
wait,
Expand Down
14 changes: 14 additions & 0 deletions tests/integ/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def volume_kms_key(sagemaker_session):
)


@pytest.fixture(scope="module")
def input_kms_key(sagemaker_session):
role_arn = sagemaker_session.expand_role(ROLE)
return get_or_create_kms_key(
sagemaker_session=sagemaker_session,
role_arn=role_arn,
alias="integ-test-processing-input-kms-key-{}".format(
sagemaker_session.boto_session.region_name
),
)


@pytest.fixture(scope="module")
def output_kms_key(sagemaker_session):
role_arn = sagemaker_session.expand_role(ROLE)
Expand Down Expand Up @@ -584,6 +596,7 @@ def test_processor_with_custom_bucket(
image_uri,
cpu_instance_type,
output_kms_key,
input_kms_key,
):
script_path = os.path.join(DATA_DIR, "dummy_script.py")

Expand All @@ -609,6 +622,7 @@ def test_processor_with_custom_bucket(
source=script_path, destination="/opt/ml/processing/input/code/", input_name="code"
)
],
kms_key=input_kms_key,
outputs=[
ProcessingOutput(
source="/opt/ml/processing/output/container/path/",
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/spark/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_configuration_validation(config, expected, sagemaker_session) -> None:
def test_spark_processor_base_run(mock_super_run, spark_processor_base):
spark_processor_base.run(submit_app="app")

mock_super_run.assert_called_with("app", None, None, None, True, True, None, None)
mock_super_run.assert_called_with("app", None, None, None, None, True, True, None, None)


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_upload_with_kms_key(sagemaker_session):
path="/path/to/app.jar",
bucket=BUCKET_NAME,
key_prefix=os.path.join(CURRENT_JOB_NAME, SOURCE_NAME),
extra_args={"SSEKMSKeyId": KMS_KEY},
extra_args={"SSEKMSKeyId": KMS_KEY, "ServerSideEncryption": "aws:kms"},
)


Expand Down