Skip to content

Commit 867e43d

Browse files
committed
feature: support tags in to_pipeline for feature processor
1 parent c1b2465 commit 867e43d

File tree

3 files changed

+165
-3
lines changed

3 files changed

+165
-3
lines changed

src/sagemaker/feature_store/feature_processor/_constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,8 @@
4040
S3_DATA_DISTRIBUTION_TYPE = "FullyReplicated"
4141
PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name"
4242
PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-version-context-name"
43+
TO_PIPELINE_RESERVED_TAG_KEYS = [
44+
FEATURE_PROCESSOR_TAG_KEY,
45+
PIPELINE_CONTEXT_NAME_TAG_KEY,
46+
PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY,
47+
]

src/sagemaker/feature_store/feature_processor/feature_scheduler.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import re
1818
from datetime import datetime
19-
from typing import Callable, List, Optional, Dict, Sequence, Union, Any
19+
from typing import Callable, List, Optional, Dict, Sequence, Union, Any, Tuple
2020

2121
import pytz
2222
from botocore.exceptions import ClientError
@@ -58,6 +58,7 @@
5858
PIPELINE_NAME_MAXIMUM_LENGTH,
5959
RESOURCE_NOT_FOUND,
6060
FEATURE_GROUP_ARN_REGEX_PATTERN,
61+
TO_PIPELINE_RESERVED_TAG_KEYS,
6162
)
6263
from sagemaker.feature_store.feature_processor._feature_processor_config import (
6364
FeatureProcessorConfig,
@@ -107,6 +108,7 @@ def to_pipeline(
107108
role: Optional[str] = None,
108109
transformation_code: Optional[TransformationCode] = None,
109110
max_retries: Optional[int] = None,
111+
tags: Optional[List[Tuple[str, str]]] = None,
110112
sagemaker_session: Optional[Session] = None,
111113
) -> str:
112114
"""Creates a sagemaker pipeline that takes in a callable as a training step.
@@ -127,6 +129,8 @@ def to_pipeline(
127129
code for Lineage tracking. This code is not used for actual transformation.
128130
max_retries (Optional[int]): The number of times to retry sagemaker pipeline step.
129131
If not specified, sagemaker pipline step will not retry.
132+
tags (List[Tuple[str, str]): A list of tags attached to the pipeline. If not specified,
133+
no custom tags will be attached to the pipeline.
130134
sagemaker_session (Optional[Session]): Session object which manages interactions
131135
with Amazon SageMaker APIs and any other AWS services needed. If not specified, the
132136
function creates one using the default AWS configuration chain.
@@ -135,6 +139,7 @@ def to_pipeline(
135139
"""
136140

137141
_validate_input_for_to_pipeline_api(pipeline_name, step)
142+
_validate_tags_for_to_pipeline_api(tags)
138143

139144
_sagemaker_session = sagemaker_session or Session()
140145

@@ -200,12 +205,15 @@ def to_pipeline(
200205
sagemaker_session=_sagemaker_session,
201206
parameters=[SCHEDULED_TIME_PIPELINE_PARAMETER],
202207
)
208+
pipeline_tags = [dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)]
209+
if tags:
210+
pipeline_tags.extend([dict(Key=k, Value=v) for k, v in tags])
203211

204212
pipeline = Pipeline(**pipeline_request_dict)
205213
logger.info("Creating/Updating sagemaker pipeline %s", pipeline_name)
206214
pipeline.upsert(
207215
role_arn=_role,
208-
tags=[dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)],
216+
tags=pipeline_tags,
209217
)
210218
logger.info("Created sagemaker pipeline %s", pipeline_name)
211219

@@ -514,6 +522,28 @@ def _validate_input_for_to_pipeline_api(pipeline_name: str, step: Callable) -> N
514522
)
515523

516524

