Skip to content

Commit 850fc98

Browse files
patrickmcarlosNamrata Madan
authored andcommitted
feature: feature store with_feature_group functionality changes (aws#3630)
1 parent d1b3e4b commit 850fc98

File tree

3 files changed

+507
-19
lines changed

3 files changed

+507
-19
lines changed

src/sagemaker/feature_store/dataset_builder.py

Lines changed: 101 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,34 @@ class TableType(Enum):
4343
DATA_FRAME = "DataFrame"
4444

4545

46+
@attr.s
47+
class JoinTypeEnum(Enum):
48+
"""Enum of Join types.
49+
50+
The Join comparator can be "INNER_JOIN", "LEFT_JOIN", "RIGHT_JOIN", "FULL_JOIN"
51+
"""
52+
53+
INNER_JOIN = "JOIN"
54+
LEFT_JOIN = "LEFT JOIN"
55+
RIGHT_JOIN = "RIGHT JOIN"
56+
FULL_JOIN = "FULL JOIN"
57+
58+
59+
@attr.s
60+
class JoinComparatorEnum(Enum):
61+
"""Enum of Join comparators.
62+
63+
The Join comparator can be "EQUALS", "GREATER_THAN", "LESS_THAN",
64+
"GREATER_THAN_OR_EQUAL_TO", or "LESS_THAN_OR_EQUAL_TO"
65+
"""
66+
67+
EQUALS = "="
68+
GREATER_THAN = ">"
69+
GREATER_THAN_OR_EQUAL_TO = ">="
70+
LESS_THAN = "<"
71+
LESS_THAN_OR_EQUAL_TO = "<="
72+
73+
4674
@attr.s
4775
class FeatureGroupToBeMerged:
4876
"""FeatureGroup metadata which will be used for SQL join.
@@ -55,19 +83,28 @@ class FeatureGroupToBeMerged:
5583
Attributes:
5684
features (List[str]): A list of strings representing feature names of this FeatureGroup.
5785
included_feature_names (List[str]): A list of strings representing features to be
58-
included in the sql join.
86+
included in the SQL join.
5987
projected_feature_names (List[str]): A list of strings representing features to be
6088
included for final projection in output.
6189
catalog (str): A string representing the catalog.
6290
database (str): A string representing the database.
6391
table_name (str): A string representing the Athena table name of this FeatureGroup.
64-
record_dentifier_feature_name (str): A string representing the record identifier feature.
92+
record_identifier_feature_name (str): A string representing the record identifier feature.
6593
event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the
6694
event time identifier feature.
6795
target_feature_name_in_base (str): A string representing the feature name in base which will
6896
be used as target join key (default: None).
6997
table_type (TableType): A TableType representing the type of table if it is Feature Group or
7098
Panda Data Frame (default: None).
99+
feature_name_in_target (str): A string representing the feature name in the target feature
100+
group that will be compared to the target feature in the base feature group.
101+
If None is provided, the record identifier feature will be used in the
102+
SQL join. (default: None).
103+
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
104+
used when joining the target feature in the base feature group and the feature
105+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
106+
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
107+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
71108
"""
72109

73110
features: List[str] = attr.ib()
@@ -80,12 +117,18 @@ class FeatureGroupToBeMerged:
80117
event_time_identifier_feature: FeatureDefinition = attr.ib()
81118
target_feature_name_in_base: str = attr.ib(default=None)
82119
table_type: TableType = attr.ib(default=None)
120+
feature_name_in_target: str = attr.ib(default=None)
121+
join_comparator: JoinComparatorEnum = attr.ib(default=JoinComparatorEnum.EQUALS)
122+
join_type: JoinTypeEnum = attr.ib(default=JoinTypeEnum.INNER_JOIN)
83123

84124

85125
def construct_feature_group_to_be_merged(
86-
feature_group: FeatureGroup,
126+
target_feature_group: FeatureGroup,
87127
included_feature_names: List[str],
88128
target_feature_name_in_base: str = None,
129+
feature_name_in_target: str = None,
130+
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
131+
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
89132
) -> FeatureGroupToBeMerged:
90133
"""Construct a FeatureGroupToBeMerged object by provided parameters.
91134
@@ -95,18 +138,29 @@ def construct_feature_group_to_be_merged(
95138
included in the output.
96139
target_feature_name_in_base (str): A string representing the feature name in base which
97140
will be used as target join key (default: None).
141+
feature_name_in_target (str): A string representing the feature name in the target feature
142+
group that will be compared to the target feature in the base feature group.
143+
If None is provided, the record identifier feature will be used in the
144+
SQL join. (default: None).
145+
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
146+
used when joining the target feature in the base feature group and the feature
147+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
148+
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
149+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
98150
Returns:
99151
A FeatureGroupToBeMerged object.
100152
101153
Raises:
102154
ValueError: Invalid feature name(s) in included_feature_names.
103155
"""
104-
feature_group_metadata = feature_group.describe()
156+
feature_group_metadata = target_feature_group.describe()
105157
data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get(
106158
"DataCatalogConfig", None
107159
)
108160
if not data_catalog_config:
109-
raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.")
161+
raise RuntimeError(
162+
f"No metastore is configured with FeatureGroup {target_feature_group.name}."
163+
)
110164

111165
record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None)
112166
feature_definitions = feature_group_metadata.get("FeatureDefinitions", [])
@@ -126,10 +180,15 @@ def construct_feature_group_to_be_merged(
126180
catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG
127181
features = [feature.get("FeatureName", None) for feature in feature_definitions]
128182

183+
if feature_name_in_target is not None and feature_name_in_target not in features:
184+
raise ValueError(
185+
f"Feature {feature_name_in_target} not found in FeatureGroup {target_feature_group.name}"
186+
)
187+
129188
for included_feature in included_feature_names or []:
130189
if included_feature not in features:
131190
raise ValueError(
132-
f"Feature {included_feature} not found in FeatureGroup {feature_group.name}"
191+
f"Feature {included_feature} not found in FeatureGroup {target_feature_group.name}"
133192
)
134193
if not included_feature_names:
135194
included_feature_names = features
@@ -151,6 +210,9 @@ def construct_feature_group_to_be_merged(
151210
FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type),
152211
target_feature_name_in_base,
153212
TableType.FEATURE_GROUP,
213+
feature_name_in_target,
214+
join_comparator,
215+
join_type,
154216
)
155217

