Skip to content

Commit 50b8010

Browse files
updates with_feature-group def and adds test
1 parent 58fe72a commit 50b8010

File tree

2 files changed

+220
-9
lines changed

2 files changed

+220
-9
lines changed

src/sagemaker/feature_store/dataset_builder.py

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,32 @@ class TableType(Enum):
4343
DATA_FRAME = "DataFrame"
4444

4545

46+
@attr.s
47+
class JoinTypeEnum(Enum):
48+
"""Enum of Join types.
49+
The Join comparator can be "INNER_JOIN", "LEFT_JOIN", "RIGHT_JOIN", "FULL_JOIN"
50+
"""
51+
52+
INNER_JOIN = "JOIN"
53+
LEFT_JOIN = "LEFT JOIN"
54+
RIGHT_JOIN = "RIGHT JOIN"
55+
FULL_JOIN = "FULL JOIN"
56+
57+
58+
@attr.s
59+
class JoinComparatorEnum(Enum):
60+
"""Enum of Join comparators.
61+
The Join comparator can be "EQUALS", "GREATER_THAN", "LESS_THAN",
62+
"GREATER_THAN_OR_EQUAL_TO", or "LESS_THAN_OR_EQUAL_TO"
63+
"""
64+
65+
EQUALS = "="
66+
GREATER_THAN = ">"
67+
GREATER_THAN_OR_EQUAL_TO = ">="
68+
LESS_THAN = "<"
69+
LESS_THAN_OR_EQUAL_TO = "<="
70+
71+
4672
@attr.s
4773
class FeatureGroupToBeMerged:
4874
"""FeatureGroup metadata which will be used for SQL join.
@@ -68,6 +94,13 @@ class FeatureGroupToBeMerged:
6894
be used as target join key (default: None).
6995
table_type (TableType): A TableType representing the type of table if it is Feature Group or
7096
Panda Data Frame (default: None).
97+
feature_name_in_target (str): A string representing the feature in the target feature group
98+
that will be compared to the target feature in the base feature group
99+
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator used
100+
when joining the target feature in the base feature group and the feature in the target
101+
feature group (default: None).
102+
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between the base and
103+
target feature groups. (default: None).
71104
"""
72105

73106
features: List[str] = attr.ib()
@@ -80,12 +113,18 @@ class FeatureGroupToBeMerged:
80113
event_time_identifier_feature: FeatureDefinition = attr.ib()
81114
target_feature_name_in_base: str = attr.ib(default=None)
82115
table_type: TableType = attr.ib(default=None)
116+
feature_name_in_target: str = attr.ib(default=None)
117+
join_comparator: JoinComparatorEnum = attr.ib(default=None)
118+
join_type: JoinTypeEnum = attr.ib(default=None)
83119

84120

85121
def construct_feature_group_to_be_merged(
86-
feature_group: FeatureGroup,
122+
target_feature_group: FeatureGroup,
87123
included_feature_names: List[str],
88124
target_feature_name_in_base: str = None,
125+
feature_name_in_target: str = None,
126+
join_comparator: JoinComparatorEnum = None,
127+
join_type: JoinTypeEnum = None
89128
) -> FeatureGroupToBeMerged:
90129
"""Construct a FeatureGroupToBeMerged object by provided parameters.
91130
@@ -101,12 +140,12 @@ def construct_feature_group_to_be_merged(
101140
Raises:
102141
ValueError: Invalid feature name(s) in included_feature_names.
103142
"""
104-
feature_group_metadata = feature_group.describe()
143+
feature_group_metadata = target_feature_group.describe()
105144
data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get(
106145
"DataCatalogConfig", None
107146
)
108147
if not data_catalog_config:
109-
raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.")
148+
raise RuntimeError(f"No metastore is configured with FeatureGroup {target_feature_group.name}.")
110149

111150
record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None)
112151
feature_definitions = feature_group_metadata.get("FeatureDefinitions", [])
@@ -126,10 +165,15 @@ def construct_feature_group_to_be_merged(
126165
catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG
127166
features = [feature.get("FeatureName", None) for feature in feature_definitions]
128167

168+
if (feature_name_in_target is not None and feature_name_in_target not in features):
169+
raise ValueError(
170+
f"Feature {feature_name_in_target} not found in FeatureGroup {target_feature_group.name}"
171+
)
172+
129173
for included_feature in included_feature_names or []:
130174
if included_feature not in features:
131175
raise ValueError(
132-
f"Feature {included_feature} not found in FeatureGroup {feature_group.name}"
176+
f"Feature {included_feature} not found in FeatureGroup {target_feature_group.name}"
133177
)
134178
if not included_feature_names:
135179
included_feature_names = features
@@ -151,6 +195,9 @@ def construct_feature_group_to_be_merged(
151195
FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type),
152196
target_feature_name_in_base,
153197
TableType.FEATURE_GROUP,
198+
feature_name_in_target,
199+
join_comparator,
200+
join_type
154201
)
155202

