Skip to content

feature: feature store with_feature_group functionality changes #3630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 101 additions & 14 deletions src/sagemaker/feature_store/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,34 @@ class TableType(Enum):
DATA_FRAME = "DataFrame"


@attr.s
class JoinTypeEnum(Enum):
"""Enum of Join types.

The Join comparator can be "INNER_JOIN", "LEFT_JOIN", "RIGHT_JOIN", "FULL_JOIN"
"""

INNER_JOIN = "JOIN"
LEFT_JOIN = "LEFT JOIN"
RIGHT_JOIN = "RIGHT JOIN"
FULL_JOIN = "FULL JOIN"


@attr.s
class JoinComparatorEnum(Enum):
"""Enum of Join comparators.

The Join comparator can be "EQUALS", "GREATER_THAN", "LESS_THAN",
"GREATER_THAN_OR_EQUAL_TO", or "LESS_THAN_OR_EQUAL_TO"
"""

EQUALS = "="
GREATER_THAN = ">"
GREATER_THAN_OR_EQUAL_TO = ">="
LESS_THAN = "<"
LESS_THAN_OR_EQUAL_TO = "<="


@attr.s
class FeatureGroupToBeMerged:
"""FeatureGroup metadata which will be used for SQL join.
Expand All @@ -55,19 +83,28 @@ class FeatureGroupToBeMerged:
Attributes:
features (List[str]): A list of strings representing feature names of this FeatureGroup.
included_feature_names (List[str]): A list of strings representing features to be
included in the sql join.
included in the SQL join.
projected_feature_names (List[str]): A list of strings representing features to be
included for final projection in output.
catalog (str): A string representing the catalog.
database (str): A string representing the database.
table_name (str): A string representing the Athena table name of this FeatureGroup.
record_dentifier_feature_name (str): A string representing the record identifier feature.
record_identifier_feature_name (str): A string representing the record identifier feature.
event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the
event time identifier feature.
target_feature_name_in_base (str): A string representing the feature name in base which will
be used as target join key (default: None).
table_type (TableType): A TableType representing the type of table if it is Feature Group or
Panda Data Frame (default: None).
feature_name_in_target (str): A string representing the feature name in the target feature
group that will be compared to the target feature in the base feature group.
If None is provided, the record identifier feature will be used in the
SQL join. (default: None).
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
used when joining the target feature in the base feature group and the feature
in the target feature group. (default: JoinComparatorEnum.EQUALS).
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
"""

features: List[str] = attr.ib()
Expand All @@ -80,12 +117,18 @@ class FeatureGroupToBeMerged:
event_time_identifier_feature: FeatureDefinition = attr.ib()
target_feature_name_in_base: str = attr.ib(default=None)
table_type: TableType = attr.ib(default=None)
feature_name_in_target: str = attr.ib(default=None)
join_comparator: JoinComparatorEnum = attr.ib(default=JoinComparatorEnum.EQUALS)
join_type: JoinTypeEnum = attr.ib(default=JoinTypeEnum.INNER_JOIN)


def construct_feature_group_to_be_merged(
feature_group: FeatureGroup,
target_feature_group: FeatureGroup,
included_feature_names: List[str],
target_feature_name_in_base: str = None,
feature_name_in_target: str = None,
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
) -> FeatureGroupToBeMerged:
"""Construct a FeatureGroupToBeMerged object by provided parameters.

Expand All @@ -95,18 +138,29 @@ def construct_feature_group_to_be_merged(
included in the output.
target_feature_name_in_base (str): A string representing the feature name in base which
will be used as target join key (default: None).
feature_name_in_target (str): A string representing the feature name in the target feature
group that will be compared to the target feature in the base feature group.
If None is provided, the record identifier feature will be used in the
SQL join. (default: None).
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
used when joining the target feature in the base feature group and the feature
in the target feature group. (default: JoinComparatorEnum.EQUALS).
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
Returns:
A FeatureGroupToBeMerged object.

Raises:
ValueError: Invalid feature name(s) in included_feature_names.
"""
feature_group_metadata = feature_group.describe()
feature_group_metadata = target_feature_group.describe()
data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get(
"DataCatalogConfig", None
)
if not data_catalog_config:
raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.")
raise RuntimeError(
f"No metastore is configured with FeatureGroup {target_feature_group.name}."
)

record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None)
feature_definitions = feature_group_metadata.get("FeatureDefinitions", [])
Expand All @@ -126,10 +180,15 @@ def construct_feature_group_to_be_merged(
catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG
features = [feature.get("FeatureName", None) for feature in feature_definitions]

if feature_name_in_target is not None and feature_name_in_target not in features:
raise ValueError(
f"Feature {feature_name_in_target} not found in FeatureGroup {target_feature_group.name}"
)

for included_feature in included_feature_names or []:
if included_feature not in features:
raise ValueError(
f"Feature {included_feature} not found in FeatureGroup {feature_group.name}"
f"Feature {included_feature} not found in FeatureGroup {target_feature_group.name}"
)
if not included_feature_names:
included_feature_names = features
Expand All @@ -151,6 +210,9 @@ def construct_feature_group_to_be_merged(
FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type),
target_feature_name_in_base,
TableType.FEATURE_GROUP,
feature_name_in_target,
join_comparator,
join_type,
)


Expand Down Expand Up @@ -227,21 +289,38 @@ def with_feature_group(
feature_group: FeatureGroup,
target_feature_name_in_base: str = None,
included_feature_names: List[str] = None,
feature_name_in_target: str = None,
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
):
"""Join FeatureGroup with base.

Args:
feature_group (FeatureGroup): A FeatureGroup which will be joined to base.
feature_group (FeatureGroup): A target FeatureGroup which will be joined to base.
target_feature_name_in_base (str): A string representing the feature name in base which
will be used as target join key (default: None).
will be used as a join key (default: None).
included_feature_names (List[str]): A list of strings representing features to be
included in the output (default: None).
Returns:
This DatasetBuilder object.
feature_name_in_target (str): A string representing the feature name in the target
feature group that will be compared to the target feature in the base feature group.
If None is provided, the record identifier feature will be used in the
SQL join. (default: None).
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
used when joining the target feature in the base feature group and the feature
in the target feature group. (default: JoinComparatorEnum.EQUALS).
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
Returns:
This DatasetBuilder object.
"""
self._feature_groups_to_be_merged.append(
construct_feature_group_to_be_merged(
feature_group, included_feature_names, target_feature_name_in_base
feature_group,
included_feature_names,
target_feature_name_in_base,
feature_name_in_target,
join_comparator,
join_type,
)
)
return self
Expand Down Expand Up @@ -905,10 +984,18 @@ def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffi
Returns:
The JOIN query string.
"""

feature_name_in_target = (
feature_group.feature_name_in_target
if feature_group.feature_name_in_target is not None
else feature_group.record_identifier_feature_name
)

join_condition_string = (
f"\nJOIN fg_{suffix}\n"
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" = '
+ f'fg_{suffix}."{feature_group.record_identifier_feature_name}"'
f"\n{feature_group.join_type.value} fg_{suffix}\n"
+ f'ON fg_base."{feature_group.target_feature_name_in_base}"'
+ f" {feature_group.join_comparator.value} "
+ f'fg_{suffix}."{feature_name_in_target}"'
)
base_timestamp_cast_function_name = "from_unixtime"
if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING:
Expand Down
Loading