Skip to content

Commit ee76822

Browse files
authored
Pass kms id as parameter for uploading code with Server side encryption (#693)
* pass kms id as parameter for uploading code with Server side encryption
1 parent a504db4 commit ee76822

File tree

7 files changed

+209
-8
lines changed

7 files changed

+209
-8
lines changed

CHANGELOG.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
CHANGELOG
33
=========
44

5-
1.18.5dev
6-
======
5+
1.18.5.dev
6+
==========
77

8+
* bug-fix: pass kms id as parameter for uploading code with Server side encryption
89
* feature: ``PipelineModel``: Create a Transformer from a PipelineModel
910
* bug-fix: ``AlgorithmEstimator``: Make SupportedHyperParameters optional
1011

src/sagemaker/estimator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -862,19 +862,31 @@ def _stage_user_code_in_s3(self):
862862
Returns: s3 uri
863863
864864
"""
865-
if self.code_location is None:
865+
local_mode = self.output_path.startswith('file://')
866+
867+
if self.code_location is None and local_mode:
866868
code_bucket = self.sagemaker_session.default_bucket()
867869
code_s3_prefix = '{}/source'.format(self._current_job_name)
870+
kms_key = None
871+
872+
elif self.code_location is None:
873+
code_bucket, _ = parse_s3_url(self.output_path)
874+
code_s3_prefix = '{}/source'.format(self._current_job_name)
875+
kms_key = self.output_kms_key
868876
else:
869877
code_bucket, key_prefix = parse_s3_url(self.code_location)
870878
code_s3_prefix = '/'.join(filter(None, [key_prefix, self._current_job_name, 'source']))
871879

880+
output_bucket, _ = parse_s3_url(self.output_path)
881+
kms_key = self.output_kms_key if code_bucket == output_bucket else None
882+
872883
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
873884
bucket=code_bucket,
874885
s3_key_prefix=code_s3_prefix,
875886
script=self.entry_point,
876887
directory=self.source_dir,
877-
dependencies=self.dependencies)
888+
dependencies=self.dependencies,
889+
kms_key=kms_key)
878890

879891
def _model_source_dir(self):
880892
"""Get the appropriate value to pass as source_dir to model constructor on deploying

src/sagemaker/fw_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def validate_source_dir(script, directory):
136136
return True
137137

138138

139-
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, dependencies=None):
139+
def tar_and_upload_dir(session, bucket, s3_key_prefix, script,
140+
directory=None, dependencies=None, kms_key=None):
140141
"""Package source files and upload a compress tar file to S3. The S3 location will be
141142
``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
142143
@@ -159,6 +160,7 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, depend
159160
dependencies (List[str]): Optional. A list of paths to directories (absolute or relative)
160161
containing additional libraries that will be copied into
161162
/opt/ml/lib
163+
kms_key (str): Optional. KMS key ID used to upload objects to the bucket (default: None).
162164
163165
Returns:
164166
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
@@ -177,7 +179,12 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory, depend
177179
tar_file = sagemaker.utils.create_tar_file(source_files,
178180
os.path.join(tmp, _TAR_SOURCE_FILENAME))
179181

180-
session.resource('s3').Object(bucket, key).upload_file(tar_file)
182+
if kms_key:
183+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': kms_key}
184+
else:
185+
extra_args = None
186+
187+
session.resource('s3').Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
181188
finally:
182189
shutil.rmtree(tmp)
183190

tests/integ/kms_utils.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
from botocore import exceptions
16+
1517
KEY_ALIAS = "SageMakerIntegTestKmsKey"
1618
KEY_POLICY = '''
1719
{{
@@ -22,7 +24,7 @@
2224
"Sid": "Enable IAM User Permissions",
2325
"Effect": "Allow",
2426
"Principal": {{
25-
"AWS": "{account_id}"
27+
"AWS": "*"
2628
}},
2729
"Action": "kms:*",
2830
"Resource": "*"
@@ -56,3 +58,78 @@ def get_or_create_kms_key(kms_client, account_id):
5658
return kms_key_arn
5759
else:
5860
return _create_kms_key(kms_client, account_id)
61+
62+
63+
KMS_BUCKET_POLICY = """{
64+
"Version": "2012-10-17",
65+
"Id": "PutObjPolicy",
66+
"Statement": [
67+
{
68+
"Sid": "DenyIncorrectEncryptionHeader",
69+
"Effect": "Deny",
70+
"Principal": "*",
71+
"Action": "s3:PutObject",
72+
"Resource": "arn:aws:s3:::%s/*",
73+
"Condition": {
74+
"StringNotEquals": {
75+
"s3:x-amz-server-side-encryption": "aws:kms"
76+
}
77+
}
78+
},
79+
{
80+
"Sid": "DenyUnEncryptedObjectUploads",
81+
"Effect": "Deny",
82+
"Principal": "*",
83+
"Action": "s3:PutObject",
84+
"Resource": "arn:aws:s3:::%s/*",
85+
"Condition": {
86+
"Null": {
87+
"s3:x-amz-server-side-encryption": "true"
88+
}
89+
}
90+
}
91+
]
92+
}"""
93+
94+
95+
def get_or_create_bucket_with_encryption(boto_session):
96+
account = boto_session.client('sts').get_caller_identity()['Account']
97+
kms_key_arn = get_or_create_kms_key(boto_session.client('kms'), account)
98+
99+
region = boto_session.region_name
100+
bucket_name = 'sagemaker-{}-{}-with-kms'.format(region, account)
101+
102+
s3 = boto_session.client('s3')
103+
try:
104+
# 'us-east-1' cannot be specified because it is the default region:
105+
# https://github.com/boto/boto3/issues/125
106+
if region == 'us-east-1':
107+
s3.create_bucket(Bucket=bucket_name)
108+
else:
109+
s3.create_bucket(Bucket=bucket_name,
110+
CreateBucketConfiguration={'LocationConstraint': region})
111+
112+
except exceptions.ClientError as e:
113+
if e.response['Error']['Code'] != 'BucketAlreadyOwnedByYou':
114+
raise
115+
116+
s3.put_bucket_encryption(
117+
Bucket=bucket_name,
118+
ServerSideEncryptionConfiguration={
119+
'Rules': [
120+
{
121+
'ApplyServerSideEncryptionByDefault': {
122+
'SSEAlgorithm': 'aws:kms',
123+
'KMSMasterKeyID': kms_key_arn
124+
}
125+
},
126+
]
127+
}
128+
)
129+
130+
s3.put_bucket_policy(
131+
Bucket=bucket_name,
132+
Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)
133+
)
134+
135+
return 's3://' + bucket_name, kms_key_arn

tests/integ/test_tf_script_mode.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414

1515
import numpy as np
1616
import os
17-
import pytest
1817
import time
1918

19+
import pytest
20+
2021
import boto3
2122
from sagemaker.tensorflow import TensorFlow
2223
from six.moves.urllib.parse import urlparse
2324
import tests.integ as integ
25+
from tests.integ import kms_utils
2426
import tests.integ.timeout as timeout
2527

2628
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'data', 'tensorflow_mnist')
@@ -54,6 +56,33 @@ def test_mnist(sagemaker_session, instance_type):
5456
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta'])
5557

5658

59+
def test_server_side_encryption(sagemaker_session):
60+
61+
bucket_with_kms, kms_key = kms_utils.get_or_create_bucket_with_encryption(sagemaker_session.boto_session)
62+
63+
output_path = os.path.join(bucket_with_kms, 'test-server-side-encryption', time.strftime('%y%m%d-%H%M'))
64+
65+
estimator = TensorFlow(entry_point=SCRIPT,
66+
role='SageMakerRole',
67+
train_instance_count=1,
68+
train_instance_type='ml.c5.xlarge',
69+
sagemaker_session=sagemaker_session,
70+
py_version='py3',
71+
framework_version='1.11',
72+
base_job_name='test-server-side-encryption',
73+
code_location=output_path,
74+
output_path=output_path,
75+
model_dir='/opt/ml/model',
76+
output_kms_key=kms_key)
77+
78+
inputs = estimator.sagemaker_session.upload_data(
79+
path=os.path.join(RESOURCE_PATH, 'data'),
80+
key_prefix='scriptmode/mnist')
81+
82+
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
83+
estimator.fit(inputs)
84+
85+
5786
@pytest.mark.canary_quick
5887
@pytest.mark.skipif(integ.PYTHON_VERSION != 'py3', reason="Script Mode tests are only configured to run with Python 3")
5988
def test_mnist_distributed(sagemaker_session, instance_type):

tests/unit/test_estimator.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,64 @@ def test_container_log_level(sagemaker_session):
761761
assert train_kwargs['hyperparameters']['sagemaker_container_log_level'] == '10'
762762

763763

764+
@patch('sagemaker.utils')
765+
def test_same_code_location_keeps_kms_key(utils, sagemaker_session):
766+
fw = DummyFramework(entry_point=SCRIPT_PATH,
767+
role='DummyRole',
768+
sagemaker_session=sagemaker_session,
769+
train_instance_count=INSTANCE_COUNT,
770+
train_instance_type=INSTANCE_TYPE,
771+
output_kms_key='kms-key')
772+
773+
fw.fit(wait=False)
774+
775+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
776+
obj = sagemaker_session.boto_session.resource('s3').Object
777+
778+
obj.assert_called_with('mybucket', '%s/source/sourcedir.tar.gz' % fw._current_job_name)
779+
780+
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
781+
782+
783+
@patch('sagemaker.utils')
784+
def test_different_code_location_kms_key(utils, sagemaker_session):
785+
fw = DummyFramework(entry_point=SCRIPT_PATH,
786+
role='DummyRole',
787+
sagemaker_session=sagemaker_session,
788+
code_location='s3://another-location',
789+
train_instance_count=INSTANCE_COUNT,
790+
train_instance_type=INSTANCE_TYPE,
791+
output_kms_key='kms-key')
792+
793+
fw.fit(wait=False)
794+
795+
obj = sagemaker_session.boto_session.resource('s3').Object
796+
797+
obj.assert_called_with('another-location', '%s/source/sourcedir.tar.gz' % fw._current_job_name)
798+
799+
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=None)
800+
801+
802+
@patch('sagemaker.utils')
803+
def test_default_code_location_uses_output_path(utils, sagemaker_session):
804+
fw = DummyFramework(entry_point=SCRIPT_PATH,
805+
role='DummyRole',
806+
sagemaker_session=sagemaker_session,
807+
output_path='s3://output_path',
808+
train_instance_count=INSTANCE_COUNT,
809+
train_instance_type=INSTANCE_TYPE,
810+
output_kms_key='kms-key')
811+
812+
fw.fit(wait=False)
813+
814+
obj = sagemaker_session.boto_session.resource('s3').Object
815+
816+
obj.assert_called_with('output_path', '%s/source/sourcedir.tar.gz' % fw._current_job_name)
817+
818+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
819+
obj().upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
820+
821+
764822
def test_wait_without_logs(sagemaker_session):
765823
training_job = _TrainingJob(sagemaker_session, JOB_NAME)
766824

tests/unit/test_fw_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,23 @@ def test_tar_and_upload_dir_s3(sagemaker_session):
165165
assert result == fw_utils.UploadedCode('s3://m', 'mnist.py')
166166

167167

168+
@patch('sagemaker.utils')
169+
def test_tar_and_upload_dir_s3_with_kms(utils, sagemaker_session):
170+
171+
result = fw_utils.tar_and_upload_dir(sagemaker_session,
172+
'mybucker',
173+
'something/source',
174+
'mnist.py',
175+
kms_key='kms-key')
176+
177+
assert result == fw_utils.UploadedCode('s3://mybucker/something/source/sourcedir.tar.gz',
178+
'mnist.py')
179+
180+
extra_args = {'ServerSideEncryption': 'aws:kms', 'SSEKMSKeyId': 'kms-key'}
181+
obj = sagemaker_session.resource('s3').Object('', '')
182+
obj.upload_file.assert_called_with(utils.create_tar_file(), ExtraArgs=extra_args)
183+
184+
168185
def test_validate_source_dir_does_not_exits(sagemaker_session):
169186
script = 'mnist.py'
170187
directory = ' !@#$%^&*()path probably in not there.!@#$%^&*()'

0 commit comments

Comments
 (0)