21
21
import boto3
22
22
from sagemaker .tensorflow import TensorFlow
23
23
from six .moves .urllib .parse import urlparse
24
+ from sagemaker .utils import unique_name_from_base
24
25
import tests .integ as integ
25
26
from tests .integ import kms_utils
26
27
import tests .integ .timeout as timeout
31
32
SCRIPT = os .path .join (RESOURCE_PATH , 'mnist.py' )
32
33
PARAMETER_SERVER_DISTRIBUTION = {'parameter_server' : {'enabled' : True }}
33
34
MPI_DISTRIBUTION = {'mpi' : {'enabled' : True }}
35
+ TAGS = [{'Key' : 'some-key' , 'Value' : 'some-value' }]
34
36
35
37
36
38
@pytest .fixture (scope = 'session' , params = ['ml.c5.xlarge' , 'ml.p2.xlarge' ])
@@ -48,7 +50,7 @@ def test_mnist(sagemaker_session, instance_type):
48
50
py_version = 'py3' ,
49
51
framework_version = TensorFlow .LATEST_VERSION ,
50
52
metric_definitions = [{'Name' : 'train:global_steps' , 'Regex' : r'global_step\/sec:\s(.*)' }],
51
- base_job_name = 'test-tf-sm-mnist' )
53
+ base_job_name = unique_name_from_base ( 'test-tf-sm-mnist' ) )
52
54
inputs = estimator .sagemaker_session .upload_data (
53
55
path = os .path .join (RESOURCE_PATH , 'data' ),
54
56
key_prefix = 'scriptmode/mnist' )
@@ -76,7 +78,7 @@ def test_server_side_encryption(sagemaker_session):
76
78
sagemaker_session = sagemaker_session ,
77
79
py_version = 'py3' ,
78
80
framework_version = '1.11' ,
79
- base_job_name = 'test-server-side-encryption' ,
81
+ base_job_name = unique_name_from_base ( 'test-server-side-encryption' ) ,
80
82
code_location = output_path ,
81
83
output_path = output_path ,
82
84
model_dir = '/opt/ml/model' ,
@@ -103,7 +105,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
103
105
script_mode = True ,
104
106
framework_version = TensorFlow .LATEST_VERSION ,
105
107
distributions = PARAMETER_SERVER_DISTRIBUTION ,
106
- base_job_name = 'test-tf-sm-mnist' )
108
+ base_job_name = unique_name_from_base ( 'test-tf-sm-mnist' ) )
107
109
inputs = estimator .sagemaker_session .upload_data (
108
110
path = os .path .join (RESOURCE_PATH , 'data' ),
109
111
key_prefix = 'scriptmode/distributed_mnist' )
@@ -122,21 +124,25 @@ def test_mnist_async(sagemaker_session):
122
124
sagemaker_session = sagemaker_session ,
123
125
py_version = 'py3' ,
124
126
framework_version = TensorFlow .LATEST_VERSION ,
125
- base_job_name = 'test-tf-sm-mnist' )
127
+ base_job_name = unique_name_from_base ('test-tf-sm-mnist' ),
128
+ tags = TAGS )
126
129
inputs = estimator .sagemaker_session .upload_data (
127
130
path = os .path .join (RESOURCE_PATH , 'data' ),
128
131
key_prefix = 'scriptmode/mnist' )
129
132
estimator .fit (inputs , wait = False )
130
133
training_job_name = estimator .latest_training_job .name
131
134
time .sleep (20 )
132
135
endpoint_name = training_job_name
136
+ _assert_training_job_tags_match (sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS )
133
137
with timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
134
138
estimator = TensorFlow .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
135
139
predictor = estimator .deploy (initial_instance_count = 1 , instance_type = 'ml.c4.xlarge' ,
136
140
endpoint_name = endpoint_name )
137
141
138
142
result = predictor .predict (np .zeros (784 ))
139
143
print ('predict result: {}' .format (result ))
144
+ _assert_endpoint_tags_match (sagemaker_session .sagemaker_client , predictor .endpoint , TAGS )
145
+ _assert_model_tags_match (sagemaker_session .sagemaker_client , estimator .latest_training_job .name , TAGS )
140
146
141
147
142
148
def _assert_s3_files_exist (s3_url , files ):
@@ -147,3 +153,23 @@ def _assert_s3_files_exist(s3_url, files):
147
153
found = [x ['Key' ] for x in contents if x ['Key' ].endswith (f )]
148
154
if not found :
149
155
raise ValueError ('File {} is not found under {}' .format (f , s3_url ))
156
+
157
+
158
+ def _assert_tags_match (sagemaker_client , resource_arn , tags ):
159
+ actual_tags = sagemaker_client .list_tags (ResourceArn = resource_arn )['Tags' ]
160
+ assert actual_tags == tags
161
+
162
+
163
+ def _assert_model_tags_match (sagemaker_client , model_name , tags ):
164
+ model_description = sagemaker_client .describe_model (ModelName = model_name )
165
+ _assert_tags_match (sagemaker_client , model_description ['ModelArn' ], tags )
166
+
167
+
168
+ def _assert_endpoint_tags_match (sagemaker_client , endpoint_name , tags ):
169
+ endpoint_description = sagemaker_client .describe_endpoint (EndpointName = endpoint_name )
170
+ _assert_tags_match (sagemaker_client , endpoint_description ['EndpointArn' ], tags )
171
+
172
+
173
+ def _assert_training_job_tags_match (sagemaker_client , training_job_name , tags ):
174
+ training_job_description = sagemaker_client .describe_training_job (TrainingJobName = training_job_name )
175
+ _assert_tags_match (sagemaker_client , training_job_description ['TrainingJobArn' ], tags )
0 commit comments