Skip to content

fix: allow kms encryption upload for processing #1897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Sep 21, 2020
27 changes: 24 additions & 3 deletions src/sagemaker/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def run(
logs=True,
job_name=None,
experiment_config=None,
kms_key=None,
):
"""Runs a processing job.

Expand All @@ -139,6 +140,8 @@ def run(
experiment_config (dict[str, str]): Experiment management configuration.
Dictionary contains three optional keys:
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).

Raises:
ValueError: if ``logs`` is True but ``wait`` is False.
Expand All @@ -153,6 +156,7 @@ def run(
job_name=job_name,
arguments=arguments,
inputs=inputs,
kms_key=kms_key,
outputs=outputs,
)

Expand All @@ -170,7 +174,15 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable
"""Extend inputs and outputs based on extra parameters"""
return inputs, outputs

def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=None, code=None):
def _normalize_args(
self,
job_name=None,
arguments=None,
inputs=None,
outputs=None,
code=None,
kms_key=None,
):
"""Normalizes the arguments so that they can be passed to the job run

Args:
Expand All @@ -182,6 +194,8 @@ def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=No
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): Input files for
the processing job. These must be provided as
:class:`~sagemaker.processing.ProcessingInput` objects (default: None).
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): Outputs for
the processing job. These can be specified as either path strings or
:class:`~sagemaker.processing.ProcessingOutput` objects (default: None).
Expand All @@ -191,7 +205,7 @@ def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=No
self._current_job_name = self._generate_current_job_name(job_name=job_name)

inputs_with_code = self._include_code_in_inputs(inputs, code)
normalized_inputs = self._normalize_inputs(inputs_with_code)
normalized_inputs = self._normalize_inputs(inputs_with_code, kms_key)
normalized_outputs = self._normalize_outputs(outputs)
self.arguments = arguments

Expand Down Expand Up @@ -233,13 +247,15 @@ def _generate_current_job_name(self, job_name=None):

return name_from_base(base_name)

def _normalize_inputs(self, inputs=None):
def _normalize_inputs(self, inputs=None, kms_key=None):
"""Ensures that all the ``ProcessingInput`` objects have names and S3 URIs.

Args:
inputs (list[sagemaker.processing.ProcessingInput]): A list of ``ProcessingInput``
objects to be normalized (default: None). If not specified,
an empty list is returned.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).

Returns:
list[sagemaker.processing.ProcessingInput]: The list of normalized
Expand Down Expand Up @@ -273,6 +289,7 @@ def _normalize_inputs(self, inputs=None):
local_path=file_input.source,
desired_s3_uri=desired_s3_uri,
sagemaker_session=self.sagemaker_session,
kms_key=kms_key,
)
file_input.source = s3_uri
normalized_inputs.append(file_input)
Expand Down Expand Up @@ -412,6 +429,7 @@ def run(
logs=True,
job_name=None,
experiment_config=None,
kms_key=None,
):
"""Runs a processing job.

Expand All @@ -434,13 +452,16 @@ def run(
experiment_config (dict[str, str]): Experiment management configuration.
Dictionary contains three optional keys:
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
"""
normalized_inputs, normalized_outputs = self._normalize_args(
job_name=job_name,
arguments=arguments,
inputs=inputs,
outputs=outputs,
code=code,
kms_key=kms_key,
)

self.latest_job = ProcessingJob.start_new(
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def upload(local_path, desired_s3_uri, kms_key=None, sagemaker_session=None):
sagemaker_session = sagemaker_session or Session()
bucket, key_prefix = parse_s3_url(url=desired_s3_uri)
if kms_key is not None:
extra_args = {"SSEKMSKeyId": kms_key}
extra_args = {"SSEKMSKeyId": kms_key, "ServerSideEncryption": "aws:kms"}

else:
extra_args = None

Expand Down
11 changes: 11 additions & 0 deletions src/sagemaker/spark/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def run(
logs=True,
job_name=None,
experiment_config=None,
kms_key=None,
):
"""Runs a processing job.

Expand All @@ -201,6 +202,8 @@ def run(
experiment_config (dict[str, str]): Experiment management configuration.
Dictionary contains three optional keys:
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

Expand All @@ -213,6 +216,7 @@ def run(
logs,
job_name,
experiment_config,
kms_key,
)

def _extend_processing_args(self, inputs, outputs, **kwargs):
Expand Down Expand Up @@ -695,6 +699,7 @@ def run(
experiment_config=None,
configuration=None,
spark_event_logs_s3_uri=None,
kms_key=None,
):
"""Runs a processing job.

Expand Down Expand Up @@ -728,6 +733,8 @@ def run(
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
spark_event_logs_s3_uri (str): S3 path where spark application events will
be published to.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

Expand Down Expand Up @@ -872,6 +879,7 @@ def run(
experiment_config=None,
configuration=None,
spark_event_logs_s3_uri=None,
kms_key=None,
):
"""Runs a processing job.

Expand Down Expand Up @@ -905,6 +913,8 @@ def run(
https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html
spark_event_logs_s3_uri (str): S3 path where spark application events will
be published to.
kms_key (str): The ARN of the KMS key that is used to encrypt the
user code file (default: None).
"""
self._current_job_name = self._generate_current_job_name(job_name=job_name)

Expand All @@ -930,6 +940,7 @@ def run(
logs=logs,
job_name=self._current_job_name,
experiment_config=experiment_config,
kms_key=kms_key,
)

def _extend_processing_args(self, inputs, outputs, **kwargs):
Expand Down
14 changes: 14 additions & 0 deletions tests/integ/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def volume_kms_key(sagemaker_session):
)


@pytest.fixture(scope="module")
def input_kms_key(sagemaker_session):
role_arn = sagemaker_session.expand_role(ROLE)
return get_or_create_kms_key(
sagemaker_session=sagemaker_session,
role_arn=role_arn,
alias="integ-test-processing-input-kms-key-{}".format(
sagemaker_session.boto_session.region_name
),
)


@pytest.fixture(scope="module")
def output_kms_key(sagemaker_session):
role_arn = sagemaker_session.expand_role(ROLE)
Expand Down Expand Up @@ -584,6 +596,7 @@ def test_processor_with_custom_bucket(
image_uri,
cpu_instance_type,
output_kms_key,
input_kms_key,
):
script_path = os.path.join(DATA_DIR, "dummy_script.py")

Expand All @@ -609,6 +622,7 @@ def test_processor_with_custom_bucket(
source=script_path, destination="/opt/ml/processing/input/code/", input_name="code"
)
],
kms_key=input_kms_key,
outputs=[
ProcessingOutput(
source="/opt/ml/processing/output/container/path/",
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/sagemaker/spark/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_configuration_validation(config, expected, sagemaker_session) -> None:
def test_spark_processor_base_run(mock_super_run, spark_processor_base):
spark_processor_base.run(submit_app="app")

mock_super_run.assert_called_with("app", None, None, None, True, True, None, None)
mock_super_run.assert_called_with("app", None, None, None, True, True, None, None, None)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -854,6 +854,7 @@ def test_spark_jar_processor_run(
logs=True,
job_name="jobName",
experiment_config=None,
kms_key=None,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_upload_with_kms_key(sagemaker_session):
path="/path/to/app.jar",
bucket=BUCKET_NAME,
key_prefix=os.path.join(CURRENT_JOB_NAME, SOURCE_NAME),
extra_args={"SSEKMSKeyId": KMS_KEY},
extra_args={"SSEKMSKeyId": KMS_KEY, "ServerSideEncryption": "aws:kms"},
)


Expand Down