Skip to content

Commit cd2b23d

Browse files
authored
fix: remove unrestrictive principal * from KMS policy tests. (#712)
* pass kms id as parameter for uploading code with Server side encryption
1 parent c2bac8f commit cd2b23d

File tree

2 files changed

+106
-37
lines changed

2 files changed

+106
-37
lines changed

tests/integ/kms_utils.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,27 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import contextlib
16+
import json
17+
1518
from botocore import exceptions
1619

17-
KEY_ALIAS = "SageMakerIntegTestKmsKey"
20+
PRINCIPAL_TEMPLATE = '["{account_id}", "{role_arn}", ' \
21+
'"arn:aws:iam::{account_id}:role/{sagemaker_role}"] '
22+
23+
KEY_ALIAS = 'SageMakerTestKMSKey'
24+
KMS_S3_ALIAS = 'SageMakerTestS3KMSKey'
25+
POLICY_NAME = 'default'
1826
KEY_POLICY = '''
1927
{{
2028
"Version": "2012-10-17",
21-
"Id": "sagemaker-kms-integ-test-policy",
29+
"Id": "{id}",
2230
"Statement": [
2331
{{
2432
"Sid": "Enable IAM User Permissions",
2533
"Effect": "Allow",
2634
"Principal": {{
27-
"AWS": "*"
35+
"AWS": {principal}
2836
}},
2937
"Action": "kms:*",
3038
"Resource": "*"
@@ -42,22 +50,75 @@ def _get_kms_key_arn(kms_client, alias):
4250
return None
4351

4452

45-
def _create_kms_key(kms_client, account_id):
53+
def _get_kms_key_id(kms_client, alias):
54+
try:
55+
response = kms_client.describe_key(KeyId='alias/' + alias)
56+
return response['KeyMetadata']['KeyId']
57+
except kms_client.exceptions.NotFoundException:
58+
return None
59+
60+
61+
def _create_kms_key(kms_client,
62+
account_id,
63+
role_arn=None,
64+
sagemaker_role='SageMakerRole',
65+
alias=KEY_ALIAS):
66+
if role_arn:
67+
principal = PRINCIPAL_TEMPLATE.format(account_id=account_id,
68+
role_arn=role_arn,
69+
sagemaker_role=sagemaker_role)
70+
else:
71+
principal = "{account_id}".format(account_id=account_id)
72+
4673
response = kms_client.create_key(
47-
Policy=KEY_POLICY.format(account_id=account_id),
74+
Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal, sagemaker_role=sagemaker_role),
4875
Description='KMS key for SageMaker Python SDK integ tests',
4976
)
5077
key_arn = response['KeyMetadata']['Arn']
51-
response = kms_client.create_alias(AliasName='alias/' + KEY_ALIAS, TargetKeyId=key_arn)
78+
79+
if alias:
80+
kms_client.create_alias(AliasName='alias/' + alias, TargetKeyId=key_arn)
5281
return key_arn
5382

5483

55-
def get_or_create_kms_key(kms_client, account_id):
56-
kms_key_arn = _get_kms_key_arn(kms_client, KEY_ALIAS)
57-
if kms_key_arn is not None:
58-
return kms_key_arn
59-
else:
60-
return _create_kms_key(kms_client, account_id)
84+
def _add_role_to_policy(kms_client,
85+
account_id,
86+
role_arn,
87+
alias=KEY_ALIAS,
88+
sagemaker_role='SageMakerRole'):
89+
key_id = _get_kms_key_id(kms_client, alias)
90+
policy = kms_client.get_key_policy(KeyId=key_id, PolicyName=POLICY_NAME)
91+
policy = json.loads(policy['Policy'])
92+
principal = policy['Statement'][0]['Principal']['AWS']
93+
94+
if role_arn not in principal or sagemaker_role not in principal:
95+
principal = PRINCIPAL_TEMPLATE.format(account_id=account_id,
96+
role_arn=role_arn,
97+
sagemaker_role=sagemaker_role)
98+
99+
kms_client.put_key_policy(KeyId=key_id,
100+
PolicyName=POLICY_NAME,
101+
Policy=KEY_POLICY.format(id=POLICY_NAME, principal=principal))
102+
103+
104+
def get_or_create_kms_key(kms_client,
105+
account_id,
106+
role_arn=None,
107+
alias=KEY_ALIAS,
108+
sagemaker_role='SageMakerRole'):
109+
kms_key_arn = _get_kms_key_arn(kms_client, alias)
110+
111+
if kms_key_arn is None:
112+
return _create_kms_key(kms_client, account_id, role_arn, sagemaker_role, alias)
113+
114+
if role_arn:
115+
_add_role_to_policy(kms_client,
116+
account_id,
117+
role_arn,
118+
alias,
119+
sagemaker_role)
120+
121+
return kms_key_arn
61122

62123

63124
KMS_BUCKET_POLICY = """{
@@ -92,9 +153,13 @@ def get_or_create_kms_key(kms_client, account_id):
92153
}"""
93154

94155

95-
def get_or_create_bucket_with_encryption(boto_session):
156+
@contextlib.contextmanager
157+
def bucket_with_encryption(boto_session, sagemaker_role):
96158
account = boto_session.client('sts').get_caller_identity()['Account']
97-
kms_key_arn = get_or_create_kms_key(boto_session.client('kms'), account)
159+
role_arn = boto_session.client('sts').get_caller_identity()['Arn']
160+
161+
kms_client = boto_session.client('kms')
162+
kms_key_arn = _create_kms_key(kms_client, account, role_arn, sagemaker_role, None)
98163

99164
region = boto_session.region_name
100165
bucket_name = 'sagemaker-{}-{}-with-kms'.format(region, account)
@@ -132,4 +197,6 @@ def get_or_create_bucket_with_encryption(boto_session):
132197
Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)
133198
)
134199

135-
return 's3://' + bucket_name, kms_key_arn
200+
yield 's3://' + bucket_name, kms_key_arn
201+
202+
kms_client.schedule_key_deletion(KeyId=kms_key_arn, PendingWindowInDays=7)

tests/integ/test_tf_script_mode.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from tests.integ import kms_utils
2626
import tests.integ.timeout as timeout
2727

28+
ROLE = 'SageMakerRole'
29+
2830
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'data', 'tensorflow_mnist')
2931
SCRIPT = os.path.join(RESOURCE_PATH, 'mnist.py')
3032
PARAMETER_SERVER_DISTRIBUTION = {'parameter_server': {'enabled': True}}
@@ -56,39 +58,39 @@ def test_mnist(sagemaker_session, instance_type):
5658
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta'])
5759

5860

59-
@pytest.mark.skip('this test is broken')
6061
def test_server_side_encryption(sagemaker_session):
6162

62-
bucket_with_kms, kms_key = kms_utils.get_or_create_bucket_with_encryption(sagemaker_session.boto_session)
63+
boto_session = sagemaker_session.boto_session
64+
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
6365

64-
output_path = os.path.join(bucket_with_kms, 'test-server-side-encryption', time.strftime('%y%m%d-%H%M'))
66+
output_path = os.path.join(bucket_with_kms, 'test-server-side-encryption', time.strftime('%y%m%d-%H%M'))
6567

66-
estimator = TensorFlow(entry_point=SCRIPT,
67-
role='SageMakerRole',
68-
train_instance_count=1,
69-
train_instance_type='ml.c5.xlarge',
70-
sagemaker_session=sagemaker_session,
71-
py_version='py3',
72-
framework_version='1.11',
73-
base_job_name='test-server-side-encryption',
74-
code_location=output_path,
75-
output_path=output_path,
76-
model_dir='/opt/ml/model',
77-
output_kms_key=kms_key)
68+
estimator = TensorFlow(entry_point=SCRIPT,
69+
role=ROLE,
70+
train_instance_count=1,
71+
train_instance_type='ml.c5.xlarge',
72+
sagemaker_session=sagemaker_session,
73+
py_version='py3',
74+
framework_version='1.11',
75+
base_job_name='test-server-side-encryption',
76+
code_location=output_path,
77+
output_path=output_path,
78+
model_dir='/opt/ml/model',
79+
output_kms_key=kms_key)
7880

79-
inputs = estimator.sagemaker_session.upload_data(
80-
path=os.path.join(RESOURCE_PATH, 'data'),
81-
key_prefix='scriptmode/mnist')
81+
inputs = estimator.sagemaker_session.upload_data(
82+
path=os.path.join(RESOURCE_PATH, 'data'),
83+
key_prefix='scriptmode/mnist')
8284

83-
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
84-
estimator.fit(inputs)
85+
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
86+
estimator.fit(inputs)
8587

8688

8789
@pytest.mark.canary_quick
8890
@pytest.mark.skipif(integ.PYTHON_VERSION != 'py3', reason="Script Mode tests are only configured to run with Python 3")
8991
def test_mnist_distributed(sagemaker_session, instance_type):
9092
estimator = TensorFlow(entry_point=SCRIPT,
91-
role='SageMakerRole',
93+
role=ROLE,
9294
train_instance_count=2,
9395
# TODO: change train_instance_type to instance_type once the test is passing consistently
9496
train_instance_type='ml.c5.xlarge',
@@ -110,7 +112,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
110112

111113
def test_mnist_async(sagemaker_session):
112114
estimator = TensorFlow(entry_point=SCRIPT,
113-
role='SageMakerRole',
115+
role=ROLE,
114116
train_instance_count=1,
115117
train_instance_type='ml.c5.4xlarge',
116118
sagemaker_session=sagemaker_session,

0 commit comments

Comments
 (0)