Skip to content

Commit 67dc885

Browse files
feature: with_feature_group [feature_store] (#3658)
1 parent 16f5a68 commit 67dc885

File tree

3 files changed

+509
-19
lines changed

3 files changed

+509
-19
lines changed

src/sagemaker/feature_store/dataset_builder.py

Lines changed: 103 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,36 @@ class TableType(Enum):
4343
DATA_FRAME = "DataFrame"
4444

4545

46+
@attr.s
47+
class JoinTypeEnum(Enum):
48+
"""Enum of Join types.
49+
50+
The Join type can be "INNER_JOIN", "LEFT_JOIN", "RIGHT_JOIN", "FULL_JOIN", or "CROSS_JOIN".
51+
"""
52+
53+
INNER_JOIN = "JOIN"
54+
LEFT_JOIN = "LEFT JOIN"
55+
RIGHT_JOIN = "RIGHT JOIN"
56+
FULL_JOIN = "FULL JOIN"
57+
CROSS_JOIN = "CROSS JOIN"
58+
59+
60+
@attr.s
61+
class JoinComparatorEnum(Enum):
62+
"""Enum of Join comparators.
63+
64+
The Join comparator can be "EQUALS", "GREATER_THAN", "LESS_THAN",
65+
"GREATER_THAN_OR_EQUAL_TO", "LESS_THAN_OR_EQUAL_TO" or "NOT_EQUAL_TO"
66+
"""
67+
68+
EQUALS = "="
69+
GREATER_THAN = ">"
70+
GREATER_THAN_OR_EQUAL_TO = ">="
71+
LESS_THAN = "<"
72+
LESS_THAN_OR_EQUAL_TO = "<="
73+
NOT_EQUAL_TO = "<>"
74+
75+
4676
@attr.s
4777
class FeatureGroupToBeMerged:
4878
"""FeatureGroup metadata which will be used for SQL join.
@@ -55,19 +85,28 @@ class FeatureGroupToBeMerged:
5585
Attributes:
5686
features (List[str]): A list of strings representing feature names of this FeatureGroup.
5787
included_feature_names (List[str]): A list of strings representing features to be
58-
included in the sql join.
88+
included in the SQL join.
5989
projected_feature_names (List[str]): A list of strings representing features to be
6090
included for final projection in output.
6191
catalog (str): A string representing the catalog.
6292
database (str): A string representing the database.
6393
table_name (str): A string representing the Athena table name of this FeatureGroup.
64-
record_dentifier_feature_name (str): A string representing the record identifier feature.
94+
record_identifier_feature_name (str): A string representing the record identifier feature.
6595
event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the
6696
event time identifier feature.
6797
target_feature_name_in_base (str): A string representing the feature name in base which will
6898
be used as target join key (default: None).
6999
table_type (TableType): A TableType representing the type of table if it is Feature Group or
70100
Panda Data Frame (default: None).
101+
feature_name_in_target (str): A string representing the feature name in the target feature
102+
group that will be compared to the target feature in the base feature group.
103+
If None is provided, the record identifier feature will be used in the
104+
SQL join. (default: None).
105+
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
106+
used when joining the target feature in the base feature group and the feature
107+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
108+
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
109+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
71110
"""
72111

73112
features: List[str] = attr.ib()
@@ -80,12 +119,18 @@ class FeatureGroupToBeMerged:
80119
event_time_identifier_feature: FeatureDefinition = attr.ib()
81120
target_feature_name_in_base: str = attr.ib(default=None)
82121
table_type: TableType = attr.ib(default=None)
122+
feature_name_in_target: str = attr.ib(default=None)
123+
join_comparator: JoinComparatorEnum = attr.ib(default=JoinComparatorEnum.EQUALS)
124+
join_type: JoinTypeEnum = attr.ib(default=JoinTypeEnum.INNER_JOIN)
83125

84126

85127
def construct_feature_group_to_be_merged(
86-
feature_group: FeatureGroup,
128+
target_feature_group: FeatureGroup,
87129
included_feature_names: List[str],
88130
target_feature_name_in_base: str = None,
131+
feature_name_in_target: str = None,
132+
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
133+
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
89134
) -> FeatureGroupToBeMerged:
90135
"""Construct a FeatureGroupToBeMerged object by provided parameters.
91136
@@ -95,18 +140,29 @@ def construct_feature_group_to_be_merged(
95140
included in the output.
96141
target_feature_name_in_base (str): A string representing the feature name in base which
97142
will be used as target join key (default: None).
143+
feature_name_in_target (str): A string representing the feature name in the target feature
144+
group that will be compared to the target feature in the base feature group.
145+
If None is provided, the record identifier feature will be used in the
146+
SQL join. (default: None).
147+
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
148+
used when joining the target feature in the base feature group and the feature
149+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
150+
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
151+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
98152
Returns:
99153
A FeatureGroupToBeMerged object.
100154
101155
Raises:
102156
ValueError: Invalid feature name(s) in included_feature_names.
103157
"""
104-
feature_group_metadata = feature_group.describe()
158+
feature_group_metadata = target_feature_group.describe()
105159
data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get(
106160
"DataCatalogConfig", None
107161
)
108162
if not data_catalog_config:
109-
raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.")
163+
raise RuntimeError(
164+
f"No metastore is configured with FeatureGroup {target_feature_group.name}."
165+
)
110166

111167
record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None)
112168
feature_definitions = feature_group_metadata.get("FeatureDefinitions", [])
@@ -126,10 +182,15 @@ def construct_feature_group_to_be_merged(
126182
catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG
127183
features = [feature.get("FeatureName", None) for feature in feature_definitions]
128184

185+
if feature_name_in_target is not None and feature_name_in_target not in features:
186+
raise ValueError(
187+
f"Feature {feature_name_in_target} not found in FeatureGroup {target_feature_group.name}"
188+
)
189+
129190
for included_feature in included_feature_names or []:
130191
if included_feature not in features:
131192
raise ValueError(
132-
f"Feature {included_feature} not found in FeatureGroup {feature_group.name}"
193+
f"Feature {included_feature} not found in FeatureGroup {target_feature_group.name}"
133194
)
134195
if not included_feature_names:
135196
included_feature_names = features
@@ -151,6 +212,9 @@ def construct_feature_group_to_be_merged(
151212
FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type),
152213
target_feature_name_in_base,
153214
TableType.FEATURE_GROUP,
215+
feature_name_in_target,
216+
join_comparator,
217+
join_type,
154218
)
155219

