Skip to content

Commit 2763f9a

Browse files
Othmane796laurenyu
authored andcommitted
fix: fix propagation of tags to SageMaker endpoint (#741)
1 parent 3e5f1bd commit 2763f9a

File tree

8 files changed

+59
-16
lines changed

8 files changed

+59
-16
lines changed

src/sagemaker/local/entities.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,15 +327,17 @@ def describe(self):
327327

328328
class _LocalEndpointConfig(object):
329329

330-
def __init__(self, config_name, production_variants):
330+
def __init__(self, config_name, production_variants, tags=None):
331331
self.name = config_name
332332
self.production_variants = production_variants
333+
self.tags = tags
333334
self.creation_time = datetime.datetime.now()
334335

335336
def describe(self):
336337
response = {
337338
'EndpointConfigName': self.name,
338339
'EndpointConfigArn': _UNUSED_ARN,
340+
'Tags': self.tags,
339341
'CreationTime': self.creation_time,
340342
'ProductionVariants': self.production_variants
341343
}
@@ -348,7 +350,7 @@ class _LocalEndpoint(object):
348350
_IN_SERVICE = 'InService'
349351
_FAILED = 'Failed'
350352

351-
def __init__(self, endpoint_name, endpoint_config_name, local_session=None):
353+
def __init__(self, endpoint_name, endpoint_config_name, tags=None, local_session=None):
352354
# runtime import since there is a cyclic dependency between entities and local_session
353355
from sagemaker.local import LocalSession
354356
self.local_session = local_session or LocalSession()
@@ -357,6 +359,7 @@ def __init__(self, endpoint_name, endpoint_config_name, local_session=None):
357359
self.name = endpoint_name
358360
self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name)
359361
self.production_variant = self.endpoint_config['ProductionVariants'][0]
362+
self.tags = tags
360363

361364
model_name = self.production_variant['ModelName']
362365
self.primary_container = local_client.describe_model(model_name)['PrimaryContainer']
@@ -392,6 +395,7 @@ def describe(self):
392395
'EndpointConfigName': self.endpoint_config['EndpointConfigName'],
393396
'CreationTime': self.create_time,
394397
'ProductionVariants': self.endpoint_config['ProductionVariants'],
398+
'Tags': self.tags,
395399
'EndpointName': self.name,
396400
'EndpointArn': _UNUSED_ARN,
397401
'EndpointStatus': self.state

src/sagemaker/local/local_session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ def describe_endpoint_config(self, EndpointConfigName):
127127
'Code': 'ValidationException', 'Message': 'Could not find local endpoint config'}}
128128
raise ClientError(error_response, 'describe_endpoint_config')
129129

130-
def create_endpoint_config(self, EndpointConfigName, ProductionVariants):
130+
def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None):
131131
LocalSagemakerClient._endpoint_configs[EndpointConfigName] = _LocalEndpointConfig(
132-
EndpointConfigName, ProductionVariants)
132+
EndpointConfigName, ProductionVariants, Tags)
133133

134134
def describe_endpoint(self, EndpointName):
135135
if EndpointName not in LocalSagemakerClient._endpoints:
@@ -138,8 +138,8 @@ def describe_endpoint(self, EndpointName):
138138
else:
139139
return LocalSagemakerClient._endpoints[EndpointName].describe()
140140

141-
def create_endpoint(self, EndpointName, EndpointConfigName):
142-
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, self.sagemaker_session)
141+
def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None):
142+
endpoint = _LocalEndpoint(EndpointName, EndpointConfigName, Tags, self.sagemaker_session)
143143
LocalSagemakerClient._endpoints[EndpointName] = endpoint
144144
endpoint.serve()
145145

src/sagemaker/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ def deploy(self, initial_instance_count, instance_type, accelerator_type=None, e
271271
model_name=self.name,
272272
initial_instance_count=initial_instance_count,
273273
instance_type=instance_type,
274-
accelerator_type=accelerator_type)
274+
accelerator_type=accelerator_type,
275+
tags=tags)
275276
self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name)
276277
else:
277278
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)

src/sagemaker/session.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ def create_endpoint_config(self, name, model_name, initial_instance_count, insta
749749
)
750750
return name
751751

