Skip to content

Commit 7fa9f32

Browse files
authored
change: add integ test for tagging (#735)
* change: add integ test for tagging * Minor changes based on pr comments * use unique_name_from_base in test_server_side_encryption
1 parent 2091718 commit 7fa9f32

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

tests/integ/test_tf_script_mode.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import boto3
2222
from sagemaker.tensorflow import TensorFlow
2323
from six.moves.urllib.parse import urlparse
24+
from sagemaker.utils import unique_name_from_base
2425
import tests.integ as integ
2526
from tests.integ import kms_utils
2627
import tests.integ.timeout as timeout
@@ -31,6 +32,7 @@
3132
SCRIPT = os.path.join(RESOURCE_PATH, 'mnist.py')
3233
PARAMETER_SERVER_DISTRIBUTION = {'parameter_server': {'enabled': True}}
3334
MPI_DISTRIBUTION = {'mpi': {'enabled': True}}
35+
TAGS = [{'Key': 'some-key', 'Value': 'some-value'}]
3436

3537

3638
@pytest.fixture(scope='session', params=['ml.c5.xlarge', 'ml.p2.xlarge'])
@@ -48,7 +50,7 @@ def test_mnist(sagemaker_session, instance_type):
4850
py_version='py3',
4951
framework_version=TensorFlow.LATEST_VERSION,
5052
metric_definitions=[{'Name': 'train:global_steps', 'Regex': r'global_step\/sec:\s(.*)'}],
51-
base_job_name='test-tf-sm-mnist')
53+
base_job_name=unique_name_from_base('test-tf-sm-mnist'))
5254
inputs = estimator.sagemaker_session.upload_data(
5355
path=os.path.join(RESOURCE_PATH, 'data'),
5456
key_prefix='scriptmode/mnist')
@@ -76,7 +78,7 @@ def test_server_side_encryption(sagemaker_session):
7678
sagemaker_session=sagemaker_session,
7779
py_version='py3',
7880
framework_version='1.11',
79-
base_job_name='test-server-side-encryption',
81+
base_job_name=unique_name_from_base('test-server-side-encryption'),
8082
code_location=output_path,
8183
output_path=output_path,
8284
model_dir='/opt/ml/model',
@@ -103,7 +105,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
103105
script_mode=True,
104106
framework_version=TensorFlow.LATEST_VERSION,
105107
distributions=PARAMETER_SERVER_DISTRIBUTION,
106-
base_job_name='test-tf-sm-mnist')
108+
base_job_name=unique_name_from_base('test-tf-sm-mnist'))
107109
inputs = estimator.sagemaker_session.upload_data(
108110
path=os.path.join(RESOURCE_PATH, 'data'),
109111
key_prefix='scriptmode/distributed_mnist')
@@ -122,21 +124,25 @@ def test_mnist_async(sagemaker_session):
122124
sagemaker_session=sagemaker_session,
123125
py_version='py3',
124126
framework_version=TensorFlow.LATEST_VERSION,
125-
base_job_name='test-tf-sm-mnist')
127+
base_job_name=unique_name_from_base('test-tf-sm-mnist'),
128+
tags=TAGS)
126129
inputs = estimator.sagemaker_session.upload_data(
127130
path=os.path.join(RESOURCE_PATH, 'data'),
128131
key_prefix='scriptmode/mnist')
129132
estimator.fit(inputs, wait=False)
130133
training_job_name = estimator.latest_training_job.name
131134
time.sleep(20)
132135
endpoint_name = training_job_name
136+
_assert_training_job_tags_match(sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS)
133137
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
134138
estimator = TensorFlow.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
135139
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge',
136140
endpoint_name=endpoint_name)
137141

138142
result = predictor.predict(np.zeros(784))
139143
print('predict result: {}'.format(result))
144+
_assert_endpoint_tags_match(sagemaker_session.sagemaker_client, predictor.endpoint, TAGS)
145+
_assert_model_tags_match(sagemaker_session.sagemaker_client, estimator.latest_training_job.name, TAGS)
140146

141147

142148
def _assert_s3_files_exist(s3_url, files):
@@ -147,3 +153,23 @@ def _assert_s3_files_exist(s3_url, files):
147153
found = [x['Key'] for x in contents if x['Key'].endswith(f)]
148154
if not found:
149155
raise ValueError('File {} is not found under {}'.format(f, s3_url))
156+
157+
158+
def _assert_tags_match(sagemaker_client, resource_arn, tags):
159+
actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)['Tags']
160+
assert actual_tags == tags
161+
162+
163+
def _assert_model_tags_match(sagemaker_client, model_name, tags):
164+
model_description = sagemaker_client.describe_model(ModelName=model_name)
165+
_assert_tags_match(sagemaker_client, model_description['ModelArn'], tags)
166+
167+
168+
def _assert_endpoint_tags_match(sagemaker_client, endpoint_name, tags):
169+
endpoint_description = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
170+
_assert_tags_match(sagemaker_client, endpoint_description['EndpointArn'], tags)
171+
172+
173+
def _assert_training_job_tags_match(sagemaker_client, training_job_name, tags):
174+
training_job_description = sagemaker_client.describe_training_job(TrainingJobName=training_job_name)
175+
_assert_tags_match(sagemaker_client, training_job_description['TrainingJobArn'], tags)

0 commit comments

Comments
 (0)