Skip to content

Commit 41cde1e

Browse files
use enum types for join type and comparator params
1 parent 282366f commit 41cde1e

File tree

1 file changed

+26
-37
lines changed

1 file changed

+26
-37
lines changed

src/sagemaker/feature_store/dataset_builder.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class FeatureGroupToBeMerged:
8383
Attributes:
8484
features (List[str]): A list of strings representing feature names of this FeatureGroup.
8585
included_feature_names (List[str]): A list of strings representing features to be
86-
included in the sql join.
86+
included in the SQL join.
8787
projected_feature_names (List[str]): A list of strings representing features to be
8888
included for final projection in output.
8989
catalog (str): A string representing the catalog.
@@ -96,15 +96,15 @@ class FeatureGroupToBeMerged:
9696
be used as target join key (default: None).
9797
table_type (TableType): A TableType representing the type of table if it is Feature Group or
9898
Panda Data Frame (default: None).
99-
feature_name_in_target (str): A string representing the feature in the target feature
99+
feature_name_in_target (str): A string representing the feature name in the target feature
100100
group that will be compared to the target feature in the base feature group.
101-
If None is provided, the record identifier will be used in the
102-
join statement. (default: None).
101+
If None is provided, the record identifier feature will be used in the
102+
SQL join. (default: None).
103103
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
104104
used when joining the target feature in the base feature group and the feature
105-
in the target feature group (default: None).
105+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
106106
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
107-
the base and target feature groups. (default: None).
107+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
108108
"""
109109

110110
features: List[str] = attr.ib()
@@ -118,17 +118,17 @@ class FeatureGroupToBeMerged:
118118
target_feature_name_in_base: str = attr.ib(default=None)
119119
table_type: TableType = attr.ib(default=None)
120120
feature_name_in_target: str = attr.ib(default=None)
121-
join_comparator: JoinComparatorEnum = attr.ib(default=None)
122-
join_type: JoinTypeEnum = attr.ib(default=None)
121+
join_comparator: JoinComparatorEnum = attr.ib(default=JoinComparatorEnum.EQUALS)
122+
join_type: JoinTypeEnum = attr.ib(default=JoinTypeEnum.INNER_JOIN)
123123

124124

125125
def construct_feature_group_to_be_merged(
126126
target_feature_group: FeatureGroup,
127127
included_feature_names: List[str],
128128
target_feature_name_in_base: str = None,
129129
feature_name_in_target: str = None,
130-
join_comparator: JoinComparatorEnum = None,
131-
join_type: JoinTypeEnum = None,
130+
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
131+
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
132132
) -> FeatureGroupToBeMerged:
133133
"""Construct a FeatureGroupToBeMerged object by provided parameters.
134134
@@ -138,15 +138,15 @@ def construct_feature_group_to_be_merged(
138138
included in the output.
139139
target_feature_name_in_base (str): A string representing the feature name in base which
140140
will be used as target join key (default: None).
141-
feature_name_in_target (str): A string representing the feature in the target feature
141+
feature_name_in_target (str): A string representing the feature name in the target feature
142142
group that will be compared to the target feature in the base feature group.
143-
If None is provided, the record identifier will be used in the
144-
join statement. (default: None).
143+
If None is provided, the record identifier feature will be used in the
144+
SQL join. (default: None).
145145
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
146146
used when joining the target feature in the base feature group and the feature
147-
in the target feature group (default: None).
147+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
148148
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
149-
the base and target feature groups. (default: None).
149+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
150150
Returns:
151151
A FeatureGroupToBeMerged object.
152152
@@ -290,8 +290,8 @@ def with_feature_group(
290290
target_feature_name_in_base: str = None,
291291
included_feature_names: List[str] = None,
292292
feature_name_in_target: str = None,
293-
join_comparator: JoinComparatorEnum = None,
294-
join_type: JoinTypeEnum = None,
293+
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
294+
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
295295
):
296296
"""Join FeatureGroup with base.
297297
@@ -301,15 +301,15 @@ def with_feature_group(
301301
will be used as a join key (default: None).
302302
included_feature_names (List[str]): A list of strings representing features to be
303303
included in the output (default: None).
304-
feature_name_in_target (str): A string representing the feature in the target feature
305-
group that will be compared to the target feature in the base feature group.
306-
If None is provided, the record identifier will be used in the
307-
join statement. (default: None).
304+
feature_name_in_target (str): A string representing the feature name in the target
305+
feature group that will be compared to the target feature in the base feature group.
306+
If None is provided, the record identifier feature will be used in the
307+
SQL join. (default: None).
308308
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
309309
used when joining the target feature in the base feature group and the feature
310-
in the target feature group (default: None).
310+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
311311
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
312-
the base and target feature groups. (default: None).
312+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
313313
Returns:
314314
This DatasetBuilder object.
315315
"""
@@ -985,27 +985,16 @@ def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffi
985985
The JOIN query string.
986986
"""
987987

988-
join_type = (
989-
feature_group.join_type
990-
if feature_group.join_type is not None
991-
else JoinTypeEnum.INNER_JOIN
992-
)
993-
994-
join_comparator = (
995-
feature_group.join_comparator
996-
if feature_group.join_comparator is not None
997-
else JoinComparatorEnum.EQUALS
998-
)
999-
1000988
feature_name_in_target = (
1001989
feature_group.feature_name_in_target
1002990
if feature_group.feature_name_in_target is not None
1003991
else feature_group.record_identifier_feature_name
1004992
)
1005993

1006994
join_condition_string = (
1007-
f"\n{join_type.value} fg_{suffix}\n"
1008-
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" {join_comparator.value} '
995+
f"\n{feature_group.join_type.value} fg_{suffix}\n"
996+
+ f'ON fg_base."{feature_group.target_feature_name_in_base}"'
997+
+ f" {feature_group.join_comparator.value} "
1009998
+ f'fg_{suffix}."{feature_name_in_target}"'
1010999
)
10111000
base_timestamp_cast_function_name = "from_unixtime"

0 commit comments

Comments
 (0)