12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
15
+ import contextlib
16
+ import json
17
+
15
18
from botocore import exceptions
16
19
17
- KEY_ALIAS = "SageMakerIntegTestKmsKey"
20
+ PRINCIPAL_TEMPLATE = '["{account_id}", "{role_arn}", ' \
21
+ '"arn:aws:iam::{account_id}:role/{sagemaker_role}"] '
22
+
23
+ KEY_ALIAS = 'SageMakerTestKMSKey'
24
+ KMS_S3_ALIAS = 'SageMakerTestS3KMSKey'
25
+ POLICY_NAME = 'default'
18
26
KEY_POLICY = '''
19
27
{{
20
28
"Version": "2012-10-17",
21
- "Id": "sagemaker-kms-integ-test-policy ",
29
+ "Id": "{id} ",
22
30
"Statement": [
23
31
{{
24
32
"Sid": "Enable IAM User Permissions",
25
33
"Effect": "Allow",
26
34
"Principal": {{
27
- "AWS": "*"
35
+ "AWS": {principal}
28
36
}},
29
37
"Action": "kms:*",
30
38
"Resource": "*"
@@ -42,22 +50,75 @@ def _get_kms_key_arn(kms_client, alias):
42
50
return None
43
51
44
52
45
- def _create_kms_key (kms_client , account_id ):
53
+ def _get_kms_key_id (kms_client , alias ):
54
+ try :
55
+ response = kms_client .describe_key (KeyId = 'alias/' + alias )
56
+ return response ['KeyMetadata' ]['KeyId' ]
57
+ except kms_client .exceptions .NotFoundException :
58
+ return None
59
+
60
+
61
+ def _create_kms_key (kms_client ,
62
+ account_id ,
63
+ role_arn = None ,
64
+ sagemaker_role = 'SageMakerRole' ,
65
+ alias = KEY_ALIAS ):
66
+ if role_arn :
67
+ principal = PRINCIPAL_TEMPLATE .format (account_id = account_id ,
68
+ role_arn = role_arn ,
69
+ sagemaker_role = sagemaker_role )
70
+ else :
71
+ principal = "{account_id}" .format (account_id = account_id )
72
+
46
73
response = kms_client .create_key (
47
- Policy = KEY_POLICY .format (account_id = account_id ),
74
+ Policy = KEY_POLICY .format (id = POLICY_NAME , principal = principal , sagemaker_role = sagemaker_role ),
48
75
Description = 'KMS key for SageMaker Python SDK integ tests' ,
49
76
)
50
77
key_arn = response ['KeyMetadata' ]['Arn' ]
51
- response = kms_client .create_alias (AliasName = 'alias/' + KEY_ALIAS , TargetKeyId = key_arn )
78
+
79
+ if alias :
80
+ kms_client .create_alias (AliasName = 'alias/' + alias , TargetKeyId = key_arn )
52
81
return key_arn
53
82
54
83
55
- def get_or_create_kms_key (kms_client , account_id ):
56
- kms_key_arn = _get_kms_key_arn (kms_client , KEY_ALIAS )
57
- if kms_key_arn is not None :
58
- return kms_key_arn
59
- else :
60
- return _create_kms_key (kms_client , account_id )
84
+ def _add_role_to_policy (kms_client ,
85
+ account_id ,
86
+ role_arn ,
87
+ alias = KEY_ALIAS ,
88
+ sagemaker_role = 'SageMakerRole' ):
89
+ key_id = _get_kms_key_id (kms_client , alias )
90
+ policy = kms_client .get_key_policy (KeyId = key_id , PolicyName = POLICY_NAME )
91
+ policy = json .loads (policy ['Policy' ])
92
+ principal = policy ['Statement' ][0 ]['Principal' ]['AWS' ]
93
+
94
+ if role_arn not in principal or sagemaker_role not in principal :
95
+ principal = PRINCIPAL_TEMPLATE .format (account_id = account_id ,
96
+ role_arn = role_arn ,
97
+ sagemaker_role = sagemaker_role )
98
+
99
+ kms_client .put_key_policy (KeyId = key_id ,
100
+ PolicyName = POLICY_NAME ,
101
+ Policy = KEY_POLICY .format (id = POLICY_NAME , principal = principal ))
102
+
103
+
104
+ def get_or_create_kms_key (kms_client ,
105
+ account_id ,
106
+ role_arn = None ,
107
+ alias = KEY_ALIAS ,
108
+ sagemaker_role = 'SageMakerRole' ):
109
+ kms_key_arn = _get_kms_key_arn (kms_client , alias )
110
+
111
+ if kms_key_arn is None :
112
+ return _create_kms_key (kms_client , account_id , role_arn , sagemaker_role , alias )
113
+
114
+ if role_arn :
115
+ _add_role_to_policy (kms_client ,
116
+ account_id ,
117
+ role_arn ,
118
+ alias ,
119
+ sagemaker_role )
120
+
121
+ return kms_key_arn
61
122
62
123
63
124
KMS_BUCKET_POLICY = """{
@@ -92,9 +153,13 @@ def get_or_create_kms_key(kms_client, account_id):
92
153
}"""
93
154
94
155
95
- def get_or_create_bucket_with_encryption (boto_session ):
156
+ @contextlib .contextmanager
157
+ def bucket_with_encryption (boto_session , sagemaker_role ):
96
158
account = boto_session .client ('sts' ).get_caller_identity ()['Account' ]
97
- kms_key_arn = get_or_create_kms_key (boto_session .client ('kms' ), account )
159
+ role_arn = boto_session .client ('sts' ).get_caller_identity ()['Arn' ]
160
+
161
+ kms_client = boto_session .client ('kms' )
162
+ kms_key_arn = _create_kms_key (kms_client , account , role_arn , sagemaker_role , None )
98
163
99
164
region = boto_session .region_name
100
165
bucket_name = 'sagemaker-{}-{}-with-kms' .format (region , account )
@@ -132,4 +197,6 @@ def get_or_create_bucket_with_encryption(boto_session):
132
197
Policy = KMS_BUCKET_POLICY % (bucket_name , bucket_name )
133
198
)
134
199
135
- return 's3://' + bucket_name , kms_key_arn
200
+ yield 's3://' + bucket_name , kms_key_arn
201
+
202
+ kms_client .schedule_key_deletion (KeyId = kms_key_arn , PendingWindowInDays = 7 )
0 commit comments