Skip to content

Commit 67b94e4

Browse files
use enum types for join type and comparator params
1 parent 09d161f commit 67b94e4

File tree

1 file changed

+26
-37
lines changed

1 file changed

+26
-37
lines changed

src/sagemaker/feature_store/dataset_builder.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class FeatureGroupToBeMerged:
8383
Attributes:
8484
features (List[str]): A list of strings representing feature names of this FeatureGroup.
8585
included_feature_names (List[str]): A list of strings representing features to be
86-
included in the sql join.
86+
included in the SQL join.
8787
projected_feature_names (List[str]): A list of strings representing features to be
8888
included for final projection in output.
8989
catalog (str): A string representing the catalog.
@@ -96,15 +96,15 @@ class FeatureGroupToBeMerged:
9696
be used as target join key (default: None).
9797
table_type (TableType): A TableType representing the type of table if it is Feature Group or
9898
Panda Data Frame (default: None).
99-
feature_name_in_target (str): A string representing the feature in the target feature
99+
feature_name_in_target (str): A string representing the feature name in the target feature
100100
group that will be compared to the target feature in the base feature group.
101-
If None is provided, the record identifier will be used in the
102-
join statement. (default: None).
101+
If None is provided, the record identifier feature will be used in the
102+
SQL join. (default: None).
103103
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
104104
used when joining the target feature in the base feature group and the feature
105-
in the target feature group (default: None).
105+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
106106
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
107-
the base and target feature groups. (default: None).
107+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
108108
"""
109109

110110
features: List[str] = attr.ib()
@@ -118,17 +118,17 @@ class FeatureGroupToBeMerged:
118118
target_feature_name_in_base: str = attr.ib(default=None)
119119
table_type: TableType = attr.ib(default=None)
120120
feature_name_in_target: str = attr.ib(default=None)
121-
join_comparator: JoinComparatorEnum = attr.ib(default=None)
122-
join_type: JoinTypeEnum = attr.ib(default=None)
121+
join_comparator: JoinComparatorEnum = attr.ib(default=JoinComparatorEnum.EQUALS)
122+
join_type: JoinTypeEnum = attr.ib(default=JoinTypeEnum.INNER_JOIN)
123123

124124

125125
def construct_feature_group_to_be_merged(
126126
target_feature_group: FeatureGroup,
127127
included_feature_names: List[str],
128128
target_feature_name_in_base: str = None,
129129
feature_name_in_target: str = None,
130-
join_comparator: JoinComparatorEnum = None,
131-
join_type: JoinTypeEnum = None,
130+
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
131+
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
132132
) -> FeatureGroupToBeMerged:
133133
"""Construct a FeatureGroupToBeMerged object by provided parameters.
134134
@@ -138,15 +138,15 @@ def construct_feature_group_to_be_merged(
138138
included in the output.
139139
target_feature_name_in_base (str): A string representing the feature name in base which
140140
will be used as target join key (default: None).
141-
feature_name_in_target (str): A string representing the feature in the target feature
141+
feature_name_in_target (str): A string representing the feature name in the target feature
142142
group that will be compared to the target feature in the base feature group.
143-
If None is provided, the record identifier will be used in the
144-
join statement. (default: None).
143+
If None is provided, the record identifier feature will be used in the
144+
SQL join. (default: None).
145145
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
146146
used when joining the target feature in the base feature group and the feature
147-
in the target feature group (default: None).
147+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
148148
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
149-
the base and target feature groups. (default: None).
149+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
150150
Returns:
151151
A FeatureGroupToBeMerged object.
152152
@@ -299,8 +299,8 @@ def with_feature_group(
299299
target_feature_name_in_base: str = None,
300300
included_feature_names: List[str] = None,
301301
feature_name_in_target: str = None,
302-
join_comparator: JoinComparatorEnum = None,
303-
join_type: JoinTypeEnum = None,
302+
join_comparator: JoinComparatorEnum = JoinComparatorEnum.EQUALS,
303+
join_type: JoinTypeEnum = JoinTypeEnum.INNER_JOIN,
304304
):
305305
"""Join FeatureGroup with base.
306306
@@ -310,15 +310,15 @@ def with_feature_group(
310310
will be used as a join key (default: None).
311311
included_feature_names (List[str]): A list of strings representing features to be
312312
included in the output (default: None).
313-
feature_name_in_target (str): A string representing the feature in the target feature
314-
group that will be compared to the target feature in the base feature group.
315-
If None is provided, the record identifier will be used in the
316-
join statement. (default: None).
313+
feature_name_in_target (str): A string representing the feature name in the target
314+
feature group that will be compared to the target feature in the base feature group.
315+
If None is provided, the record identifier feature will be used in the
316+
SQL join. (default: None).
317317
join_comparator (JoinComparatorEnum): A JoinComparatorEnum representing the comparator
318318
used when joining the target feature in the base feature group and the feature
319-
in the target feature group (default: None).
319+
in the target feature group. (default: JoinComparatorEnum.EQUALS).
320320
join_type (JoinTypeEnum): A JoinTypeEnum representing the type of join between
321-
the base and target feature groups. (default: None).
321+
the base and target feature groups. (default: JoinTypeEnum.INNER_JOIN).
322322
Returns:
323323
This DatasetBuilder object.
324324
"""
@@ -994,27 +994,16 @@ def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffi
994994
The JOIN query string.
995995
"""
996996

997-
join_type = (
998-
feature_group.join_type
999-
if feature_group.join_type is not None
1000-
else JoinTypeEnum.INNER_JOIN
1001-
)
1002-
1003-
join_comparator = (
1004-
feature_group.join_comparator
1005-
if feature_group.join_comparator is not None
1006-
else JoinComparatorEnum.EQUALS
1007-
)
1008-
1009997
feature_name_in_target = (
1010998
feature_group.feature_name_in_target
1011999
if feature_group.feature_name_in_target is not None
10121000
else feature_group.record_identifier_feature_name
10131001
)
10141002

10151003
join_condition_string = (
1016-
f"\n{join_type.value} fg_{suffix}\n"
1017-
+ f'ON fg_base."{feature_group.target_feature_name_in_base}" {join_comparator.value} '
1004+
f"\n{feature_group.join_type.value} fg_{suffix}\n"
1005+
+ f'ON fg_base."{feature_group.target_feature_name_in_base}"'
1006+
+ f" {feature_group.join_comparator.value} "
10181007
+ f'fg_{suffix}."{feature_name_in_target}"'
10191008
)
10201009
base_timestamp_cast_function_name = "from_unixtime"

0 commit comments

Comments
 (0)