Skip to content

Commit 79f860e

Browse files
committed
Add principal to the policy
1 parent 50ef480 commit 79f860e

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

tests/integ/kms_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import contextlib
1516
import json
1617

1718
from botocore import exceptions
@@ -74,7 +75,9 @@ def _create_kms_key(kms_client,
7475
Description='KMS key for SageMaker Python SDK integ tests',
7576
)
7677
key_arn = response['KeyMetadata']['Arn']
77-
response = kms_client.create_alias(AliasName='alias/' + alias, TargetKeyId=key_arn)
78+
79+
if alias:
80+
kms_client.create_alias(AliasName='alias/' + alias, TargetKeyId=key_arn)
7881
return key_arn
7982

8083

@@ -150,14 +153,13 @@ def get_or_create_kms_key(kms_client,
150153
}"""
151154

152155

153-
def get_or_create_bucket_with_encryption(boto_session, sagemaker_role):
156+
@contextlib.contextmanager
157+
def bucket_with_encryption(boto_session, sagemaker_role):
154158
account = boto_session.client('sts').get_caller_identity()['Account']
155159
role_arn = boto_session.client('sts').get_caller_identity()['Arn']
156-
kms_key_arn = get_or_create_kms_key(boto_session.client('kms'),
157-
account,
158-
role_arn,
159-
alias=KMS_S3_ALIAS,
160-
sagemaker_role=sagemaker_role)
160+
161+
kms_client = boto_session.client('kms')
162+
kms_key_arn = _create_kms_key(kms_client, account, role_arn, sagemaker_role, None)
161163

162164
region = boto_session.region_name
163165
bucket_name = 'sagemaker-{}-{}-with-kms'.format(region, account)
@@ -195,4 +197,6 @@ def get_or_create_bucket_with_encryption(boto_session, sagemaker_role):
195197
Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)
196198
)
197199

198-
return 's3://' + bucket_name, kms_key_arn
200+
yield 's3://' + bucket_name, kms_key_arn
201+
202+
kms_client.schedule_key_deletion(KeyId=kms_key_arn, PendingWindowInDays=7)

tests/integ/test_tf_script_mode.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,30 @@ def test_mnist(sagemaker_session, instance_type):
6060

6161
def test_server_side_encryption(sagemaker_session):
6262

63-
bucket_with_kms, kms_key = kms_utils.get_or_create_bucket_with_encryption(sagemaker_session.boto_session,
64-
ROLE)
65-
66-
output_path = os.path.join(bucket_with_kms, 'test-server-side-encryption', time.strftime('%y%m%d-%H%M'))
67-
68-
estimator = TensorFlow(entry_point=SCRIPT,
69-
role=ROLE,
70-
train_instance_count=1,
71-
train_instance_type='ml.c5.xlarge',
72-
sagemaker_session=sagemaker_session,
73-
py_version='py3',
74-
framework_version='1.11',
75-
base_job_name='test-server-side-encryption',
76-
code_location=output_path,
77-
output_path=output_path,
78-
model_dir='/opt/ml/model',
79-
output_kms_key=kms_key)
80-
81-
inputs = estimator.sagemaker_session.upload_data(
82-
path=os.path.join(RESOURCE_PATH, 'data'),
83-
key_prefix='scriptmode/mnist')
84-
85-
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
86-
estimator.fit(inputs)
63+
boto_session = sagemaker_session.boto_session
64+
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
65+
66+
output_path = os.path.join(bucket_with_kms, 'test-server-side-encryption', time.strftime('%y%m%d-%H%M'))
67+
68+
estimator = TensorFlow(entry_point=SCRIPT,
69+
role=ROLE,
70+
train_instance_count=1,
71+
train_instance_type='ml.c5.xlarge',
72+
sagemaker_session=sagemaker_session,
73+
py_version='py3',
74+
framework_version='1.11',
75+
base_job_name='test-server-side-encryption',
76+
code_location=output_path,
77+
output_path=output_path,
78+
model_dir='/opt/ml/model',
79+
output_kms_key=kms_key)
80+
81+
inputs = estimator.sagemaker_session.upload_data(
82+
path=os.path.join(RESOURCE_PATH, 'data'),
83+
key_prefix='scriptmode/mnist')
84+
85+
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
86+
estimator.fit(inputs)
8787

8888

8989
@pytest.mark.canary_quick

0 commit comments

Comments
 (0)