Skip to content

Commit c06ed9b

Browse files
imingtsoumizanfiu
authored andcommitted
feat: Handle merge and timestamp filters (aws#727)
1 parent b01c3ae commit c06ed9b

File tree

1 file changed

+78
-16
lines changed

1 file changed

+78
-16
lines changed

src/sagemaker/feature_store/dataset_builder.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def with_feature_group(
152152
table_name = data_catalog_config.get("TableName", None)
153153
database = data_catalog_config.get("Database", None)
154154
features = [feature.feature_name for feature in feature_group.feature_definitions]
155+
if not target_feature_name_in_base:
156+
target_feature_name_in_base = self._record_identifier_feature_name
155157
self._feature_groups_to_be_merged.append(
156158
FeatureGroupToBeMerged(
157159
features,
@@ -312,6 +314,37 @@ def to_csv(self):
312314
), query_result.get("QueryExecution", None).get("Query", None)
313315
raise ValueError("Base must be either a FeatureGroup or a DataFrame.")
314316

317+
def _construct_where_query_string(self, suffix: str, event_time_identifier_feature_name: str):
318+
"""Internal method for constructing SQL WHERE query string by parameters.
319+
320+
Args:
321+
suffix (str): A temp identifier of the FeatureGroup.
322+
event_time_identifier_feature_name (str): A string representing the event time
323+
identifier feature.
324+
Returns:
325+
The WHERE query string.
326+
"""
327+
where_conditions = []
328+
if not self._include_deleted_records:
329+
where_conditions.append("NOT is_deleted")
330+
if self._write_time_ending_timestamp:
331+
where_conditions.append(
332+
f'table_{suffix}."write_time" <= {self._write_time_ending_timestamp}'
333+
)
334+
if self._event_time_starting_timestamp:
335+
where_conditions.append(
336+
f'table_{suffix}."{event_time_identifier_feature_name}" >= '
337+
+ str(self._event_time_starting_timestamp)
338+
)
339+
if self._event_time_ending_timestamp:
340+
where_conditions.append(
341+
f'table_{suffix}."{event_time_identifier_feature_name}" <= '
342+
+ str(self._event_time_ending_timestamp)
343+
)
344+
if len(where_conditions) == 0:
345+
return ""
346+
return "WHERE " + "\nAND ".join(where_conditions)
347+
315348
def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix: str):
316349
"""Internal method for constructing SQL query string by parameters.
317350
@@ -333,8 +366,6 @@ def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix:
333366
query_string += (
334367
f'FROM "{feature_group.database}"."{feature_group.table_name}" table_{suffix}\n'
335368
)
336-
if not self._include_deleted_records:
337-
query_string += "WHERE NOT is_deleted\n"
338369
else:
339370
features = feature_group.features
340371
features.remove(feature_group.event_time_identifier_feature_name)
@@ -351,9 +382,9 @@ def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix:
351382
+ f") AS table_{suffix}\n"
352383
+ f"WHERE row_{suffix} = 1\n"
353384
)
354-
if not self._include_deleted_records:
355-
query_string += "AND NOT is_deleted\n"
356-
return query_string
385+
return query_string + self._construct_where_query_string(
386+
suffix, feature_group.event_time_identifier_feature_name
387+
)
357388

358389
def _construct_query_string(
359390
self, base_table_name: str, database: str, base_features: list
@@ -367,16 +398,6 @@ def _construct_query_string(
367398
Returns:
368399
The query string.
369400
"""
370-
query_string = ""
371-
if len(self._feature_groups_to_be_merged) > 0:
372-
with_subquery_string = ",\n".join(
373-
[
374-
f"fg_{i} AS ({self._construct_table_query(feature_group, str(i))})"
375-
for i, feature_group in enumerate(self._feature_groups_to_be_merged)
376-
]
377-
)
378-
query_string = f"WITH {with_subquery_string}\n"
379-
380401
base = FeatureGroupToBeMerged(
381402
base_features,
382403
self._included_feature_names,
@@ -385,7 +406,48 @@ def _construct_query_string(
385406
self._record_identifier_feature_name,
386407
self._event_time_identifier_feature_name,
387408
)
388-
return query_string + self._construct_table_query(base, "base")
409+
base_table_query_string = self._construct_table_query(base, "base")
410+
query_string = f"WITH fg_base AS ({base_table_query_string})"
411+
if len(self._feature_groups_to_be_merged) > 0:
412+
with_subquery_string = "".join(
413+
[
414+
f",\nfg_{i} AS ({self._construct_table_query(feature_group, str(i))})"
415+
for i, feature_group in enumerate(self._feature_groups_to_be_merged)
416+
]
417+
)
418+
query_string += with_subquery_string
419+
query_string += "SELECT *\nFROM fg_base"
420+
if len(self._feature_groups_to_be_merged) > 0:
421+
join_subquery_string = "".join(
422+
[
423+
self._construct_join_condition(feature_group, str(i))
424+
for i, feature_group in enumerate(self._feature_groups_to_be_merged)
425+
]
426+
)
427+
query_string += join_subquery_string
428+
return query_string
429+
430+
def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str):
431+
"""Internal method for constructing SQL JOIN query string by parameters.
432+
433+
Args:
434+
feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the
435+
FeatureGroup metadata.
436+
suffix (str): A temp identifier of the FeatureGroup.
437+
Returns:
438+
The JOIN query string.
439+
"""
440+
join_condition_string = (
441+
f"\nJOIN fg_{suffix}\n"
442+
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" = '
443+
+ f'fg_{suffix}."{feature_group.record_identifier_feature_name}"'
444+
)
445+
if self._point_in_time_accurate_join:
446+
join_condition_string += (
447+
f'\nAND fg_base."{self._event_time_identifier_feature_name}" >= '
448+
+ f'fg_{suffix}."{feature_group.event_time_identifier_feature_name}"'
449+
)
450+
return join_condition_string
389451

390452
def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str):
391453
"""Internal method for creating a temp Athena table for the base pandas.Dataframe.

0 commit comments

Comments
 (0)