Skip to content

Commit 09549d3

Browse files
martinRenouakrishna1995
authored andcommitted
Change: More pythonic tags
1 parent 06b3ef0 commit 09549d3

File tree

31 files changed

+225
-179
lines changed

31 files changed

+225
-179
lines changed

src/sagemaker/algorithm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sagemaker.session import Session
2929
from sagemaker.workflow.entities import PipelineVariable
3030
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
31+
from sagemaker.utils import format_tags, Tags
3132

3233
from sagemaker.workflow import is_pipeline_variable
3334

@@ -58,7 +59,7 @@ def __init__(
5859
base_job_name: Optional[str] = None,
5960
sagemaker_session: Optional[Session] = None,
6061
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
61-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
62+
tags: Optional[Tags] = None,
6263
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
6364
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
6465
model_uri: Optional[str] = None,
@@ -121,7 +122,7 @@ def __init__(
121122
interactions with Amazon SageMaker APIs and any other AWS services needed. If
122123
not specified, the estimator creates one using the default
123124
AWS configuration chain.
124-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
125+
tags (Union[Tags]): Tags for
125126
labeling a training job. For more, see
126127
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
127128
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not specified
@@ -170,7 +171,7 @@ def __init__(
170171
output_kms_key=output_kms_key,
171172
base_job_name=base_job_name,
172173
sagemaker_session=sagemaker_session,
173-
tags=tags,
174+
tags=format_tags(tags),
174175
subnets=subnets,
175176
security_group_ids=security_group_ids,
176177
model_uri=model_uri,

src/sagemaker/automl/automl.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from sagemaker.job import _Job
3030
from sagemaker.session import Session
31-
from sagemaker.utils import name_from_base, resolve_value_from_config
31+
from sagemaker.utils import name_from_base, resolve_value_from_config, format_tags, Tags
3232
from sagemaker.workflow.entities import PipelineVariable
3333
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
3434

@@ -127,7 +127,7 @@ def __init__(
127127
total_job_runtime_in_seconds: Optional[int] = None,
128128
job_objective: Optional[Dict[str, str]] = None,
129129
generate_candidate_definitions_only: Optional[bool] = False,
130-
tags: Optional[List[Dict[str, str]]] = None,
130+
tags: Optional[Tags] = None,
131131
content_type: Optional[str] = None,
132132
s3_data_type: Optional[str] = None,
133133
feature_specification_s3_uri: Optional[str] = None,
@@ -167,8 +167,7 @@ def __init__(
167167
In the format of: {"MetricName": str}
168168
generate_candidate_definitions_only (bool): Whether to generates
169169
possible candidates without training the models.
170-
tags (List[dict[str, str]]): The list of tags to attach to this
171-
specific endpoint.
170+
tags (Optional[Tags]): Tags to attach to this specific endpoint.
172171
content_type (str): The content type of the data from the input source.
173172
s3_data_type (str): The data type for S3 data source.
174173
Valid values: ManifestFile or S3Prefix.
@@ -203,7 +202,7 @@ def __init__(
203202
self.target_attribute_name = target_attribute_name
204203
self.job_objective = job_objective
205204
self.generate_candidate_definitions_only = generate_candidate_definitions_only
206-
self.tags = tags
205+
self.tags = format_tags(tags)
207206
self.content_type = content_type
208207
self.s3_data_type = s3_data_type
209208
self.feature_specification_s3_uri = feature_specification_s3_uri

src/sagemaker/clarify.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from sagemaker.session import Session
3434
from sagemaker.network import NetworkConfig
3535
from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor
36+
from sagemaker.utils import format_tags, Tags
3637

3738
logger = logging.getLogger(__name__)
3839

@@ -1417,7 +1418,7 @@ def __init__(
14171418
max_runtime_in_seconds: Optional[int] = None,
14181419
sagemaker_session: Optional[Session] = None,
14191420
env: Optional[Dict[str, str]] = None,
1420-
tags: Optional[List[Dict[str, str]]] = None,
1421+
tags: Optional[Tags] = None,
14211422
network_config: Optional[NetworkConfig] = None,
14221423
job_name_prefix: Optional[str] = None,
14231424
version: Optional[str] = None,
@@ -1454,7 +1455,7 @@ def __init__(
14541455
using the default AWS configuration chain.
14551456
env (dict[str, str]): Environment variables to be passed to
14561457
the processing jobs (default: None).
1457-
tags (list[dict]): List of tags to be passed to the processing job
1458+
tags (Optional[Tags]): Tags to be passed to the processing job
14581459
(default: None). For more, see
14591460
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
14601461
network_config (:class:`~sagemaker.network.NetworkConfig`):
@@ -1482,7 +1483,7 @@ def __init__(
14821483
None, # We set method-specific job names below.
14831484
sagemaker_session,
14841485
env,
1485-
tags,
1486+
format_tags(tags),
14861487
network_config,
14871488
)
14881489

src/sagemaker/estimator.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@
9898
to_string,
9999
check_and_get_run_experiment_config,
100100
resolve_value_from_config,
101+
format_tags,
102+
Tags,
101103
)
102104
from sagemaker.workflow import is_pipeline_variable
103105
from sagemaker.workflow.entities import PipelineVariable
@@ -144,7 +146,7 @@ def __init__(
144146
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
145147
base_job_name: Optional[str] = None,
146148
sagemaker_session: Optional[Session] = None,
147-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
149+
tags: Optional[Tags] = None,
148150
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
149151
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
150152
model_uri: Optional[str] = None,
@@ -270,8 +272,8 @@ def __init__(
270272
manages interactions with Amazon SageMaker APIs and any other
271273
AWS services needed. If not specified, the estimator creates one
272274
using the default AWS configuration chain.
273-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]):
274-
List of tags for labeling a training job. For more, see
275+
tags (Optional[Tags]):
276+
Tags for labeling a training job. For more, see
275277
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
276278
subnets (list[str] or list[PipelineVariable]): List of subnet ids. If not
277279
specified training job will be created without VPC config.
@@ -604,6 +606,7 @@ def __init__(
604606
else:
605607
self.sagemaker_session = sagemaker_session or Session()
606608

609+
tags = format_tags(tags)
607610
self.tags = (
608611
add_jumpstart_uri_tags(
609612
tags=tags, training_model_uri=self.model_uri, training_script_uri=self.source_dir
@@ -1352,7 +1355,7 @@ def compile_model(
13521355
framework=None,
13531356
framework_version=None,
13541357
compile_max_run=15 * 60,
1355-
tags=None,
1358+
tags: Optional[Tags] = None,
13561359
target_platform_os=None,
13571360
target_platform_arch=None,
13581361
target_platform_accelerator=None,
@@ -1378,7 +1381,7 @@ def compile_model(
13781381
compile_max_run (int): Timeout in seconds for compilation (default:
13791382
15 * 60). After this amount of time Amazon SageMaker Neo
13801383
terminates the compilation job regardless of its current status.
1381-
tags (list[dict]): List of tags for labeling a compilation job. For
1384+
tags (list[dict]): Tags for labeling a compilation job. For
13821385
more, see
13831386
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
13841387
target_platform_os (str): Target Platform OS, for example: 'LINUX'.
@@ -1420,7 +1423,7 @@ def compile_model(
14201423
input_shape,
14211424
output_path,
14221425
self.role,
1423-
tags,
1426+
format_tags(tags),
14241427
self._compilation_job_name(),
14251428
compile_max_run,
14261429
framework=framework,
@@ -1532,7 +1535,7 @@ def deploy(
15321535
model_name=None,
15331536
kms_key=None,
15341537
data_capture_config=None,
1535-
tags=None,
1538+
tags: Optional[Tags] = None,
15361539
serverless_inference_config=None,
15371540
async_inference_config=None,
15381541
volume_size=None,
@@ -1601,8 +1604,10 @@ def deploy(
16011604
empty object passed through, will use pre-defined values in
16021605
``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an
16031606
instance based endpoint if it's None. (default: None)
1604-
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
1607+
tags(Optional[Tags]): Optional. Tags to attach to this specific
16051608
endpoint. Example:
1609+
>>> tags = {'tagname', 'tagvalue'}
1610+
Or
16061611
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
16071612
For more information about tags, see
16081613
https://boto3.amazonaws.com/v1/documentation\
@@ -1664,7 +1669,7 @@ def deploy(
16641669
model.name = model_name
16651670

16661671
tags = update_inference_tags_with_jumpstart_training_tags(
1667-
inference_tags=tags, training_tags=self.tags
1672+
inference_tags=format_tags(tags), training_tags=self.tags
16681673
)
16691674

16701675
return model.deploy(
@@ -2017,7 +2022,7 @@ def transformer(
20172022
env=None,
20182023
max_concurrent_transforms=None,
20192024
max_payload=None,
2020-
tags=None,
2025+
tags: Optional[Tags] = None,
20212026
role=None,
20222027
volume_kms_key=None,
20232028
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
@@ -2051,7 +2056,7 @@ def transformer(
20512056
to be made to each individual transform container at one time.
20522057
max_payload (int): Maximum size of the payload in a single HTTP
20532058
request to the container in MB.
2054-
tags (list[dict]): List of tags for labeling a transform job. If
2059+
tags (Optional[Tags]): Tags for labeling a transform job. If
20552060
none specified, then the tags used for the training job are used
20562061
for the transform job.
20572062
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -2078,7 +2083,7 @@ def transformer(
20782083
model. If not specified, the estimator generates a default job name
20792084
based on the training image name and current timestamp.
20802085
"""
2081-
tags = tags or self.tags
2086+
tags = format_tags(tags) or self.tags
20822087
model_name = self._get_or_create_name(model_name)
20832088

20842089
if self.latest_training_job is None:
@@ -2717,7 +2722,7 @@ def __init__(
27172722
base_job_name: Optional[str] = None,
27182723
sagemaker_session: Optional[Session] = None,
27192724
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
2720-
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
2725+
tags: Optional[Tags] = None,
27212726
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
27222727
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
27232728
model_uri: Optional[str] = None,
@@ -2847,7 +2852,7 @@ def __init__(
28472852
hyperparameters. SageMaker rejects the training job request and returns an
28482853
validation error for detected credentials, if such user input is found.
28492854
2850-
tags (list[dict[str, str] or list[dict[str, PipelineVariable]]): List of tags for
2855+
tags (Optional[Tags]): Tags for
28512856
labeling a training job. For more, see
28522857
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
28532858
subnets (list[str] or list[PipelineVariable]): List of subnet ids.
@@ -3130,7 +3135,7 @@ def __init__(
31303135
output_kms_key,
31313136
base_job_name,
31323137
sagemaker_session,
3133-
tags,
3138+
format_tags(tags),
31343139
subnets,
31353140
security_group_ids,
31363141
model_uri=model_uri,
@@ -3762,7 +3767,7 @@ def transformer(
37623767
env=None,
37633768
max_concurrent_transforms=None,
37643769
max_payload=None,
3765-
tags=None,
3770+
tags: Optional[Tags] = None,
37663771
role=None,
37673772
model_server_workers=None,
37683773
volume_kms_key=None,
@@ -3798,7 +3803,7 @@ def transformer(
37983803
to be made to each individual transform container at one time.
37993804
max_payload (int): Maximum size of the payload in a single HTTP
38003805
request to the container in MB.
3801-
tags (list[dict]): List of tags for labeling a transform job. If
3806+
tags (Optional[Tags]): Tags for labeling a transform job. If
38023807
none specified, then the tags used for the training job are used
38033808
for the transform job.
38043809
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
@@ -3837,7 +3842,7 @@ def transformer(
38373842
SageMaker Batch Transform job.
38383843
"""
38393844
role = role or self.role
3840-
tags = tags or self.tags
3845+
tags = format_tags(tags) or self.tags
38413846
model_name = self._get_or_create_name(model_name)
38423847

38433848
if self.latest_training_job is not None:

src/sagemaker/experiments/run.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
from sagemaker.utils import (
4545
get_module,
4646
unique_name_from_base,
47+
format_tags,
48+
Tags,
49+
TagsDict,
4750
)
4851

4952
from sagemaker.experiments._utils import (
@@ -97,7 +100,7 @@ def __init__(
97100
run_name: Optional[str] = None,
98101
experiment_display_name: Optional[str] = None,
99102
run_display_name: Optional[str] = None,
100-
tags: Optional[List[Dict[str, str]]] = None,
103+
tags: Optional[Tags] = None,
101104
sagemaker_session: Optional["Session"] = None,
102105
artifact_bucket: Optional[str] = None,
103106
artifact_prefix: Optional[str] = None,
@@ -152,7 +155,7 @@ def __init__(
152155
run_display_name (str): The display name of the run used in UI (default: None).
153156
This display name is used in a create run call. If a run with the
154157
specified name already exists, this display name won't take effect.
155-
tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
158+
tags (Optional[Tags]): Tags to be used for all create calls,
156159
e.g. to create an experiment, a run group, etc. (default: None).
157160
sagemaker_session (sagemaker.session.Session): Session object which
158161
manages interactions with Amazon SageMaker APIs and any other
@@ -172,6 +175,8 @@ def __init__(
172175
# avoid confusion due to mis-match in casing between run name and TC name
173176
self.run_name = self.run_name.lower()
174177

178+
tags = format_tags(tags)
179+
175180
trial_component_name = Run._generate_trial_component_name(
176181
run_name=self.run_name, experiment_name=self.experiment_name
177182
)
@@ -676,11 +681,11 @@ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: s
676681
)
677682

678683
@staticmethod
679-
def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
684+
def _append_run_tc_label_to_tags(tags: Optional[List[TagsDict]] = None) -> list:
680685
"""Append the run trial component label to tags used to create a trial component.
681686
682687
Args:
683-
tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object.
688+
tags (List[TagsDict]): The tags supplied by users to initialize a Run object.
684689
685690
Returns:
686691
list: The updated tags with the appended run trial component label.

src/sagemaker/feature_store/feature_group.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import tempfile
2929
from concurrent.futures import as_completed
3030
from concurrent.futures import ThreadPoolExecutor
31-
from typing import Sequence, List, Dict, Any, Union
31+
from typing import Optional, Sequence, List, Dict, Any, Union
3232
from urllib.parse import urlparse
3333

3434
from multiprocessing.pool import AsyncResult
@@ -65,7 +65,7 @@
6565
OnlineStoreConfigUpdate,
6666
OnlineStoreStorageTypeEnum,
6767
)
68-
from sagemaker.utils import resolve_value_from_config
68+
from sagemaker.utils import resolve_value_from_config, format_tags, Tags
6969

7070
logger = logging.getLogger(__name__)
7171

@@ -538,7 +538,7 @@ def create(
538538
disable_glue_table_creation: bool = False,
539539
data_catalog_config: DataCatalogConfig = None,
540540
description: str = None,
541-
tags: List[Dict[str, str]] = None,
541+
tags: Optional[Tags] = None,
542542
table_format: TableFormatEnum = None,
543543
online_store_storage_type: OnlineStoreStorageTypeEnum = None,
544544
) -> Dict[str, Any]:
@@ -566,7 +566,7 @@ def create(
566566
data_catalog_config (DataCatalogConfig): configuration for
567567
Metadata store (default: None).
568568
description (str): description of the FeatureGroup (default: None).
569-
tags (List[Dict[str, str]]): list of tags for labeling a FeatureGroup (default: None).
569+
tags (Optional[Tags]): Tags for labeling a FeatureGroup (default: None).
570570
table_format (TableFormatEnum): format of the offline store table (default: None).
571571
online_store_storage_type (OnlineStoreStorageTypeEnum): storage type for the
572572
online store (default: None).
@@ -602,7 +602,7 @@ def create(
602602
],
603603
role_arn=role_arn,
604604
description=description,
605-
tags=tags,
605+
tags=format_tags(tags),
606606
)
607607

608608
# online store configuration

src/sagemaker/feature_store/feature_processor/_event_bridge_rule_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.feature_store.feature_processor._enums import (
3333
FeatureProcessorPipelineExecutionStatus,
3434
)
35+
from sagemaker.utils import TagsDict
3536

3637
logger = logging.getLogger("sagemaker")
3738

@@ -175,7 +176,7 @@ def disable_rule(self, rule_name: str) -> None:
175176
self.event_bridge_rule_client.disable_rule(Name=rule_name)
176177
logger.info("Disabled EventBridge Rule for pipeline %s.", rule_name)
177178

178-
def add_tags(self, rule_arn: str, tags: List[Dict[str, str]]) -> None:
179+
def add_tags(self, rule_arn: str, tags: List[TagsDict]) -> None:
179180
"""Adds tags to the EventBridge Rule.
180181
181182
Args:

0 commit comments

Comments
 (0)