Skip to content

Commit ce4ac2d

Browse files
martinRenouakrishna1995
authored andcommitted
More tags formatting and add a test
1 parent 6276eff commit ce4ac2d

31 files changed

+167
-119
lines changed

src/sagemaker/algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def transformer(
392392
if self._is_marketplace():
393393
transform_env = None
394394

395-
tags = tags or self.tags
395+
tags = format_tags(tags) or self.tags
396396
else:
397397
raise RuntimeError("No finished training job found associated with this estimator")
398398

src/sagemaker/apiutils/_base_types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.apiutils import _boto_functions, _utils
17+
from sagemaker.utils import format_tags
1718

1819

1920
class ApiObject(object):
@@ -194,13 +195,13 @@ def _set_tags(self, resource_arn=None, tags=None):
194195
195196
Args:
196197
resource_arn (str): The arn of the Record
197-
tags (dict): An array of Tag objects that set to Record
198+
tags (Optional[Tags]): An array of Tag objects that set to Record
198199
199200
Returns:
200201
A list of key, value pair objects. i.e. [{"key":"value"}]
201202
"""
202203
tag_list = self.sagemaker_session.sagemaker_client.add_tags(
203-
ResourceArn=resource_arn, Tags=tags
204+
ResourceArn=resource_arn, Tags=format_tags(tags)
204205
)["Tags"]
205206
return tag_list
206207

src/sagemaker/automl/automl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def deploy(
580580
be selected on each ``deploy``.
581581
endpoint_name (str): The name of the endpoint to create (default:
582582
None). If not specified, a unique endpoint name will be created.
583-
tags (List[dict[str, str]]): The list of tags to attach to this
583+
tags (Optional[Tags]): The list of tags to attach to this
584584
specific endpoint.
585585
wait (bool): Whether the call should wait until the deployment of
586586
model completes (default: True).
@@ -632,7 +632,7 @@ def deploy(
632632
deserializer=deserializer,
633633
endpoint_name=endpoint_name,
634634
kms_key=model_kms_key,
635-
tags=tags,
635+
tags=format_tags(tags),
636636
wait=wait,
637637
volume_size=volume_size,
638638
model_data_download_timeout=model_data_download_timeout,

src/sagemaker/base_predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
NumpySerializer,
5454
)
5555
from sagemaker.session import production_variant, Session
56-
from sagemaker.utils import name_from_base, stringify_object
56+
from sagemaker.utils import name_from_base, stringify_object, format_tags
5757

5858
from sagemaker.model_monitor.model_monitoring import DEFAULT_REPOSITORY_NAME
5959

@@ -409,7 +409,7 @@ def update_endpoint(
409409
self.sagemaker_session.create_endpoint_config_from_existing(
410410
current_endpoint_config_name,
411411
new_endpoint_config_name,
412-
new_tags=tags,
412+
new_tags=format_tags(tags),
413413
new_kms_key=kms_key,
414414
new_data_capture_config_dict=data_capture_config_dict,
415415
new_production_variants=production_variants,

src/sagemaker/djl_inference/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sagemaker.s3_utils import s3_path_join
3131
from sagemaker.serializers import JSONSerializer, BaseSerializer
3232
from sagemaker.session import Session
33-
from sagemaker.utils import _tmpdir, _create_or_update_code_dir
33+
from sagemaker.utils import _tmpdir, _create_or_update_code_dir, format_tags
3434
from sagemaker.workflow.entities import PipelineVariable
3535
from sagemaker.estimator import Estimator
3636
from sagemaker.s3 import S3Uploader
@@ -610,7 +610,7 @@ def deploy(
610610
default deserializer is set by the ``predictor_cls``.
611611
endpoint_name (str): The name of the endpoint to create (default:
612612
None). If not specified, a unique endpoint name will be created.
613-
tags (List[dict[str, str]]): The list of tags to attach to this
613+
tags (Optional[Tags]): The list of tags to attach to this
614614
specific endpoint.
615615
kms_key (str): The ARN of the KMS key that is used to encrypt the
616616
data on the storage volume attached to the instance hosting the
@@ -651,7 +651,7 @@ def deploy(
651651
serializer=serializer,
652652
deserializer=deserializer,
653653
endpoint_name=endpoint_name,
654-
tags=tags,
654+
tags=format_tags(tags),
655655
kms_key=kms_key,
656656
wait=wait,
657657
data_capture_config=data_capture_config,

src/sagemaker/experiments/experiment.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.experiments.trial import _Trial
2222
from sagemaker.experiments.trial_component import _TrialComponent
23+
from sagemaker.utils import format_tags
2324

2425

2526
class Experiment(_base_types.Record):
@@ -111,7 +112,7 @@ def create(
111112
manages interactions with Amazon SageMaker APIs and any other
112113
AWS services needed. If not specified, one is created using the
113114
default AWS configuration chain.
114-
tags (List[Dict[str, str]]): A list of tags to associate with the experiment
115+
tags (Optional[Tags]): A list of tags to associate with the experiment
115116
(default: None).
116117
117118
Returns:
@@ -122,7 +123,7 @@ def create(
122123
experiment_name=experiment_name,
123124
display_name=display_name,
124125
description=description,
125-
tags=tags,
126+
tags=format_tags(tags),
126127
sagemaker_session=sagemaker_session,
127128
)
128129

@@ -149,7 +150,7 @@ def _load_or_create(
149150
manages interactions with Amazon SageMaker APIs and any other
150151
AWS services needed. If not specified, one is created using the
151152
default AWS configuration chain.
152-
tags (List[Dict[str, str]]): A list of tags to associate with the experiment
153+
tags (Optional[Tags]): A list of tags to associate with the experiment
153154
(default: None). This is used only when the given `experiment_name` does not
154155
exist and a new experiment has to be created.
155156
@@ -161,7 +162,7 @@ def _load_or_create(
161162
experiment_name=experiment_name,
162163
display_name=display_name,
163164
description=description,
164-
tags=tags,
165+
tags=format_tags(tags),
165166
sagemaker_session=sagemaker_session,
166167
)
167168
except ClientError as ce:

src/sagemaker/experiments/trial.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sagemaker.apiutils import _base_types
1919
from sagemaker.experiments import _api_types
2020
from sagemaker.experiments.trial_component import _TrialComponent
21+
from sagemaker.utils import format_tags
2122

2223

2324
class _Trial(_base_types.Record):
@@ -101,7 +102,7 @@ def create(
101102
trial_name: (str): Name of the Trial.
102103
display_name (str): Name of the trial that will appear in UI,
103104
such as SageMaker Studio (default: None).
104-
tags (List[dict]): A list of tags to associate with the trial (default: None).
105+
tags (Optional[Tags]): A list of tags to associate with the trial (default: None).
105106
sagemaker_session (sagemaker.session.Session): Session object which
106107
manages interactions with Amazon SageMaker APIs and any other
107108
AWS services needed. If not specified, one is created using the
@@ -115,7 +116,7 @@ def create(
115116
trial_name=trial_name,
116117
experiment_name=experiment_name,
117118
display_name=display_name,
118-
tags=tags,
119+
tags=format_tags(tags),
119120
sagemaker_session=sagemaker_session,
120121
)
121122
return trial
@@ -259,7 +260,7 @@ def _load_or_create(
259260
display_name (str): Name of the trial that will appear in UI,
260261
such as SageMaker Studio (default: None). This is used only when the given
261262
`trial_name` does not exist and a new trial has to be created.
262-
tags (List[dict]): A list of tags to associate with the trial (default: None).
263+
tags (Optional[Tags]): A list of tags to associate with the trial (default: None).
263264
This is used only when the given `trial_name` does not exist and
264265
a new trial has to be created.
265266
sagemaker_session (sagemaker.session.Session): Session object which
@@ -275,7 +276,7 @@ def _load_or_create(
275276
experiment_name=experiment_name,
276277
trial_name=trial_name,
277278
display_name=display_name,
278-
tags=tags,
279+
tags=format_tags(tags),
279280
sagemaker_session=sagemaker_session,
280281
)
281282
except ClientError as ce:

src/sagemaker/experiments/trial_component.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.experiments import _api_types
2222
from sagemaker.experiments._api_types import TrialComponentSearchResult
23+
from sagemaker.utils import format_tags
2324

2425

2526
class _TrialComponent(_base_types.Record):
@@ -191,7 +192,7 @@ def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_se
191192
Args:
192193
trial_component_name (str): The name of the trial component.
193194
display_name (str): Display name of the trial component used by Studio (default: None).
194-
tags (List[Dict[str, str]]): Tags to add to the trial component (default: None).
195+
tags (Optional[Tags]): Tags to add to the trial component (default: None).
195196
sagemaker_session (sagemaker.session.Session): Session object which
196197
manages interactions with Amazon SageMaker APIs and any other
197198
AWS services needed. If not specified, one is created using the
@@ -204,7 +205,7 @@ def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_se
204205
cls._boto_create_method,
205206
trial_component_name=trial_component_name,
206207
display_name=display_name,
207-
tags=tags,
208+
tags=format_tags(tags),
208209
sagemaker_session=sagemaker_session,
209210
)
210211

@@ -316,7 +317,7 @@ def _load_or_create(
316317
display_name (str): Display name of the trial component used by Studio (default: None).
317318
This is used only when the given `trial_component_name` does not
318319
exist and a new trial component has to be created.
319-
tags (List[Dict[str, str]]): Tags to add to the trial component (default: None).
320+
tags (Optional[Tags]): Tags to add to the trial component (default: None).
320321
This is used only when the given `trial_component_name` does not
321322
exist and a new trial component has to be created.
322323
sagemaker_session (sagemaker.session.Session): Session object which
@@ -333,7 +334,7 @@ def _load_or_create(
333334
run_tc = _TrialComponent.create(
334335
trial_component_name=trial_component_name,
335336
display_name=display_name,
336-
tags=tags,
337+
tags=format_tags(tags),
337338
sagemaker_session=sagemaker_session,
338339
)
339340
except ClientError as ce:

src/sagemaker/huggingface/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from sagemaker.predictor import Predictor
3030
from sagemaker.serializers import JSONSerializer
3131
from sagemaker.session import Session
32-
from sagemaker.utils import to_string
32+
from sagemaker.utils import to_string, format_tags
3333
from sagemaker.workflow import is_pipeline_variable
3434
from sagemaker.workflow.entities import PipelineVariable
3535

@@ -255,7 +255,7 @@ def deploy(
255255
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html
256256
endpoint_name (str): The name of the endpoint to create (default:
257257
None). If not specified, a unique endpoint name will be created.
258-
tags (List[dict[str, str]]): The list of tags to attach to this
258+
tags (Optional[Tags]): The list of tags to attach to this
259259
specific endpoint.
260260
kms_key (str): The ARN of the KMS key that is used to encrypt the
261261
data on the storage volume attached to the instance hosting the
@@ -319,7 +319,7 @@ def deploy(
319319
deserializer,
320320
accelerator_type,
321321
endpoint_name,
322-
tags,
322+
format_tags(tags),
323323
kms_key,
324324
wait,
325325
data_capture_config,

src/sagemaker/jumpstart/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def _create_sagemaker_model(
388388
attach to an endpoint for model loading and inference, for
389389
example, 'ml.eia1.medium'. If not specified, no Elastic
390390
Inference accelerator will be attached to the endpoint. (Default: None).
391-
tags (List[dict[str, str]]): Optional. The list of tags to add to
391+
tags (Optional[Tags]): Optional. The list of tags to add to
392392
the model. Example: >>> tags = [{'Key': 'tagname', 'Value':
393393
'tagvalue'}] For more information about tags, see
394394
https://boto3.amazonaws.com/v1/documentation
@@ -402,6 +402,8 @@ def _create_sagemaker_model(
402402
any so they are ignored.
403403
"""
404404

405+
tags = format_tags(tags)
406+
405407
# if the user inputs a model artifact uri, do not use model package arn to create
406408
# inference endpoint.
407409
if self.model_package_arn and not self._model_data_is_set:

src/sagemaker/lineage/action.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.lineage import _api_types, _utils
2222
from sagemaker.lineage._api_types import ActionSource, ActionSummary
2323
from sagemaker.lineage.artifact import Artifact
24+
from sagemaker.utils import format_tags
2425

2526
from sagemaker.lineage.query import (
2627
LineageQuery,
@@ -159,12 +160,12 @@ def set_tags(self, tags=None):
159160
"""Add tags to the object.
160161
161162
Args:
162-
tags ([{key:value}]): list of key value pairs.
163+
tags (Optional[Tags]): list of key value pairs.
163164
164165
Returns:
165166
list({str:str}): a list of key value pairs
166167
"""
167-
return self._set_tags(resource_arn=self.action_arn, tags=tags)
168+
return self._set_tags(resource_arn=self.action_arn, tags=format_tags(tags))
168169

169170
@classmethod
170171
def create(

src/sagemaker/lineage/artifact.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232
from sagemaker.lineage._utils import _disassociate, get_resource_name_from_arn
3333
from sagemaker.lineage.association import Association
34-
from sagemaker.utils import get_module
34+
from sagemaker.utils import get_module, format_tags
3535

3636
LOGGER = logging.getLogger("sagemaker")
3737

@@ -288,12 +288,12 @@ def set_tags(self, tags=None):
288288
"""Add tags to the object.
289289
290290
Args:
291-
tags ([{key:value}]): list of key value pairs.
291+
tags (Optional[Tags]): list of key value pairs.
292292
293293
Returns:
294294
list({str:str}): a list of key value pairs
295295
"""
296-
return self._set_tags(resource_arn=self.artifact_arn, tags=tags)
296+
return self._set_tags(resource_arn=self.artifact_arn, tags=format_tags(tags))
297297

298298
@classmethod
299299
def create(

src/sagemaker/lineage/association.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.apiutils import _base_types
2121
from sagemaker.lineage import _api_types
2222
from sagemaker.lineage._api_types import AssociationSummary
23+
from sagemaker.utils import format_tags
2324

2425
logger = logging.getLogger(__name__)
2526

@@ -95,7 +96,7 @@ def set_tags(self, tags=None):
9596
"set_tags on Association is deprecated. Use set_tags on the source or destination\
9697
entity instead."
9798
)
98-
return self._set_tags(resource_arn=self.source_arn, tags=tags)
99+
return self._set_tags(resource_arn=self.source_arn, tags=format_tags(tags))
99100

100101
@classmethod
101102
def create(

src/sagemaker/lineage/context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from sagemaker.lineage.artifact import Artifact
3434
from sagemaker.lineage.action import Action
3535
from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
36+
from sagemaker.utils import format_tags
3637

3738

3839
class Context(_base_types.Record):
@@ -126,7 +127,7 @@ def set_tags(self, tags=None):
126127
Returns:
127128
list({str:str}): a list of key value pairs
128129
"""
129-
return self._set_tags(resource_arn=self.context_arn, tags=tags)
130+
return self._set_tags(resource_arn=self.context_arn, tags=format_tags(tags))
130131

131132
@classmethod
132133
def load(cls, context_name: str, sagemaker_session=None) -> "Context":

src/sagemaker/local/entities.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
from sagemaker.local.image import _SageMakerContainer
3030
from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host
31-
from sagemaker.utils import DeferredError, get_config_value
31+
from sagemaker.utils import DeferredError, get_config_value, format_tags
3232
from sagemaker.local.exceptions import StepExecutionException
3333

3434
logger = logging.getLogger(__name__)
@@ -552,7 +552,7 @@ class _LocalEndpointConfig(object):
552552
def __init__(self, config_name, production_variants, tags=None):
553553
self.name = config_name
554554
self.production_variants = production_variants
555-
self.tags = tags
555+
self.tags = format_tags(tags)
556556
self.creation_time = datetime.datetime.now()
557557

558558
def describe(self):
@@ -584,7 +584,7 @@ def __init__(self, endpoint_name, endpoint_config_name, tags=None, local_session
584584
self.name = endpoint_name
585585
self.endpoint_config = local_client.describe_endpoint_config(endpoint_config_name)
586586
self.production_variant = self.endpoint_config["ProductionVariants"][0]
587-
self.tags = tags
587+
self.tags = format_tags(tags)
588588

589589
model_name = self.production_variant["ModelName"]
590590
self.primary_container = local_client.describe_model(model_name)["PrimaryContainer"]

0 commit comments

Comments
 (0)