Skip to content

Commit ef3d2b2

Browse files
can-sunjiapinw
andauthored
feat: Feature Processor event based triggers (#1132) (#4202)
Co-authored-by: jiapinw <[email protected]>
1 parent e9617f5 commit ef3d2b2

18 files changed

+1496
-14
lines changed

src/sagemaker/feature_store/feature_processor/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,16 @@
3030
to_pipeline,
3131
schedule,
3232
describe,
33+
put_trigger,
34+
delete_trigger,
35+
enable_trigger,
36+
disable_trigger,
3337
delete_schedule,
3438
list_pipelines,
3539
execute,
3640
TransformationCode,
41+
FeatureProcessorPipelineEvents,
42+
)
43+
from sagemaker.feature_store.feature_processor._enums import ( # noqa: F401
44+
FeatureProcessorPipelineExecutionStatus,
3745
)

src/sagemaker/feature_store/feature_processor/_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge"
1919
DEFAULT_SCHEDULE_STATE = "ENABLED"
20+
DEFAULT_TRIGGER_STATE = "ENABLED"
2021
UNDERSCORE = "_"
2122
RESOURCE_NOT_FOUND_EXCEPTION = "ResourceNotFoundException"
2223
RESOURCE_NOT_FOUND = "ResourceNotFound"
@@ -36,6 +37,8 @@
3637
FEATURE_PROCESSOR_TAG_KEY = "sm-fs-fe:created-from"
3738
FEATURE_PROCESSOR_TAG_VALUE = "fp-to-pipeline"
3839
FEATURE_GROUP_ARN_REGEX_PATTERN = r"arn:(.*?):sagemaker:(.*?):(.*?):feature-group/(.*?)$"
40+
PIPELINE_ARN_REGEX_PATTERN = r"arn:(.*?):sagemaker:(.*?):(.*?):pipeline/(.*?)$"
41+
EVENTBRIDGE_RULE_ARN_REGEX_PATTERN = r"arn:(.*?):events:(.*?):(.*?):rule/(.*?)$"
3942
SAGEMAKER_WHL_FILE_S3_PATH = "s3://ada-private-beta/sagemaker-2.151.1.dev0-py2.py3-none-any.whl"
4043
S3_DATA_DISTRIBUTION_TYPE = "FullyReplicated"
4144
PIPELINE_CONTEXT_NAME_TAG_KEY = "sm-fs-fe:feature-engineering-pipeline-context-name"
@@ -45,3 +48,7 @@
4548
PIPELINE_CONTEXT_NAME_TAG_KEY,
4649
PIPELINE_VERSION_CONTEXT_NAME_TAG_KEY,
4750
]
51+
BASE_EVENT_PATTERN = {
52+
"source": ["aws.sagemaker"],
53+
"detail": {"currentPipelineExecutionStatus": [], "pipelineArn": []},
54+
}

src/sagemaker/feature_store/feature_processor/_enums.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,13 @@ class FeatureProcessorMode(Enum):
2121