156218

@@ -227,21 +289,38 @@ def with_feature_group(
227289
feature_group: FeatureGroup,
228290
target_feature_name_in_base: str = None,
229291
included_feature_names: List[str] = None,
292+
feature_name_in_target: str = None,
293+
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
294+
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
230295
):
231296
"""Join FeatureGroup with base.
232297
233298
Args:
234-
feature_group (FeatureGroup): A FeatureGroup which will be joined to base.
299+
feature_group (FeatureGroup): A target FeatureGroup which will be joined to base.
235300
target_feature_name_in_base (str): A string representing the feature name in base which
236-
will be used as target join key (default: None).
301+
will be used as a join key (default: None).
237302
included_feature_names (List[str]): A list of strings representing features to be
238303
included in the output (default: None).
239-
Returns:
240-
This DatasetBuilder object.
304+
feature_name_in_target (str): A string representing the feature name in the target
305+
feature group that will be compared to the target feature in the base feature group.
306+
If None is provided, the record identifier feature will be used in the
307+
SQL join. (default: None).
308+
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
309+
used when joining the target feature in the base feature group and the feature
310+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
311+
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
312+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
313+
Returns:
314+
This DatasetBuilder object.
241315
"""
242316
self._feature_groups_to_be_merged.append(
243317
construct_feature_group_to_be_merged(
244-
feature_group, included_feature_names, target_feature_name_in_base
318+
feature_group,
319+
included_feature_names,
320+
target_feature_name_in_base,
321+
feature_name_in_target,
322+
join_comparator,
323+
join_type,
245324
)
246325
)
247326
return self
@@ -905,10 +984,18 @@ def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffi
905984
Returns:
906985
The JOIN query string.
907986
"""
987+
988+
feature_name_in_target = (
989+
feature_group.feature_name_in_target
990+
if feature_group.feature_name_in_target is not None
991+
else feature_group.record_identifier_feature_name
992+
)
993+
908994
join_condition_string = (
909-
f"\nJOIN fg_{suffix}\n"
910-
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" = '
911-
+ f'fg_{suffix}."{feature_group.record_identifier_feature_name}"'
995+
f"\n{feature_group.join_type.value} fg_{suffix}\n"
996+
+ f'ON fg_base."{feature_group.target_feature_name_in_base}"'
997+
+ f" {feature_group.join_comparator.value} "
998+
+ f'fg_{suffix}."{feature_name_in_target}"'
912999
)
9131000
base_timestamp_cast_function_name = "from_unixtime"
9141001
if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING:

0 commit comments

Comments
 (0)