156220

@@ -236,21 +300,38 @@ def with_feature_group(
236300
feature_group: FeatureGroup,
237301
target_feature_name_in_base: str = None,
238302
included_feature_names: List[str] = None,
303+
feature_name_in_target: str = None,
304+
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
305+
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
239306
):
240307
"""Join FeatureGroup with base.
241308
242309
Args:
243-
feature_group (FeatureGroup): A FeatureGroup which will be joined to base.
310+
feature_group (FeatureGroup): A target FeatureGroup which will be joined to base.
244311
target_feature_name_in_base (str): A string representing the feature name in base which
245-
will be used as target join key (default: None).
312+
will be used as a join key (default: None).
246313
included_feature_names (List[str]): A list of strings representing features to be
247314
included in the output (default: None).
248-
Returns:
249-
This DatasetBuilder object.
315+
feature_name_in_target (str): A string representing the feature name in the target
316+
feature group that will be compared to the target feature in the base feature group.
317+
If None is provided, the record identifier feature will be used in the
318+
SQL join. (default: None).
319+
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
320+
used when joining the target feature in the base feature group and the feature
321+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
322+
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
323+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
324+
Returns:
325+
This DatasetBuilder object.
250326
"""
251327
self._feature_groups_to_be_merged.append(
252328
construct_feature_group_to_be_merged(
253-
feature_group, included_feature_names, target_feature_name_in_base
329+
feature_group,
330+
included_feature_names,
331+
target_feature_name_in_base,
332+
feature_name_in_target,
333+
join_comparator,
334+
join_type,
254335
)
255336
)
256337
return self
@@ -914,10 +995,18 @@ def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffi
914995
Returns:
915996
The JOIN query string.
916997
"""
998+
999+
feature_name_in_target = (
1000+
feature_group.feature_name_in_target
1001+
if feature_group.feature_name_in_target is not None
1002+
else feature_group.record_identifier_feature_name
1003+
)
1004+
9171005
join_condition_string = (
918-
f"\nJOIN fg_{suffix}\n"
919-
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" = '
920-
+ f'fg_{suffix}."{feature_group.record_identifier_feature_name}"'
1006+
f"\n{feature_group.join_type.value} fg_{suffix}\n"
1007+
+ f'ON fg_base."{feature_group.target_feature_name_in_base}"'
1008+
+ f" {feature_group.join_comparator.value} "
1009+
+ f'fg_{suffix}."{feature_name_in_target}"'
9211010
)
9221011
base_timestamp_cast_function_name = "from_unixtime"
9231012
if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING:

0 commit comments

Comments
 (0)