Skip to content

Commit 832b211

Browse files
committed
[Fix] remove validation on number of tags and added integ tests
1 parent 867e43d commit 832b211

File tree

3 files changed

+11
-70
lines changed

3 files changed

+11
-70
lines changed

src/sagemaker/feature_store/feature_processor/feature_scheduler.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def to_pipeline(
139139
"""
140140

141141
_validate_input_for_to_pipeline_api(pipeline_name, step)
142-
_validate_tags_for_to_pipeline_api(tags)
142+
if tags:
143+
_validate_tags_for_to_pipeline_api(tags)
143144

144145
_sagemaker_session = sagemaker_session or Session()
145146

@@ -529,13 +530,8 @@ def _validate_tags_for_to_pipeline_api(tags: List[Tuple[str, str]]) -> None:
529530
tags (List[Tuple[str, str]]): A list of tags attached to the pipeline.
530531
531532
Raises (ValueError): raises ValueError when any of the following scenario happen:
532-
1. more than 47 tags are provided to API.
533-
2. reserved tag keys are provided to API.
533+
1. reserved tag keys are provided to API.
534534
"""
535-
if len(tags) > 48:
536-
raise ValueError(
537-
"to_pipeline can only accept up to 47 tags. Please reduce the number of tags provided."
538-
)
539535
provided_tag_keys = [tag_key_value_pair[0] for tag_key_value_pair in tags]
540536
for reserved_tag_key in TO_PIPELINE_RESERVED_TAG_KEYS:
541537
if reserved_tag_key in provided_tag_keys:

tests/integ/sagemaker/feature_store/feature_processor/test_feature_processor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,18 @@ def transform(raw_s3_data_as_df):
550550
step=transform,
551551
role=get_execution_role(sagemaker_session),
552552
max_retries=2,
553+
tags=[("integ_test_tag_key_1", "integ_test_tag_key_2")],
553554
sagemaker_session=sagemaker_session,
554555
)
556+
_sagemaker_client = get_sagemaker_client(sagemaker_session=sagemaker_session)
555557

556558
assert pipeline_arn is not None
557559

560+
tags = _sagemaker_client.list_tags(ResourceArn=pipeline_arn)["Tags"]
561+
562+
tag_keys = [tag["Key"] for tag in tags]
563+
assert "integ_test_tag_key_1" in tag_keys
564+
558565
pipeline_description = Pipeline(name=pipeline_name).describe()
559566
assert pipeline_arn == pipeline_description["PipelineArn"]
560567
assert get_execution_role(sagemaker_session) == pipeline_description["RoleArn"]
@@ -570,7 +577,7 @@ def transform(raw_s3_data_as_df):
570577

571578
status = _wait_for_pipeline_execution_to_reach_terminal_state(
572579
pipeline_execution_arn=pipeline_execution_arn,
573-
sagemaker_client=get_sagemaker_client(sagemaker_session=sagemaker_session),
580+
sagemaker_client=_sagemaker_client,
574581
)
575582
assert status == "Succeeded"
576583

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -532,68 +532,6 @@ def test_to_pipeline_pipeline_name_length_limit_exceeds(
532532
)
533533

534534

535-
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
536-
@patch(
537-
"sagemaker.remote_function.job._JobSettings._get_default_spark_image",
538-
return_value="some_image_uri",
539-
)
540-
@patch("sagemaker.remote_function.job.get_execution_role", return_value=EXECUTION_ROLE_ARN)
541-
def test_to_pipeline_too_many_tags(get_execution_role, mock_spark_image, session):
542-
session.sagemaker_config = None
543-
session.boto_region_name = TEST_REGION
544-
session.expand_role.return_value = EXECUTION_ROLE_ARN
545-
spark_config = SparkConfig(submit_files=["file_a", "file_b", "file_c"])
546-
job_settings = _JobSettings(
547-
spark_config=spark_config,
548-
s3_root_uri=S3_URI,
549-
role=EXECUTION_ROLE_ARN,
550-
include_local_workdir=True,
551-
instance_type="ml.m5.large",
552-
encrypt_inter_container_traffic=True,
553-
sagemaker_session=session,
554-
)
555-
jobs_container_entrypoint = [
556-
"/bin/bash",
557-
f"/opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{ENTRYPOINT_SCRIPT_NAME}",
558-
]
559-
jobs_container_entrypoint.extend(["--jars", "path_a"])
560-
jobs_container_entrypoint.extend(["--py-files", "path_b"])
561-
jobs_container_entrypoint.extend(["--files", "path_c"])
562-
jobs_container_entrypoint.extend([SPARK_APP_SCRIPT_PATH])
563-
container_args = ["--s3_base_uri", f"{S3_URI}/pipeline_name"]
564-
container_args.extend(["--region", session.boto_region_name])
565-
566-
mock_feature_processor_config = Mock(
567-
mode=FeatureProcessorMode.PYSPARK, inputs=[tdh.FEATURE_PROCESSOR_INPUTS], output="some_fg"
568-
)
569-
mock_feature_processor_config.mode.return_value = FeatureProcessorMode.PYSPARK
570-
571-
wrapped_func = Mock(
572-
Callable,
573-
feature_processor_config=mock_feature_processor_config,
574-
job_settings=job_settings,
575-
wrapped_func=job_function,
576-
)
577-
wrapped_func.feature_processor_config.return_value = mock_feature_processor_config
578-
wrapped_func.job_settings.return_value = job_settings
579-
wrapped_func.wrapped_func.return_value = job_function
580-
581-
tags = [("key_" + str(i), "value_" + str(i)) for i in range(50)]
582-
583-
with pytest.raises(
584-
ValueError,
585-
match="to_pipeline can only accept up to 47 tags. Please reduce the number of tags provided.",
586-
):
587-
to_pipeline(
588-
pipeline_name="pipeline_name",
589-
step=wrapped_func,
590-
role=EXECUTION_ROLE_ARN,
591-
max_retries=1,
592-
tags=tags,
593-
sagemaker_session=session,
594-
)
595-
596-
597535
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
598536
@patch(
599537
"sagemaker.remote_function.job._JobSettings._get_default_spark_image",

0 commit comments

Comments
 (0)