Skip to content

Commit 3bc4298

Browse files
fix: handle tags when upsert pipeine
1 parent 4c0d3cf commit 3bc4298

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

src/sagemaker/workflow/pipeline.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def describe(self) -> Dict[str, Any]:
143143
"""
144144
return self.sagemaker_session.sagemaker_client.describe_pipeline(PipelineName=self.name)
145145

146-
def update(self, role_arn: str, description: str = None) -> Dict[str, Any]:
146+
def update(self, role_arn: str, description: str = None, ) -> Dict[str, Any]:
147147
"""Updates a Pipeline in the Workflow service.
148148
149149
Args:
@@ -182,6 +182,20 @@ def upsert(
182182
and "Pipeline names must be unique within" in error["Message"]
183183
):
184184
response = self.update(role_arn, description)
185+
if tags is not None:
186+
old_tags = self.sagemaker_session.sagemaker_client.list_tags(
187+
ResourceArn=response["PipelineArn"])["Tags"]
188+
189+
tag_keys = [tag["Key"] for tag in old_tags]
190+
191+
self.sagemaker_session.sagemaker_client.delete_tags(
192+
ResourceArn=response["PipelineArn"],
193+
TagKeys=tag_keys
194+
)
195+
self.sagemaker_session.sagemaker_client.add_tags(
196+
ResourceArn=response["PipelineArn"],
197+
Tags=tags
198+
)
185199
else:
186200
raise
187201
return response

tests/integ/test_workflow.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,13 +410,13 @@ def test_three_step_definition(
410410

411411
def test_one_step_sklearn_processing_pipeline(
412412
sagemaker_session,
413-
role,
414413
sklearn_latest_version,
415414
cpu_instance_type,
416415
pipeline_name,
417416
region_name,
418417
athena_dataset_definition,
419418
):
419+
role = "arn:aws:iam::734680132978:role/SageMakerFullAccessRole"
420420
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
421421
script_path = os.path.join(DATA_DIR, "dummy_script.py")
422422
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
@@ -458,20 +458,26 @@ def test_one_step_sklearn_processing_pipeline(
458458
# sagemaker entities. However, the jobs created in the steps themselves execute
459459
# under a potentially different role, often requiring access to S3 and other
460460
# artifacts not required to during creation of the jobs in the pipeline steps.
461-
response = pipeline.create(role)
461+
response = pipeline.create(role, tags=[{"Key": "foo", "Value": "abc"}, {"Key": "bar", "Value": "xyz"}])
462462
create_arn = response["PipelineArn"]
463463
assert re.match(
464464
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
465465
create_arn,
466466
)
467+
original_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=create_arn)
468+
for tag in [{"Key": "foo", "Value": "abc"}, {"Key": "bar", "Value": "xyz"}]:
469+
assert tag in original_tags["Tags"]
467470

468471
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
469-
response = pipeline.update(role)
472+
response = pipeline.upsert(role, tags=[{"Key": "abc", "Value": "foo"}, {"Key": "xyz", "Value": "bar"}])
470473
update_arn = response["PipelineArn"]
471474
assert re.match(
472475
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
473476
update_arn,
474477
)
478+
updated_tags = sagemaker_session.sagemaker_client.list_tags(ResourceArn=create_arn)
479+
for tag in [{"Key": "abc", "Value": "foo"}, {"Key": "xyz", "Value": "bar"}]:
480+
assert tag in updated_tags["Tags"]
475481

476482
execution = pipeline.start(parameters={})
477483
assert re.match(

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,39 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn):
102102
}
103103
},
104104
),
105+
{"PipelineArn": "mock_pipeline_arn"},
106+
[
107+
{"Key": "dummy", "Value": "dummy_tag"}
108+
],
109+
{},
105110
{},
106111
]
112+
107113
pipeline = Pipeline(
108114
name="MyPipeline",
109115
parameters=[],
110116
steps=[],
111117
sagemaker_session=sagemaker_session_mock,
112118
)
113-
pipeline.update(role_arn=role_arn)
119+
120+
tags = [
121+
{"Key": "foo", "Value": "abc"},
122+
{"Key": "bar", "Value": "xyz"},
123+
]
124+
pipeline.upsert(role_arn=role_arn, tags=tags)
114125
assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with(
115126
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
116127
)
117128
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
118129
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
119130
)
131+
assert sagemaker_session_mock.sagemaker_client.list_tags.called_with(ResourceArn="mock_pipeline_arn")
132+
assert sagemaker_session_mock.sagemaker_client.delete_tags(
133+
ResourceArn="mock_pipeline_arn", TagKeys=["dummy"]
134+
)
135+
assert sagemaker_session_mock.sagemaker_client.add_tags.called_with(
136+
ResourceArn="mock_pipeline_arn", Tags=tags
137+
)
120138

121139

122140
def test_pipeline_delete(sagemaker_session_mock):

0 commit comments

Comments
 (0)