Skip to content

Commit 1a64d36

Browse files
fix: handle tags when upsert pipeine (#2488)
* fix: handle tags when upsert pipeine * fix: typo and remove local test changes * fix: typo * fix: black check failure * fix: merge tags instead of delete Co-authored-by: icywang86rui <[email protected]>
1 parent 939fab0 commit 1a64d36

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,19 @@ def upsert(
182182
and "Pipeline names must be unique within" in error["Message"]
183183
):
184184
response = self.update(role_arn, description)
185+
if tags is not None:
186+
old_tags = self.sagemaker_session.sagemaker_client.list_tags(
187+
ResourceArn=response["PipelineArn"]
188+
)["Tags"]
189+
190+
tag_keys = [tag["Key"] for tag in tags]
191+
for old_tag in old_tags:
192+
if old_tag["Key"] not in tag_keys:
193+
tags.append(old_tag)
194+
195+
self.sagemaker_session.sagemaker_client.add_tags(
196+
ResourceArn=response["PipelineArn"], Tags=tags
197+
)
185198
else:
186199
raise
187200
return response

tests/integ/test_workflow.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,20 +458,34 @@ def test_one_step_sklearn_processing_pipeline(
458458
# sagemaker entities. However, the jobs created in the steps themselves execute
459459
# under a potentially different role, often requiring access to S3 and other
460460
# artifacts not required to during creation of the jobs in the pipeline steps.
461-
response = pipeline.create(role)
461+
response = pipeline.create(
462+
role, tags=[{"Key": "foo", "Value": "123"}, {"Key": "bar", "Value": "456"}]
463+
)
462464
create_arn = response["PipelineArn"]
463465
assert re.match(
464466
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
465467
create_arn,
466468
)
469+
original_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=create_arn)
470+
for tag in [{"Key": "foo", "Value": "123"}, {"Key": "bar", "Value": "456"}]:
471+
assert tag in original_tags["Tags"]
467472

468473
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
469-
response = pipeline.update(role)
474+
response = pipeline.upsert(
475+
role, tags=[{"Key": "foo", "Value": "abc"}, {"Key": "baz", "Value": "789"}]
476+
)
470477
update_arn = response["PipelineArn"]
471478
assert re.match(
472479
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
473480
update_arn,
474481
)
482+
updated_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=create_arn)
483+
for tag in [
484+
{"Key": "foo", "Value": "abc"},
485+
{"Key": "bar", "Value": "456"},
486+
{"Key": "baz", "Value": "789"},
487+
]:
488+
assert tag in updated_tags["Tags"]
475489

476490
execution = pipeline.start(parameters={})
477491
assert re.match(

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,37 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn):
102102
}
103103
},
104104
),
105+
{"PipelineArn": "mock_pipeline_arn"},
106+
[{"Key": "dummy", "Value": "dummy_tag"}],
105107
{},
106108
]
109+
107110
pipeline = Pipeline(
108111
name="MyPipeline",
109112
parameters=[],
110113
steps=[],
111114
sagemaker_session=sagemaker_session_mock,
112115
)
113-
pipeline.update(role_arn=role_arn)
116+
117+
tags = [
118+
{"Key": "foo", "Value": "abc"},
119+
{"Key": "bar", "Value": "xyz"},
120+
]
121+
pipeline.upsert(role_arn=role_arn, tags=tags)
114122
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
115123
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
116124
)
117125
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
118126
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
119127
)
128+
assert sagemaker_session_mock.sagemaker_client.list_tags.called_with(
129+
ResourceArn="mock_pipeline_arn"
130+
)
131+
132+
tags.append({"Key": "dummy", "Value": "dummy_tag"})
133+
assert sagemaker_session_mock.sagemaker_client.add_tags.called_with(
134+
ResourceArn="mock_pipeline_arn", Tags=tags
135+
)
120136

121137

122138
def test_pipeline_delete(sagemaker_session_mock):

0 commit comments

Comments
 (0)