752-
def create_endpoint(self, endpoint_name, config_name, wait=True):
752+
def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
753753
"""Create an Amazon SageMaker ``Endpoint`` according to the endpoint configuration specified in the request.
754754
755755
Once the ``Endpoint`` is created, client applications can send requests to obtain inferences.
@@ -764,7 +764,10 @@ def create_endpoint(self, endpoint_name, config_name, wait=True):
764764
str: Name of the Amazon SageMaker ``Endpoint`` created.
765765
"""
766766
LOGGER.info('Creating endpoint with name {}'.format(endpoint_name))
767-
self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name)
767+
768+
tags = tags or []
769+
770+
self.sagemaker_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags)
768771
if wait:
769772
self.wait_for_endpoint(endpoint_name)
770773
return endpoint_name
@@ -1052,7 +1055,7 @@ def endpoint_from_production_variants(self, name, production_variants, tags=None
10521055
config_options['Tags'] = tags
10531056

10541057
self.sagemaker_client.create_endpoint_config(**config_options)
1055-
return self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
1058+
return self.create_endpoint(endpoint_name=name, config_name=name, tags=tags, wait=wait)
10561059

10571060
def expand_role(self, role):
10581061
"""Expand an IAM role name into an ARN.

tests/integ/test_mxnet_train.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,37 @@ def test_deploy_model(mxnet_training_job, sagemaker_session, mxnet_full_version)
7878
assert 'Could not find model' in str(exception.value)
7979

8080

81+
def test_deploy_model_with_tags(mxnet_training_job, sagemaker_session, mxnet_full_version):
82+
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
83+
84+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
85+
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=mxnet_training_job)
86+
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
87+
script_path = os.path.join(DATA_DIR, 'mxnet_mnist', 'mnist.py')
88+
model = MXNetModel(model_data, 'SageMakerRole', entry_point=script_path,
89+
py_version=PYTHON_VERSION, sagemaker_session=sagemaker_session,
90+
framework_version=mxnet_full_version)
91+
tags = [{'Key': 'TagtestKey', 'Value': 'TagtestValue'}]
92+
model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name, tags=tags)
93+
94+
returned_model = sagemaker_session.describe_model(EndpointName=model.name)
95+
returned_model_tags = sagemaker_session.list_tags(ResourceArn=returned_model['ModelArn'])['Tags']
96+
97+
endpoint = sagemaker_session.describe_endpoint(EndpointName=endpoint_name)
98+
endpoint_tags = sagemaker_session.list_tags(ResourceArn=endpoint['EndpointArn'])['Tags']
99+
100+
endpoint_config = sagemaker_session.describe_endpoint_config(EndpointConfigName=endpoint['EndpointConfigName'])
101+
endpoint_config_tags = sagemaker_session.list_tags(ResourceArn=endpoint_config['EndpointConfigArn'])['Tags']
102+
103+
production_variants = endpoint_config['ProductionVariants']
104+
105+
assert returned_model_tags == tags
106+
assert endpoint_config_tags == tags
107+
assert endpoint_tags == tags
108+
assert production_variants[0]['InstanceType'] == 'ml.m4.xlarge'
109+
assert production_variants[0]['InitialInstanceCount'] == 1
110+
111+
81112
def test_deploy_model_with_update_endpoint(mxnet_training_job, sagemaker_session, mxnet_full_version):
82113
endpoint_name = 'test-mxnet-deploy-model-{}'.format(sagemaker_timestamp())
83114

tests/unit/test_create_deploy_entities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_create_endpoint_no_wait(sagemaker_session):
9696

9797
assert returned_name == ENDPOINT_NAME
9898
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
99-
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
99+
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[])
100100

101101

102102
def test_create_endpoint_wait(sagemaker_session):
@@ -105,5 +105,5 @@ def test_create_endpoint_wait(sagemaker_session):
105105

106106
assert returned_name == ENDPOINT_NAME
107107
sagemaker_session.sagemaker_client.create_endpoint.assert_called_once_with(
108-
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME)
108+
EndpointName=ENDPOINT_NAME, EndpointConfigName=ENDPOINT_CONFIG_NAME, Tags=[])
109109
sagemaker_session.wait_for_endpoint.assert_called_once_with(ENDPOINT_NAME)

tests/unit/test_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ def test_deploy_update_endpoint(sagemaker_session, tmpdir):
245245
model_name=model.name,
246246
initial_instance_count=INSTANCE_COUNT,
247247
instance_type=INSTANCE_TYPE,
248-
accelerator_type=ACCELERATOR_TYPE
248+
accelerator_type=ACCELERATOR_TYPE,
249+
tags=None
249250
)
250251
config_name = sagemaker_session.create_endpoint_config(
251252
name=model.name,

tests/unit/test_session.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,8 @@ def test_endpoint_from_production_variants(sagemaker_session):
910910
ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex)
911911
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs)
912912
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
913-
EndpointName='some-endpoint')
913+
EndpointName='some-endpoint',
914+
Tags=[])
914915
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
915916
EndpointConfigName='some-endpoint',
916917
ProductionVariants=pvs)
@@ -936,7 +937,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
936937
tags = [{'ModelName': 'TestModel'}]
937938
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
938939
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
939-
EndpointName='some-endpoint')
940+
EndpointName='some-endpoint',
941+
Tags=tags)
940942
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
941943
EndpointConfigName='some-endpoint',
942944
ProductionVariants=pvs,
@@ -953,7 +955,8 @@ def test_endpoint_from_production_variants_with_accelerator_type(sagemaker_sessi
953955
tags = [{'ModelName': 'TestModel'}]
954956
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
955957
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
956-
EndpointName='some-endpoint')
958+
EndpointName='some-endpoint',
959+
Tags=tags)
957960
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
958961
EndpointConfigName='some-endpoint',
959962
ProductionVariants=pvs,

0 commit comments

Comments
 (0)