525+
def _validate_tags_for_to_pipeline_api(tags: List[Tuple[str, str]]) -> None:
526+
"""Validate tags provided to to_pipeline API.
527+
528+
Args:
529+
tags (List[Tuple[str, str]]): A list of tags attached to the pipeline.
530+
531+
Raises (ValueError): raises ValueError when any of the following scenario happen:
532+
1. more than 47 tags are provided to API.
533+
2. reserved tag keys are provided to API.
534+
"""
535+
if len(tags) > 48:
536+
raise ValueError(
537+
"to_pipeline can only accept up to 47 tags. Please reduce the number of tags provided."
538+
)
539+
provided_tag_keys = [tag_key_value_pair[0] for tag_key_value_pair in tags]
540+
for reserved_tag_key in TO_PIPELINE_RESERVED_TAG_KEYS:
541+
if reserved_tag_key in provided_tag_keys:
542+
raise ValueError(
543+
f"{reserved_tag_key} is a reserved tag key for to_pipeline API. Please choose another tag."
544+
)
545+
546+
517547
def _validate_lineage_resources_for_to_pipeline_api(
518548
feature_processor_config: FeatureProcessorConfig, sagemaker_session: Session
519549
) -> None:

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def test_to_pipeline(
244244
step=wrapped_func,
245245
role=EXECUTION_ROLE_ARN,
246246
max_retries=1,
247+
tags=[("tag_key_1", "tag_value_1"), ("tag_key_2", "tag_value_2")],
247248
sagemaker_session=session,
248249
)
249250
assert pipeline_arn == PIPELINE_ARN
@@ -346,7 +347,11 @@ def test_to_pipeline(
346347
[
347348
call(
348349
role_arn=EXECUTION_ROLE_ARN,
349-
tags=[dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE)],
350+
tags=[
351+
dict(Key=FEATURE_PROCESSOR_TAG_KEY, Value=FEATURE_PROCESSOR_TAG_VALUE),
352+
dict(Key="tag_key_1", Value="tag_value_1"),
353+
dict(Key="tag_key_2", Value="tag_value_2"),
354+
],
350355
),
351356
call(
352357
role_arn=EXECUTION_ROLE_ARN,
@@ -527,6 +532,128 @@ def test_to_pipeline_pipeline_name_length_limit_exceeds(
527532
)
528533

529534

535+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
536+
@patch(
537+
"sagemaker.remote_function.job._JobSettings._get_default_spark_image",
538+
return_value="some_image_uri",
539+
)
540+
@patch("sagemaker.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN)
541+
def test_to_pipeline_too_many_tags(get_execution_role, mock_spark_image, session):
542+
session.sagemaker_config = None
543+
session.boto_region_name = TEST_REGION
544+
session.expand_role.return_value = EXECUTION_ROLE_ARN
545+
spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"])
546+
job_settings = _JobSettings(
547+
spark_config=spark_config,
548+
s3_root_uri=S3_URI,
549+
role=EXECUTION_ROLE_ARN,
550+
include_local_workdir=True,
551+
instance_type="ml.m5.large",
552+
encrypt_inter_container_traffic=True,
553+
sagemaker_session=session,
554+
)
555+
jobs_container_entrypoint = [
556+
"/bin/bash",
557+
f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}",
558+
]
559+
jobs_container_entrypoint.extend(["--jars", "path_a"])
560+
jobs_container_entrypoint.extend(["--py-files", "path_b"])
561+
jobs_container_entrypoint.extend(["--files", "path_c"])
562+
jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH])
563+
container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"]
564+
container_args.extend(["--region", session.boto_region_name])
565+
566+
mock_feature_processor_config = Mock(
567+
mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg"
568+
)
569+
mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK
570+
571+
wrapped_func = Mock(
572+
Callable,
573+
feature_processor_config=mock_feature_processor_config,
574+
job_settings=job_settings,
575+
wrapped_func=job_function,
576+
)
577+
wrapped_func.feature_processor_config.return_value = mock_feature_processor_config
578+
wrapped_func.job_settings.return_value = job_settings
579+
wrapped_func.wrapped_func.return_value = job_function
580+
581+
tags = [("key_" + str(i), "value_" + str(i)) for i in range(50)]
582+
583+
with pytest.raises(
584+
ValueError,
585+
match="to_pipeline can only accept up to 47 tags. Please reduce the number of tags provided.",
586+
):
587+
to_pipeline(
588+
pipeline_name="pipeline_name",
589+
step=wrapped_func,
590+
role=EXECUTION_ROLE_ARN,
591+
max_retries=1,
592+
tags=tags,
593+
sagemaker_session=session,
594+
)
595+
596+
597+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
598+
@patch(
599+
"sagemaker.remote_function.job._JobSettings._get_default_spark_image",
600+
return_value="some_image_uri",
601+
)
602+
@patch("sagemaker.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN)
603+
def test_to_pipeline_used_reserved_tags(get_execution_role, mock_spark_image, session):
604+
session.sagemaker_config = None
605+
session.boto_region_name = TEST_REGION
606+
session.expand_role.return_value = EXECUTION_ROLE_ARN
607+
spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"])
608+
job_settings = _JobSettings(
609+
spark_config=spark_config,
610+
s3_root_uri=S3_URI,
611+
role=EXECUTION_ROLE_ARN,
612+
include_local_workdir=True,
613+
instance_type="ml.m5.large",
614+
encrypt_inter_container_traffic=True,
615+
sagemaker_session=session,
616+
)
617+
jobs_container_entrypoint = [
618+
"/bin/bash",
619+
f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}",
620+
]
621+
jobs_container_entrypoint.extend(["--jars", "path_a"])
622+
jobs_container_entrypoint.extend(["--py-files", "path_b"])
623+
jobs_container_entrypoint.extend(["--files", "path_c"])
624+
jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH])
625+
container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"]
626+
container_args.extend(["--region", session.boto_region_name])
627+
628+
mock_feature_processor_config = Mock(
629+
mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg"
630+
)
631+
mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK
632+
633+
wrapped_func = Mock(
634+
Callable,
635+
feature_processor_config=mock_feature_processor_config,
636+
job_settings=job_settings,
637+
wrapped_func=job_function,
638+
)
639+
wrapped_func.feature_processor_config.return_value = mock_feature_processor_config
640+
wrapped_func.job_settings.return_value = job_settings
641+
wrapped_func.wrapped_func.return_value = job_function
642+
643+
with pytest.raises(
644+
ValueError,
645+
match="sm-fs-fe:created-from is a reserved tag key for to_pipeline API. Please choose another tag.",
646+
):
647+
to_pipeline(
648+
pipeline_name="pipeline_name",
649+
step=wrapped_func,
650+
role=EXECUTION_ROLE_ARN,
651+
max_retries=1,
652+
tags=[("sm-fs-fe:created-from", "random")],
653+
sagemaker_session=session,
654+
)
655+
656+
530657
@patch(
531658
"sagemaker.feature_store.feature_processor.feature_scheduler._validate_pipeline_lineage_resources",
532659
return_value=None,

0 commit comments

Comments
 (0)