2222
PYSPARK = "pyspark" # Execute a pyspark job.
2323
PYTHON = "python" # Execute a regular python script.
24+
25+
26+
class FeatureProcessorPipelineExecutionStatus(Enum):
27+
"""Enum of feature_processor pipeline execution status."""
28+
29+
EXECUTING = "Executing"
30+
STOPPING = "Stopping"
31+
STOPPED = "Stopped"
32+
FAILED = "Failed"
33+
SUCCEEDED = "Succeeded"
Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Contains classes for EventBridge Schedule management for a feature processor."""
14+
from __future__ import absolute_import
15+
16+
import json
17+
import logging
18+
import re
19+
from typing import Dict, List, Tuple, Optional, Any
20+
import attr
21+
from botocore.exceptions import ClientError
22+
from botocore.paginate import PageIterator
23+
from sagemaker import Session
24+
from sagemaker.feature_store.feature_processor._feature_processor_pipeline_events import (
25+
FeatureProcessorPipelineEvents,
26+
)
27+
from sagemaker.feature_store.feature_processor._constants import (
28+
RESOURCE_NOT_FOUND_EXCEPTION,
29+
PIPELINE_ARN_REGEX_PATTERN,
30+
BASE_EVENT_PATTERN,
31+
)
32+
from sagemaker.feature_store.feature_processor._enums import (
33+
FeatureProcessorPipelineExecutionStatus,
34+
)
35+
36+
logger = logging.getLogger("sagemaker")
37+
38+
39+
@attr.s
40+
class EventBridgeRuleHelper:
41+
"""Contains helper methods for managing EventBridge rules for a feature processor."""
42+
43+
sagemaker_session: Session = attr.ib()
44+
event_bridge_rule_client = attr.ib()
45+
46+
def put_rule(
47+
self,
48+
source_pipeline_events: List[FeatureProcessorPipelineEvents],
49+
target_pipeline: str,
50+
event_pattern: str,
51+
state: str,
52+
) -> str:
53+
"""Creates an EventBridge Rule for a given target pipeline.
54+
55+
Args:
56+
source_pipeline_events: The list of pipeline events that trigger the EventBridge Rule.
57+
target_pipeline: The name of the pipeline that is triggered by the EventBridge Rule.
58+
event_pattern: The EventBridge EventPattern that triggers the EventBridge Rule.
59+
If specified, will override source_pipeline_events.
60+
state: Indicates whether the rule is enabled or disabled.
61+
62+
Returns:
63+
The Amazon Resource Name (ARN) of the rule.
64+
"""
65+
self._validate_feature_processor_pipeline_events(source_pipeline_events)
66+
rule_name = target_pipeline
67+
_event_patterns = (
68+
event_pattern
69+
or self._generate_event_pattern_from_feature_processor_pipeline_events(
70+
source_pipeline_events
71+
)
72+
)
73+
rule_arn = self.event_bridge_rule_client.put_rule(
74+
Name=rule_name, EventPattern=_event_patterns, State=state
75+
)["RuleArn"]
76+
return rule_arn
77+
78+
def put_target(
79+
self,
80+
rule_name: str,
81+
target_pipeline: str,
82+
target_pipeline_parameters: Dict[str, str],
83+
role_arn: str,
84+
) -> None:
85+
"""Attach target pipeline to an event based trigger.
86+
87+
Args:
88+
rule_name: The name of the EventBridge Rule.
89+
target_pipeline: The name of the pipeline that is triggered by the EventBridge Rule.
90+
target_pipeline_parameters: The list of parameters to start execution of a pipeline.
91+
role_arn: The Amazon Resource Name (ARN) of the IAM role associated with the rule.
92+
"""
93+
target_pipeline_arn_and_name = self._generate_pipeline_arn_and_name(target_pipeline)
94+
target_pipeline_name = target_pipeline_arn_and_name["pipeline_name"]
95+
target_pipeline_arn = target_pipeline_arn_and_name["pipeline_arn"]
96+
target_request_dict = {
97+
"Id": target_pipeline_name,
98+
"Arn": target_pipeline_arn,
99+
"RoleArn": role_arn,
100+
}
101+
if target_pipeline_parameters:
102+
target_request_dict["SageMakerPipelineParameters"] = {
103+
"PipelineParameterList": target_pipeline_parameters
104+
}
105+
put_targets_response = self.event_bridge_rule_client.put_targets(
106+
Rule=rule_name,
107+
Targets=[target_request_dict],
108+
)
109+
if put_targets_response["FailedEntryCount"] != 0:
110+
error_msg = put_targets_response["FailedEntries"][0]["ErrorMessage"]
111+
raise Exception(f"Failed to add target pipeline to rule. Failure reason: {error_msg}")
112+
113+
def delete_rule(self, rule_name: str) -> None:
114+
"""Deletes an EventBridge Rule of a given pipeline if there is one.
115+
116+
Args:
117+
rule_name: The name of the EventBridge Rule.
118+
"""
119+
self.event_bridge_rule_client.delete_rule(Name=rule_name)
120+
121+
def remove_targets(self, rule_name: str, ids: List[str]) -> None:
122+
"""Deletes an EventBridge Targets of a given rule if there is one.
123+
124+
Args:
125+
rule_name: The name of the EventBridge Rule.
126+
ids: The ids of the EventBridge Target.
127+
"""
128+
self.event_bridge_rule_client.remove_targets(Rule=rule_name, Ids=ids)
129+
130+
def list_targets_by_rule(self, rule_name: str) -> PageIterator:
131+
"""List EventBridge Targets of a given rule.
132+
133+
Args:
134+
rule_name: The name of the EventBridge Rule.
135+
136+
Returns:
137+
The page iterator of list_targets_by_rule call.
138+
"""
139+
return self.event_bridge_rule_client.get_paginator("list_targets_by_rule").paginate(
140+
Rule=rule_name
141+
)
142+
143+
def describe_rule(self, rule_name: str) -> Optional[Dict[str, Any]]:
144+
"""Describe the EventBridge Rule ARN corresponding to a sagemaker pipeline
145+
146+
Args:
147+
rule_name: The name of the EventBridge Rule.
148+
Returns:
149+
Optional[Dict[str, str]] : Describe EventBridge Rule response if exists.
150+
"""
151+
try:
152+
event_bridge_rule_response = self.event_bridge_rule_client.describe_rule(Name=rule_name)
153+
return event_bridge_rule_response
154+
except ClientError as e:
155+
if RESOURCE_NOT_FOUND_EXCEPTION == e.response["Error"]["Code"]:
156+
logger.info("No EventBridge Rule found for pipeline %s.", rule_name)
157+
return None
158+
raise e
159+
160+
def enable_rule(self, rule_name: str) -> None:
161+
"""Enables an EventBridge Rule of a given pipeline if there is one.
162+
163+
Args:
164+
rule_name: The name of the EventBridge Rule.
165+
"""
166+
self.event_bridge_rule_client.enable_rule(Name=rule_name)
167+
logger.info("Enabled EventBridge Rule for pipeline %s.", rule_name)
168+
169+
def disable_rule(self, rule_name: str) -> None:
170+
"""Disables an EventBridge Rule of a given pipeline if there is one.
171+
172+
Args:
173+
rule_name: The name of the EventBridge Rule.
174+
"""
175+
self.event_bridge_rule_client.disable_rule(Name=rule_name)
176+
logger.info("Disabled EventBridge Rule for pipeline %s.", rule_name)
177+
178+
def add_tags(self, rule_arn: str, tags: List[Dict[str, str]]) -> None:
179+
"""Adds tags to the EventBridge Rule.
180+
181+
Args:
182+
rule_arn: The ARN of the EventBridge Rule.
183+
tags: List of tags to be added.
184+
"""
185+
self.event_bridge_rule_client.tag_resource(ResourceARN=rule_arn, Tags=tags)
186+
187+
def _generate_event_pattern_from_feature_processor_pipeline_events(
188+
self, pipeline_events: List[FeatureProcessorPipelineEvents]
189+
) -> str:
190+
"""Generates the event pattern json string from the pipeline events.
191+
192+
Args:
193+
pipeline_events: List of pipeline events.
194+
Returns:
195+
str: The event pattern json string.
196+
197+
Raises:
198+
ValueError: If pipeline events contain duplicate pipeline names.
199+
"""
200+
201+
result_event_pattern = {
202+
"detail-type": ["SageMaker Model Building Pipeline Execution Status Change"],
203+
}
204+
filters = []
205+
desired_status_to_pipeline_names_map = (
206+
self._aggregate_pipeline_events_with_same_desired_status(pipeline_events)
207+
)
208+
for desired_status in desired_status_to_pipeline_names_map:
209+
pipeline_arns = [
210+
self._generate_pipeline_arn_and_name(pipeline_name)["pipeline_arn"]
211+
for pipeline_name in desired_status_to_pipeline_names_map[desired_status]
212+
]
213+
curr_filter = BASE_EVENT_PATTERN.copy()
214+
curr_filter["detail"]["pipelineArn"] = pipeline_arns
215+
curr_filter["detail"]["currentPipelineExecutionStatus"] = [
216+
status_enum.value for status_enum in desired_status
217+
]
218+
filters.append(curr_filter)
219+
if len(filters) > 1:
220+
result_event_pattern["$or"] = filters
221+
else:
222+
result_event_pattern.update(filters[0])
223+
return json.dumps(result_event_pattern)
224+
225+
def _validate_feature_processor_pipeline_events(
226+
self, pipeline_events: List[FeatureProcessorPipelineEvents]
227+
) -> None:
228+
"""Validates the pipeline events.
229+
230+
Args:
231+
pipeline_events: List of pipeline events.
232+
Raises:
233+
ValueError: If pipeline events contain duplicate pipeline names.
234+
"""
235+
236+
unique_pipelines = {event.pipeline_name for event in pipeline_events}
237+
potential_infinite_loop = []
238+
if len(unique_pipelines) != len(pipeline_events):
239+
raise ValueError("Pipeline names in pipeline_events must be unique.")
240+
241+
for event in pipeline_events:
242+
if FeatureProcessorPipelineExecutionStatus.EXECUTING in event.pipeline_execution_status:
243+
potential_infinite_loop.append(event.pipeline_name)
244+
if potential_infinite_loop:
245+
logger.warning(
246+
"Potential infinite loop detected for pipelines %s. "
247+
"Setting pipeline_execution_status to EXECUTING might cause infinite loop. "
248+
"Please consider a terminal status instead.",
249+
potential_infinite_loop,
250+
)
251+
252+
def _aggregate_pipeline_events_with_same_desired_status(
253+
self, pipeline_events: List[FeatureProcessorPipelineEvents]
254+
) -> Dict[Tuple, List[str]]:
255+
"""Aggregate pipeline events with same desired status.
256+
257+
e.g.
258+
{
259+
(FeatureProcessorPipelineExecutionStatus.FAILED,
260+
FeatureProcessorPipelineExecutionStatus.STOPPED):
261+
["pipeline_name_1", "pipeline_name_2"],
262+
(FeatureProcessorPipelineExecutionStatus.STOPPED,
263+
FeatureProcessorPipelineExecutionStatus.STOPPED):
264+
["pipeline_name_3"],
265+
}
266+
Args:
267+
pipeline_events: List of pipeline events.
268+
Returns:
269+
Dict[Tuple, List[str]]: A dictionary of desired status keys and corresponding pipeline
270+
names.
271+
"""
272+
events_by_desired_status = {}
273+
274+
for event in pipeline_events:
275+
sorted_execution_status = sorted(event.pipeline_execution_status, key=lambda x: x.value)
276+
desired_status_keys = tuple(sorted_execution_status)
277+
278+
if desired_status_keys not in events_by_desired_status:
279+
events_by_desired_status[desired_status_keys] = []
280+
events_by_desired_status[desired_status_keys].append(event.pipeline_name)
281+
282+
return events_by_desired_status
283+
284+
def _generate_pipeline_arn_and_name(self, pipeline_uri: str) -> Dict[str, str]:
285+
"""Generate pipeline arn and pipeline name from pipeline uri.
286+
287+
Args:
288+
pipeline_uri: The name or arn of the pipeline.
289+
Returns:
290+
Dict[str, str]: The arn and name of the pipeline.
291+
"""
292+
match = re.match(PIPELINE_ARN_REGEX_PATTERN, pipeline_uri)
293+
pipeline_arn = ""
294+
pipeline_name = ""
295+
if not match:
296+
pipeline_name = pipeline_uri
297+
describe_pipeline_response = self.sagemaker_session.sagemaker_client.describe_pipeline(
298+
PipelineName=pipeline_name
299+
)
300+
pipeline_arn = describe_pipeline_response["PipelineArn"]
301+
else:
302+
pipeline_arn = pipeline_uri
303+
pipeline_name = match.group(4)
304+
return dict(pipeline_arn=pipeline_arn, pipeline_name=pipeline_name)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Contains data classes for the Feature Processor Pipeline Events."""
14+
from __future__ import absolute_import
15+
16+
from typing import List
17+
import attr
18+
from sagemaker.feature_store.feature_processor._enums import FeatureProcessorPipelineExecutionStatus
19+
20+
21+
@attr.s(frozen=True)
22+
class FeatureProcessorPipelineEvents:
23+
"""Immutable data class containing the execution events for a FeatureProcessor pipeline.
24+
25+
This class is used for creating event based triggers for feature processor pipelines.
26+
"""
27+
28+
pipeline_name: str = attr.ib()
29+
pipeline_execution_status: List[FeatureProcessorPipelineExecutionStatus] = attr.ib()

0 commit comments

Comments
 (0)