Skip to content

Commit a68faf1

Browse files
authored
feature: Add support for tags in to_pipeline API for feature processor (#3963)
1 parent 0bc4f41 commit a68faf1

File tree

4 files changed

+107
-4
lines changed

4 files changed

+107
-4
lines changed

src/sagemaker/feature_store/feature_processor/_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@
4040
S3_DATA_DISTRIBUTION_TYPE = "FullyReplicated"
4141
PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name"
4242
PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name"
43+
TO_PIPELINE_RESERVED_TAG_KEYS = [
44+
FEATURE_PROCESSOR_TAG_KEY,
45+
PIPELINE_CONTEXT_NAME_TAG_KEY,
46+
PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY,
47+
]

src/sagemaker/feature_store/feature_processor/feature_scheduler.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import re
1818
from datetime import datetime
19-
from typing import Callable, List, Optional, Dict, Sequence, Union, Any
19+
from typing import Callable, List, Optional, Dict, Sequence, Union, Any, Tuple
2020

2121
import pytz
2222
from botocore.exceptions import ClientError
@@ -58,6 +58,7 @@
5858
PIPELINE_NAME_MAXIMUM_LENGTH,
5959
RESOURCE_NOT_FOUND,
6060
FEATURE_GROUP_ARN_REGEX_PATTERN,
61+
TO_PIPELINE_RESERVED_TAG_KEYS,
6162
)
6263
from sagemaker.feature_store.feature_processor._feature_processor_config import (
6364
FeatureProcessorConfig,
@@ -107,6 +108,7 @@ def to_pipeline(
107108
role: Optional[str] = None,
108109
transformation_code: Optional[TransformationCode] = None,
109110
max_retries: Optional[int] = None,
111+
tags: Optional[List[Tuple[str, str]]] = None,
110112
sagemaker_session: Optional[Session] = None,
111113
) -> str:
112114
"""Creates a sagemaker pipeline that takes in a callable as a training step.
@@ -127,6 +129,8 @@ def to_pipeline(
127129
code for Lineage tracking. This code is not used for actual transformation.
128130
max_retries (Optional[int]): The number of times to retry sagemaker pipeline step.
129131
If not specified, sagemaker pipline step will not retry.
132+
tags (List[Tuple[str, str]): A list of tags attached to the pipeline. If not specified,
133+
no custom tags will be attached to the pipeline.
130134
sagemaker_session (Optional[Session]): Session object which manages interactions
131135
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
132136
function creates one using the default AWS configuration chain.
@@ -135,6 +139,8 @@ def to_pipeline(
135139
"""
136140

137141
_validate_input_for_to_pipeline_api(pipeline_name, step)
142+
if tags:
143+
_validate_tags_for_to_pipeline_api(tags)
138144

139145
_sagemaker_session = sagemaker_session or Session()
140146

@@ -200,12 +206,15 @@ def to_pipeline(
200206
sagemaker_session=_sagemaker_session,
201207
parameters=[SCHEDULED_TIME_PIPELINE_PARAMETER],
202208
)
209+
pipeline_tags = [dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)]
210+
if tags:
211+
pipeline_tags.extend([dict(Key=k, Value=v) for k, v in tags])
203212

204213
pipeline = Pipeline(**pipeline_request_dict)
205214
logger.info("Creating/Updating sagemaker pipeline %s", pipeline_name)
206215
pipeline.upsert(
207216
role_arn=_role,
208-
tags=[dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)],
217+
tags=pipeline_tags,
209218
)
210219
logger.info("Created sagemaker pipeline %s", pipeline_name)
211220

@@ -514,6 +523,23 @@ def _validate_input_for_to_pipeline_api(pipeline_name: str, step: Callable) -> N
514523
)
515524

516525

526+
def _validate_tags_for_to_pipeline_api(tags: List[Tuple[str, str]]) -> None:
527+
"""Validate tags provided to to_pipeline API.
528+
529+
Args:
530+
tags (List[Tuple[str, str]]): A list of tags attached to the pipeline.
531+
532+
Raises (ValueError): raises ValueError when any of the following scenario happen:
533+
1. reserved tag keys are provided to API.
534+
"""
535+
provided_tag_keys = [tag_key_value_pair[0] for tag_key_value_pair in tags]
536+
for reserved_tag_key in TO_PIPELINE_RESERVED_TAG_KEYS:
537+
if reserved_tag_key in provided_tag_keys:
538+
raise ValueError(
539+
f"{reserved_tag_key} is a reserved tag key for to_pipeline API. Please choose another tag."
540+
)
541+
542+
517543
def _validate_lineage_resources_for_to_pipeline_api(
518544
feature_processor_config: FeatureProcessorConfig, sagemaker_session: Session
519545
) -> None:

tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,18 @@ def transform(raw_s3_data_as_df):
550550
step=transform,
551551
role=get_execution_role(sagemaker_session),
552552
max_retries=2,
553+
tags=[("integ_test_tag_key_1", "integ_test_tag_key_2")],
553554
sagemaker_session=sagemaker_session,
554555
)
556+
_sagemaker_client = get_sagemaker_client(sagemaker_session=sagemaker_session)
555557

556558
assert pipeline_arn is not None
557559

560+
tags = _sagemaker_client.list_tags(ResourceArn=pipeline_arn)["Tags"]
561+
562+
tag_keys = [tag["Key"] for tag in tags]
563+
assert "integ_test_tag_key_1" in tag_keys
564+
558565
pipeline_description = Pipeline(name=pipeline_name).describe()
559566
assert pipeline_arn == pipeline_description["PipelineArn"]
560567
assert get_execution_role(sagemaker_session) == pipeline_description["RoleArn"]
@@ -570,7 +577,7 @@ def transform(raw_s3_data_as_df):
570577

571578
status = _wait_for_pipeline_execution_to_reach_terminal_state(
572579
pipeline_execution_arn=pipeline_execution_arn,
573-
sagemaker_client=get_sagemaker_client(sagemaker_session=sagemaker_session),
580+
sagemaker_client=_sagemaker_client,
574581
)
575582
assert status == "Succeeded"
576583

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def test_to_pipeline(
244244
step=wrapped_func,
245245
role=EXECUTION_ROLE_ARN,
246246
max_retries=1,
247+
tags=[("tag_key_1", "tag_value_1"), ("tag_key_2", "tag_value_2")],
247248
sagemaker_session=session,
248249
)
249250
assert pipeline_arn == PIPELINE_ARN
@@ -346,7 +347,11 @@ def test_to_pipeline(
346347
[
347348
call(
348349
role_arn=EXECUTION_ROLE_ARN,
349-
tags=[dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)],
350+
tags=[
351+
dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE),
352+
dict(Key="tag_key_1", Value="tag_value_1"),
353+
dict(Key="tag_key_2", Value="tag_value_2"),
354+
],
350355
),
351356
call(
352357
role_arn=EXECUTION_ROLE_ARN,
@@ -527,6 +532,66 @@ def test_to_pipeline_pipeline_name_length_limit_exceeds(
527532
)
528533

529534

535+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
536+
@patch(
537+
"sagemaker.remote_function.job._JobSettings._get_default_spark_image",
538+
return_value="some_image_uri",
539+
)
540+
@patch("sagemaker.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN)
541+
def test_to_pipeline_used_reserved_tags(get_execution_role, mock_spark_image, session):
542+
session.sagemaker_config = None
543+
session.boto_region_name = TEST_REGION
544+
session.expand_role.return_value = EXECUTION_ROLE_ARN
545+
spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"])
546+
job_settings = _JobSettings(
547+
spark_config=spark_config,
548+
s3_root_uri=S3_URI,
549+
role=EXECUTION_ROLE_ARN,
550+
include_local_workdir=True,
551+
instance_type="ml.m5.large",
552+
encrypt_inter_container_traffic=True,
553+
sagemaker_session=session,
554+
)
555+
jobs_container_entrypoint = [
556+
"/bin/bash",
557+
f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}",
558+
]
559+
jobs_container_entrypoint.extend(["--jars", "path_a"])
560+
jobs_container_entrypoint.extend(["--py-files", "path_b"])
561+
jobs_container_entrypoint.extend(["--files", "path_c"])
562+
jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH])
563+
container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"]
564+
container_args.extend(["--region", session.boto_region_name])
565+
566+
mock_feature_processor_config = Mock(
567+
mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg"
568+
)
569+
mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK
570+
571+
wrapped_func = Mock(
572+
Callable,
573+
feature_processor_config=mock_feature_processor_config,
574+
job_settings=job_settings,
575+
wrapped_func=job_function,
576+
)
577+
wrapped_func.feature_processor_config.return_value = mock_feature_processor_config
578+
wrapped_func.job_settings.return_value = job_settings
579+
wrapped_func.wrapped_func.return_value = job_function
580+
581+
with pytest.raises(
582+
ValueError,
583+
match="sm-fs-fe:created-from is a reserved tag key for to_pipeline API. Please choose another tag.",
584+
):
585+
to_pipeline(
586+
pipeline_name="pipeline_name",
587+
step=wrapped_func,
588+
role=EXECUTION_ROLE_ARN,
589+
max_retries=1,
590+
tags=[("sm-fs-fe:created-from", "random")],
591+
sagemaker_session=session,
592+
)
593+
594+
530595
@patch(
531596
"sagemaker.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources",
532597
return_value=None,

0 commit comments

Comments
 (0)