156203

@@ -227,6 +274,9 @@ def with_feature_group(
227274
feature_group: FeatureGroup,
228275
target_feature_name_in_base: str = None,
229276
included_feature_names: List[str] = None,
277+
feature_name_in_target: str = None,
278+
join_comparator: JoinComparatorEnum = None,
279+
join_type: JoinTypeEnum = None
230280
):
231281
"""Join FeatureGroup with base.
232282
@@ -241,7 +291,11 @@ def with_feature_group(
241291
"""
242292
self._feature_groups_to_be_merged.append(
243293
construct_feature_group_to_be_merged(
244-
feature_group, included_feature_names, target_feature_name_in_base
294+
feature_group, included_feature_names,
295+
target_feature_name_in_base,
296+
feature_name_in_target,
297+
join_comparator,
298+
join_type
245299
)
246300
)
247301
return self
@@ -905,10 +959,22 @@ def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffi
905959
Returns:
906960
The JOIN query string.
907961
"""
962+
963+
join_type = (feature_group.join_type if feature_group.join_type is not None
964+
else JoinTypeEnum.INNER_JOIN)
965+
966+
join_comparator = (feature_group.join_comparator
967+
if feature_group.join_comparator is not None
968+
else JoinComparatorEnum.EQUALS)
969+
970+
feature_name_in_target = (feature_group.feature_name_in_target
971+
if feature_group.feature_name_in_target is not None
972+
else feature_group.record_identifier_feature_name)
973+
908974
join_condition_string = (
909-
f"\nJOIN fg_{suffix}\n"
910-
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" = '
911-
+ f'fg_{suffix}."{feature_group.record_identifier_feature_name}"'
975+
f"\n{join_type.value} fg_{suffix}\n"
976+
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" {join_comparator.value} '
977+
+ f'fg_{suffix}."{feature_name_in_target}"'
912978
)
913979
base_timestamp_cast_function_name = "from_unixtime"
914980
if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING:

tests/unit/sagemaker/feature_store/test_dataset_builder.py

Lines changed: 146 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
DatasetBuilder,
2424
FeatureGroupToBeMerged,
2525
TableType,
26+
JoinComparatorEnum,
27+
JoinTypeEnum
2628
)
2729
from sagemaker.feature_store.feature_group import (
2830
FeatureDefinition,
2931
FeatureGroup,
30-
FeatureTypeEnum,
32+
FeatureTypeEnum
3133
)
3234

3335

@@ -528,6 +530,149 @@ def test_construct_query_string(sagemaker_session_mock):
528530
+ ")\n"
529531
)
530532

533+
# Tests the optional feature_name_in_target, join_comparator and join_type parameters
534+
def test_with_feature_group_with_optional_params_query_string(sagemaker_session_mock):
535+
base_feature_group = FeatureGroup(name="base_feature_group", sagemaker_session=sagemaker_session_mock)
536+
target_feature_group = FeatureGroup(name="target_feature_group", sagemaker_session=sagemaker_session_mock)
537+
538+
dataset_builder = DatasetBuilder(
539+
sagemaker_session=sagemaker_session_mock,
540+
base=base_feature_group,
541+
output_path="file/to/path",
542+
record_identifier_feature_name="target-feature",
543+
)
544+
545+
dataset_builder._event_time_identifier_feature_name = "target-feature"
546+
547+
sagemaker_session_mock.describe_feature_group.return_value = {
548+
"OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table-name", "Database": "database"}},
549+
"RecordIdentifierFeatureName": "feature-1",
550+
"EventTimeFeatureName": "feature-2",
551+
"FeatureDefinitions": [
552+
{"FeatureName": "feature-1", "FeatureType": "String"},
553+
{"FeatureName": "feature-2", "FeatureType": "String"},
554+
],
555+
}
556+
557+
dataset_builder.with_feature_group(target_feature_group,
558+
"target-feature",
559+
["feature-1", "feature-2"],
560+
"feature-2",
561+
JoinComparatorEnum.GREATER_THAN,
562+
JoinTypeEnum.FULL_JOIN)
563+
564+
query_string = dataset_builder._construct_query_string(BASE)
565+
566+
assert (
567+
query_string
568+
== "WITH fg_base AS (WITH table_base AS (\n"
569+
+ "SELECT *\n"
570+
+ "FROM (\n"
571+
+ "SELECT *, row_number() OVER (\n"
572+
+ 'PARTITION BY origin_base."target-feature", origin_base."other-feature"\n'
573+
+ 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n'
574+
+ ") AS dedup_row_base\n"
575+
+ 'FROM "database"."base-table" origin_base\n'
576+
+ ")\n"
577+
+ "WHERE dedup_row_base = 1\n"
578+
+ "),\n"
579+
+ "deleted_base AS (\n"
580+
+ "SELECT *\n"
581+
+ "FROM (\n"
582+
+ "SELECT *, row_number() OVER (\n"
583+
+ 'PARTITION BY origin_base."target-feature"\n'
584+
+ 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" '
585+
+ 'DESC, origin_base."write_time" DESC\n'
586+
+ ") AS deleted_row_base\n"
587+
+ 'FROM "database"."base-table" origin_base\n'
588+
+ "WHERE is_deleted\n"
589+
+ ")\n"
590+
+ "WHERE deleted_row_base = 1\n"
591+
+ ")\n"
592+
+ 'SELECT table_base."target-feature", table_base."other-feature"\n'
593+
+ "FROM (\n"
594+
+ 'SELECT table_base."target-feature", table_base."other-feature", '
595+
+ 'table_base."write_time"\n'
596+
+ "FROM table_base\n"
597+
+ "LEFT JOIN deleted_base\n"
598+
+ 'ON table_base."target-feature" = deleted_base."target-feature"\n'
599+
+ 'WHERE deleted_base."target-feature" IS NULL\n'
600+
+ "UNION ALL\n"
601+
+ 'SELECT table_base."target-feature", table_base."other-feature", '
602+
+ 'table_base."write_time"\n'
603+
+ "FROM deleted_base\n"
604+
+ "JOIN table_base\n"
605+
+ 'ON table_base."target-feature" = deleted_base."target-feature"\n'
606+
+ "AND (\n"
607+
+ 'table_base."other-feature" > deleted_base."other-feature"\n'
608+
+ 'OR (table_base."other-feature" = deleted_base."other-feature" AND '
609+
+ 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n'
610+
+ 'OR (table_base."other-feature" = deleted_base."other-feature" AND '
611+
+ 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND '
612+
+ 'table_base."write_time" > deleted_base."write_time")\n'
613+
+ ")\n"
614+
+ ") AS table_base\n"
615+
+ "),\n"
616+
+ "fg_0 AS (WITH table_0 AS (\n"
617+
+ "SELECT *\n"
618+
+ "FROM (\n"
619+
+ "SELECT *, row_number() OVER (\n"
620+
+ 'PARTITION BY origin_0."feature-1", origin_0."feature-2"\n'
621+
+ 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n'
622+
+ ") AS dedup_row_0\n"
623+
+ 'FROM "database"."table-name" origin_0\n'
624+
+ ")\n"
625+
+ "WHERE dedup_row_0 = 1\n"
626+
+ "),\n"
627+
+ "deleted_0 AS (\n"
628+
+ "SELECT *\n"
629+
+ "FROM (\n"
630+
+ "SELECT *, row_number() OVER (\n"
631+
+ 'PARTITION BY origin_0."feature-1"\n'
632+
+ 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, '
633+
+ 'origin_0."write_time" DESC\n'
634+
+ ") AS deleted_row_0\n"
635+
+ 'FROM "database"."table-name" origin_0\n'
636+
+ "WHERE is_deleted\n"
637+
+ ")\n"
638+
+ "WHERE deleted_row_0 = 1\n"
639+
+ ")\n"
640+
+ 'SELECT table_0."feature-1", table_0."feature-2"\n'
641+
+ "FROM (\n"
642+
+ 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n'
643+
+ "FROM table_0\n"
644+
+ "LEFT JOIN deleted_0\n"
645+
+ 'ON table_0."feature-1" = deleted_0."feature-1"\n'
646+
+ 'WHERE deleted_0."feature-1" IS NULL\n'
647+
+ "UNION ALL\n"
648+
+ 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n'
649+
+ "FROM deleted_0\n"
650+
+ "JOIN table_0\n"
651+
+ 'ON table_0."feature-1" = deleted_0."feature-1"\n'
652+
+ "AND (\n"
653+
+ 'table_0."feature-2" > deleted_0."feature-2"\n'
654+
+ 'OR (table_0."feature-2" = deleted_0."feature-2" AND '
655+
+ 'table_0."api_invocation_time" > deleted_0."api_invocation_time")\n'
656+
+ 'OR (table_0."feature-2" = deleted_0."feature-2" AND '
657+
+ 'table_0."api_invocation_time" = deleted_0."api_invocation_time" AND '
658+
+ 'table_0."write_time" > deleted_0."write_time")\n'
659+
+ ")\n"
660+
+ ") AS table_0\n"
661+
+ ")\n"
662+
+ 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n'
663+
+ "FROM (\n"
664+
+ 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as '
665+
+ '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n'
666+
+ 'PARTITION BY fg_base."target-feature"\n'
667+
+ 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n'
668+
+ ") AS row_recent\n"
669+
+ "FROM fg_base\n"
670+
+ "FULL JOIN fg_0\n"
671+
+ 'ON fg_base."target-feature" > fg_0."feature-2"\n'
672+
+ ")\n"
673+
)
674+
675+
531676

532677
def test_create_temp_table(sagemaker_session_mock):
533678
dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]})

0 commit comments

Comments
 (0)