Skip to content

Commit 6de4fca

Browse files
metrizableDan Choi
authored andcommitted
change: add project tags to creates (aws#534)
1 parent 07dd11a commit 6de4fca

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

src/sagemaker/_studio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26-
def _append_project_tags(working_dir=None, tags=None):
26+
def _append_project_tags(tags=None, working_dir=None):
2727
"""Appends the project tag to the list of tags, if it exists.
2828
2929
Args:

src/sagemaker/session.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import sagemaker.logs
3232
from sagemaker import vpc_utils
3333

34+
from sagemaker._studio import _append_project_tags
3435
from sagemaker.deprecations import deprecated_class
3536
from sagemaker.inputs import ShuffleConfig, TrainingInput
3637
from sagemaker.user_agent import prepend_user_agent
@@ -534,6 +535,7 @@ def train( # noqa: C901
534535
Returns:
535536
str: ARN of the training job, if it is created.
536537
"""
538+
tags = _append_project_tags(tags)
537539
train_request = self._get_train_request(
538540
input_mode=input_mode,
539541
input_config=input_config,
@@ -779,6 +781,7 @@ def process(
779781
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
780782
(default: ``None``)
781783
"""
784+
tags = _append_project_tags(tags)
782785
process_request = self._get_process_request(
783786
inputs=inputs,
784787
output_config=output_config,
@@ -1019,6 +1022,7 @@ def create_monitoring_schedule(
10191022
"NetworkConfig"
10201023
] = network_config
10211024

1025+
tags = _append_project_tags(tags)
10221026
if tags is not None:
10231027
monitoring_schedule_request["Tags"] = tags
10241028

@@ -1527,6 +1531,8 @@ def auto_ml(
15271531
auto_ml_job_request["AutoMLJobObjective"] = job_objective
15281532
if problem_type is not None:
15291533
auto_ml_job_request["ProblemType"] = problem_type
1534+
1535+
tags = _append_project_tags(tags)
15301536
if tags is not None:
15311537
auto_ml_job_request["Tags"] = tags
15321538

@@ -1719,6 +1725,7 @@ def compile_model(
17191725
"CompilationJobName": job_name,
17201726
}
17211727

1728+
tags = _append_project_tags(tags)
17221729
if tags is not None:
17231730
compilation_job_request["Tags"] = tags
17241731

@@ -1868,6 +1875,7 @@ def tune( # noqa: C901
18681875
if warm_start_config is not None:
18691876
tune_request["WarmStartConfig"] = warm_start_config
18701877

1878+
tags = _append_project_tags(tags)
18711879
if tags is not None:
18721880
tune_request["Tags"] = tags
18731881

@@ -1925,6 +1933,7 @@ def create_tuning_job(
19251933
if warm_start_config is not None:
19261934
tune_request["WarmStartConfig"] = warm_start_config
19271935

1936+
tags = _append_project_tags(tags)
19281937
if tags is not None:
19291938
tune_request["Tags"] = tags
19301939

@@ -2315,6 +2324,7 @@ def transform(
23152324
job. Dictionary contains two optional keys,
23162325
'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
23172326
"""
2327+
tags = _append_project_tags(tags)
23182328
transform_request = self._get_transform_request(
23192329
job_name=job_name,
23202330
model_name=model_name,
@@ -2430,6 +2440,7 @@ def create_model(
24302440
Returns:
24312441
str: Name of the Amazon SageMaker ``Model`` created.
24322442
"""
2443+
tags = _append_project_tags(tags)
24332444
create_model_request = self._create_model_request(
24342445
name=name,
24352446
role=role,
@@ -2754,6 +2765,7 @@ def create_endpoint_config(
27542765
],
27552766
}
27562767

2768+
tags = _append_project_tags(tags)
27572769
if tags is not None:
27582770
request["Tags"] = tags
27592771

@@ -2823,6 +2835,7 @@ def create_endpoint_config_from_existing(
28232835
request_tags = new_tags or self.list_tags(
28242836
existing_endpoint_config_desc["EndpointConfigArn"]
28252837
)
2838+
request_tags = _append_project_tags(request_tags)
28262839
if request_tags:
28272840
request["Tags"] = request_tags
28282841

@@ -2857,6 +2870,7 @@ def create_endpoint(self, endpoint_name, config_name, tags=None, wait=True):
28572870
LOGGER.info("Creating endpoint with name %s", endpoint_name)
28582871

28592872
tags = tags or []
2873+
tags = _append_project_tags(tags)
28602874

28612875
self.sagemaker_client.create_endpoint(
28622876
EndpointName=endpoint_name, EndpointConfigName=config_name, Tags=tags
@@ -3336,6 +3350,7 @@ def endpoint_from_production_variants(
33363350
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)
33373351
):
33383352
config_options = {"EndpointConfigName": name, "ProductionVariants": production_variants}
3353+
tags = _append_project_tags(tags)
33393354
if tags:
33403355
config_options["Tags"] = tags
33413356
if kms_key:
@@ -3728,6 +3743,7 @@ def create_feature_group(
37283743
Returns:
37293744
Response dict from service.
37303745
"""
3746+
tags = _append_project_tags(tags)
37313747
kwargs = dict(
37323748
FeatureGroupName=feature_group_name,
37333749
RecordIdentifierFeatureName=record_identifier_name,

tests/unit/sagemaker/test_studio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,13 @@ def test_append_project_tags(tmpdir):
8282
config.write('{"sagemakerProjectId": "proj-1234", "sagemakerProjectName": "proj-name"}')
8383
working_dir = tmpdir.mkdir("sub")
8484

85-
tags = _append_project_tags(working_dir, None)
85+
tags = _append_project_tags(None, working_dir)
8686
assert tags == [
8787
{"Key": "sagemaker:project-id", "Value": "proj-1234"},
8888
{"Key": "sagemaker:project-name", "Value": "proj-name"},
8989
]
9090

91-
tags = _append_project_tags(working_dir, [{"Key": "a", "Value": "b"}])
91+
tags = _append_project_tags([{"Key": "a", "Value": "b"}], working_dir)
9292
assert tags == [
9393
{"Key": "a", "Value": "b"},
9494
{"Key": "sagemaker:project-id", "Value": "proj-1234"},

0 commit comments

Comments
 (0)