Skip to content

Commit d9697b4

Browse files
committed
Add tags on endpoints
1 parent 555cfe7 commit d9697b4

File tree

5 files changed

+61
-7
lines changed

5 files changed

+61
-7
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ CHANGELOG
88
* enhancement: Let Framework models reuse code uploaded by Framework estimators
99
* enhancement: Unify generation of model uploaded code location
1010
* feature: Change minimum required scipy from 1.0.0 to 0.19.0
11+
* feature: Option to add Tags on SageMaker Endpoints
1112

1213
1.5.0
1314
=====

src/sagemaker/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def prepare_container_def(self, instance_type):
6666
"""
6767
return sagemaker.container_def(self.image, self.model_data, self.env)
6868

69-
def deploy(self, initial_instance_count, instance_type, endpoint_name=None):
69+
def deploy(self, initial_instance_count, instance_type, endpoint_name=None, tags=None):
7070
"""Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``.
7171
7272
Create a SageMaker ``Model`` and ``EndpointConfig``, and deploy an ``Endpoint`` from this ``Model``.
@@ -82,6 +82,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None):
8282
``Endpoint`` created from this ``Model``.
8383
endpoint_name (str): The name of the endpoint to create (default: None).
8484
If not specified, a unique endpoint name will be created.
85+
tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint (default: None).
8586
8687
Returns:
8788
callable[string, sagemaker.session.Session] or None: Invocation of ``self.predictor_cls`` on
@@ -98,7 +99,7 @@ def deploy(self, initial_instance_count, instance_type, endpoint_name=None):
9899
self.sagemaker_session.create_model(model_name, self.role, container_def)
99100
production_variant = sagemaker.production_variant(model_name, instance_type, initial_instance_count)
100101
self.endpoint_name = endpoint_name or model_name
101-
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant])
102+
self.sagemaker_session.endpoint_from_production_variants(self.endpoint_name, [production_variant], tags)
102103
if self.predictor_cls:
103104
return self.predictor_cls(self.endpoint_name, self.sagemaker_session)
104105

src/sagemaker/session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -646,12 +646,13 @@ def endpoint_from_model_data(self, model_s3_location, deployment_image, initial_
646646
self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
647647
return name
648648

649-
def endpoint_from_production_variants(self, name, production_variants, wait=True):
649+
def endpoint_from_production_variants(self, name, production_variants, tags=None, wait=True):
650650
"""Create an SageMaker ``Endpoint`` from a list of production variants.
651651
652652
Args:
653653
name (str): The name of the ``Endpoint`` to create.
654654
production_variants (list[dict[str, str]]): The list of production variants to deploy.
655+
tags (list[dict[str, str]]): A list of key-value pairs for tagging the endpoint (default: None).
655656
wait (bool): Whether to wait for the endpoint deployment to complete before returning (default: True).
656657
657658
Returns:
@@ -660,8 +661,12 @@ def endpoint_from_production_variants(self, name, production_variants, wait=True
660661

661662
if not _deployment_entity_exists(
662663
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)):
663-
self.sagemaker_client.create_endpoint_config(
664-
EndpointConfigName=name, ProductionVariants=production_variants)
664+
if tags:
665+
self.sagemaker_client.create_endpoint_config(
666+
EndpointConfigName=name, ProductionVariants=production_variants, Tags=tags)
667+
else:
668+
self.sagemaker_client.create_endpoint_config(
669+
EndpointConfigName=name, ProductionVariants=production_variants)
665670
return self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
666671

667672
def expand_role(self, role):

tests/unit/test_model.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def test_deploy(tfo, time, sagemaker_session):
100100
'ModelName': 'mi-2017-10-10-14-14-15',
101101
'InstanceType': INSTANCE_TYPE,
102102
'InitialInstanceCount': 1,
103-
'VariantName': 'AllTraffic'}])
103+
'VariantName': 'AllTraffic'}],
104+
None)
104105

105106

106107
@patch('tarfile.open')
@@ -114,7 +115,24 @@ def test_deploy_endpoint_name(tfo, time, sagemaker_session):
114115
'ModelName': 'mi-2017-10-10-14-14-15',
115116
'InstanceType': INSTANCE_TYPE,
116117
'InitialInstanceCount': 55,
117-
'VariantName': 'AllTraffic'}])
118+
'VariantName': 'AllTraffic'}],
119+
None)
120+
121+
122+
@patch('tarfile.open')
123+
@patch('time.strftime', return_value=TIMESTAMP)
124+
def test_deploy_tags(tfo, time, sagemaker_session):
125+
model = DummyFrameworkModel(sagemaker_session)
126+
tags = [{'ModelName': 'TestModel'}]
127+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1, tags=tags)
128+
sagemaker_session.endpoint_from_production_variants.assert_called_with(
129+
'mi-2017-10-10-14-14-15',
130+
[{'InitialVariantWeight': 1,
131+
'ModelName': 'mi-2017-10-10-14-14-15',
132+
'InstanceType': INSTANCE_TYPE,
133+
'InitialInstanceCount': 1,
134+
'VariantName': 'AllTraffic'}],
135+
tags)
118136

119137

120138
@patch('sagemaker.model.Session')

tests/unit/test_session.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,35 @@ def test_endpoint_from_production_variants(sagemaker_session):
487487
'VariantName': 'AllTraffic'}])
488488

489489

490+
def test_endpoint_from_production_variants_with_tags(sagemaker_session):
491+
ims = sagemaker_session
492+
ims.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointStatus': 'InService'})
493+
pvs = [sagemaker.production_variant('A', 'ml.p2.xlarge'), sagemaker.production_variant('B', 'p299.4096xlarge')]
494+
ex = ClientError({'Error': {'Code': 'ValidationException', 'Message': 'Could not find your thing'}}, 'b')
495+
ims.sagemaker_client.describe_endpoint_config = Mock(side_effect=ex)
496+
tags = [{'ModelName': 'TestModel'}]
497+
sagemaker_session.endpoint_from_production_variants('some-endpoint', pvs, tags)
498+
sagemaker_session.sagemaker_client.create_endpoint.assert_called_with(EndpointConfigName='some-endpoint',
499+
EndpointName='some-endpoint')
500+
sagemaker_session.sagemaker_client.create_endpoint_config.assert_called_with(
501+
EndpointConfigName='some-endpoint',
502+
ProductionVariants=[
503+
{
504+
'InstanceType': 'ml.p2.xlarge',
505+
'ModelName': 'A',
506+
'InitialVariantWeight': 1,
507+
'InitialInstanceCount': 1,
508+
'VariantName': 'AllTraffic'
509+
},
510+
{
511+
'InstanceType': 'p299.4096xlarge',
512+
'ModelName': 'B',
513+
'InitialVariantWeight': 1,
514+
'InitialInstanceCount': 1,
515+
'VariantName': 'AllTraffic'}],
516+
Tags=tags)
517+
518+
490519
def test_wait_for_tuning_job(sagemaker_session):
491520
hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'}
492521
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(

0 commit comments

Comments
 (0)