Skip to content

Commit 42a7452

Browse files
committed
feature: Add support of collection types in feature store. (#1106)
1 parent dd89c3c commit 42a7452

File tree

7 files changed

+341
-17
lines changed

7 files changed

+341
-17
lines changed

doc/api/prep_data/feature_store.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@ Feature Definition
4141
:members:
4242
:show-inheritance:
4343

44+
.. autoclass:: sagemaker.feature_store.feature_definition.CollectionTypeEnum
45+
:members:
46+
:show-inheritance:
47+
48+
.. autoclass:: sagemaker.feature_store.feature_definition.CollectionType
49+
:members:
50+
:show-inheritance:
51+
52+
.. autoclass:: sagemaker.feature_store.feature_definition.ListCollectionType
53+
:members:
54+
:show-inheritance:
55+
56+
.. autoclass:: sagemaker.feature_store.feature_definition.SetCollectionType
57+
:members:
58+
:show-inheritance:
59+
60+
.. autoclass:: sagemaker.feature_store.feature_definition.VectorCollectionType
61+
:members:
62+
:show-inheritance:
4463

4564
Inputs
4665
******

src/sagemaker/feature_store/feature_definition.py

Lines changed: 107 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,87 @@ class FeatureTypeEnum(Enum):
3939
STRING = "String"
4040

4141

42+
class CollectionTypeEnum(Enum):
43+
"""Enum of collection types.
44+
45+
The collection type of a feature can be List, Set or Vector.
46+
"""
47+
48+
LIST = "List"
49+
SET = "Set"
50+
VECTOR = "Vector"
51+
52+
53+
@attr.s
54+
class CollectionType(Config):
55+
"""Collection type and its configuration.
56+
57+
This initiates a collectiontype object where CollectionType is a subclass of Config.
58+
59+
Attributes:
60+
collection_type (CollectionTypeEnum): The type of the collection
61+
collection_config (Dict[str, Any]): The configuration for the collection.
62+
"""
63+
64+
collection_type: CollectionTypeEnum = attr.ib()
65+
collection_config: Dict[str, Any] = attr.ib()
66+
67+
def to_dict(self) -> Dict[str, Any]:
68+
"""Construct a dictionary based on each attribute."""
69+
return Config.construct_dict(
70+
CollectionType=self.collection_type.value, CollectionConfig=self.collection_config
71+
)
72+
73+
74+
class ListCollectionType(CollectionType):
75+
"""List collection type
76+
77+
This class instantiates a ListCollectionType object, as subclass of CollectionType
78+
where the collection type is defined as List.
79+
80+
"""
81+
82+
def __init__(self):
83+
"""Construct an instance of ListCollectionType."""
84+
super(ListCollectionType, self).__init__(CollectionTypeEnum.LIST, None)
85+
86+
87+
class SetCollectionType(CollectionType):
88+
"""Set collection type
89+
90+
This class instantiates a SetCollectionType object, as subclass of CollectionType
91+
where the collection type is defined as Set.
92+
93+
"""
94+
95+
def __init__(self):
96+
"""Construct an instance of SetCollectionType."""
97+
super(SetCollectionType, self).__init__(CollectionTypeEnum.SET, None)
98+
99+
100+
class VectorCollectionType(CollectionType):
101+
"""Vector collection type
102+
103+
This class instantiates a VectorCollectionType object, as subclass of CollectionType
104+
where the collection type is defined as Vector.
105+
106+
Attributes:
107+
dimension (int): The dimension size for the Vector.
108+
"""
109+
110+
def __init__(self, dimension: int):
111+
"""Construct an instance of VectorCollectionType.
112+
113+
Attributes:
114+
dimension (int): The dimension size for the Vector.
115+
"""
116+
collection_config: Dict[str, Any] = {}
117+
vector_config: Dict[str, Any] = {}
118+
vector_config["Dimension"] = dimension
119+
collection_config["VectorConfig"] = vector_config
120+
super(VectorCollectionType, self).__init__(CollectionTypeEnum.VECTOR, collection_config)
121+
122+
42123
@attr.s
43124
class FeatureDefinition(Config):
44125
"""Feature definition.
@@ -48,15 +129,25 @@ class FeatureDefinition(Config):
48129
Attributes:
49130
feature_name (str): The name of the feature
50131
feature_type (FeatureTypeEnum): The type of the feature
132+
collection_type (CollectionType): The type of collection for the feature
51133
"""
52134

53135
feature_name: str = attr.ib()
54136
feature_type: FeatureTypeEnum = attr.ib()
137+
collection_type: CollectionType = attr.ib(default=None)
55138

56139
def to_dict(self) -> Dict[str, Any]:
57140
"""Construct a dictionary based on each attribute."""
141+
58142
return Config.construct_dict(
59-
FeatureName=self.feature_name, FeatureType=self.feature_type.value
143+
FeatureName=self.feature_name,
144+
FeatureType=self.feature_type.value,
145+
CollectionType=(
146+
self.collection_type.collection_type.value if self.collection_type else None
147+
),
148+
CollectionConfig=(
149+
self.collection_type.collection_config if self.collection_type else None
150+
),
60151
)
61152

62153

@@ -69,15 +160,18 @@ class FractionalFeatureDefinition(FeatureDefinition):
69160
Attributes:
70161
feature_name (str): The name of the feature
71162
feature_type (FeatureTypeEnum): A `FeatureTypeEnum.FRACTIONAL` type
163+
collection_type (CollectionType): The type of collection for the feature
72164
"""
73165

74-
def __init__(self, feature_name: str):
166+
def __init__(self, feature_name: str, collection_type: CollectionType = None):
75167
"""Construct an instance of FractionalFeatureDefinition.
76168
77169
Args:
78170
feature_name (str): the name of the feature.
79171
"""
80-
super(FractionalFeatureDefinition, self).__init__(feature_name, FeatureTypeEnum.FRACTIONAL)
172+
super(FractionalFeatureDefinition, self).__init__(
173+
feature_name, FeatureTypeEnum.FRACTIONAL, collection_type
174+
)
81175

82176

83177
class IntegralFeatureDefinition(FeatureDefinition):
@@ -89,15 +183,18 @@ class IntegralFeatureDefinition(FeatureDefinition):
89183
Attributes:
90184
feature_name (str): the name of the feature.
91185
feature_type (FeatureTypeEnum): a `FeatureTypeEnum.INTEGRAL` type.
186+
collection_type (CollectionType): The type of collection for the feature.
92187
"""
93188

94-
def __init__(self, feature_name: str):
189+
def __init__(self, feature_name: str, collection_type: CollectionType = None):
95190
"""Construct an instance of IntegralFeatureDefinition.
96191
97192
Args:
98193
feature_name (str): the name of the feature.
99194
"""
100-
super(IntegralFeatureDefinition, self).__init__(feature_name, FeatureTypeEnum.INTEGRAL)
195+
super(IntegralFeatureDefinition, self).__init__(
196+
feature_name, FeatureTypeEnum.INTEGRAL, collection_type
197+
)
101198

102199

103200
class StringFeatureDefinition(FeatureDefinition):
@@ -109,12 +206,15 @@ class StringFeatureDefinition(FeatureDefinition):
109206
Attributes:
110207
feature_name (str): the name of the feature.
111208
feature_type (FeatureTypeEnum): a `FeatureTypeEnum.STRING` type.
209+
collection_type (CollectionType): The type of collection for the feature.
112210
"""
113211

114-
def __init__(self, feature_name: str):
212+
def __init__(self, feature_name: str, collection_type: CollectionType = None):
115213
"""Construct an instance of StringFeatureDefinition.
116214
117215
Args:
118216
feature_name (str): the name of the feature.
119217
"""
120-
super(StringFeatureDefinition, self).__init__(feature_name, FeatureTypeEnum.STRING)
218+
super(StringFeatureDefinition, self).__init__(
219+
feature_name, FeatureTypeEnum.STRING, collection_type
220+
)

src/sagemaker/feature_store/inputs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,13 @@ class FeatureValue(Config):
266266
Attributes:
267267
feature_name (str): name of the Feature.
268268
value_as_string (str): value of the Feature in string form.
269+
value_as_string_list (List[str]): value of the Feature in string list
270+
form used for collection type.
269271
"""
270272

271273
feature_name: str = attr.ib(default=None)
272274
value_as_string: str = attr.ib(default=None)
275+
value_as_string_list: List[str] = attr.ib(default=None)
273276

274277
def to_dict(self) -> Dict[str, Any]:
275278
"""Construct a dictionary based on the attributes provided.
@@ -280,6 +283,7 @@ def to_dict(self) -> Dict[str, Any]:
280283
return Config.construct_dict(
281284
FeatureName=self.feature_name,
282285
ValueAsString=self.value_as_string,
286+
ValueAsStringList=self.value_as_string_list,
283287
)
284288

285289

tests/integ/test_feature_store.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from pandas import DataFrame
2626

2727
from sagemaker.feature_store.feature_utils import get_feature_group_as_dataframe
28-
from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition
28+
from sagemaker.feature_store.feature_definition import (
29+
FractionalFeatureDefinition,
30+
StringFeatureDefinition,
31+
ListCollectionType,
32+
)
2933
from sagemaker.feature_store.feature_group import FeatureGroup
3034
from sagemaker.feature_store.feature_store import FeatureStore
3135
from sagemaker.feature_store.inputs import (
@@ -186,6 +190,16 @@ def record():
186190
]
187191

188192

193+
@pytest.fixture
194+
def collection_type_record():
195+
return [
196+
FeatureValue(feature_name="feature1", value_as_string="10.0"),
197+
FeatureValue(feature_name="feature2", value_as_string="10"),
198+
FeatureValue(feature_name="feature3", value_as_string="2020-10-30T03:43:21Z"),
199+
FeatureValue(feature_name="feature4", value_as_string_list=["val1", "val2"]),
200+
]
201+
202+
189203
@pytest.fixture
190204
def create_table_ddl():
191205
return (
@@ -664,6 +678,60 @@ def test_get_and_batch_get_record(
664678
assert feature["FeatureName"] is not removed_feature_name
665679

666680

681+
def test_put_and_get_collection_type_record(
682+
feature_store_session,
683+
role,
684+
feature_group_name,
685+
pandas_data_frame,
686+
collection_type_record,
687+
):
688+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
689+
feature_definition_with_collection = [
690+
StringFeatureDefinition(feature_name="feature1"),
691+
StringFeatureDefinition(feature_name="feature2"),
692+
StringFeatureDefinition(feature_name="feature3"),
693+
StringFeatureDefinition(feature_name="feature4", collection_type=ListCollectionType()),
694+
]
695+
feature_group.feature_definitions = feature_definition_with_collection
696+
record_identifier_value_as_string = collection_type_record[0].value_as_string
697+
with cleanup_feature_group(feature_group):
698+
feature_group.create(
699+
s3_uri=False,
700+
record_identifier_name="feature1",
701+
event_time_feature_name="feature3",
702+
role_arn=role,
703+
enable_online_store=True,
704+
online_store_storage_type=OnlineStoreStorageTypeEnum.IN_MEMORY,
705+
)
706+
_wait_for_feature_group_create(feature_group)
707+
# Ingest data
708+
feature_group.put_record(record=collection_type_record)
709+
# Retrieve data
710+
retrieved_record = feature_group.get_record(
711+
record_identifier_value_as_string=record_identifier_value_as_string,
712+
)
713+
714+
assert retrieved_record is not None
715+
record_names = list(map(lambda r: r.feature_name, collection_type_record))
716+
assert len(retrieved_record) == len(record_names)
717+
718+
retrieved_feature_map = {}
719+
for feature in retrieved_record:
720+
assert feature["FeatureName"] in record_names
721+
retrieved_feature_map[feature["FeatureName"]] = (
722+
feature.get("ValueAsStringList")
723+
if feature.get("ValueAsString") is None
724+
else feature.get("ValueAsString")
725+
)
726+
727+
assert collection_type_record[0].value_as_string == retrieved_feature_map.get("feature1")
728+
assert collection_type_record[1].value_as_string == retrieved_feature_map.get("feature2")
729+
assert collection_type_record[2].value_as_string == retrieved_feature_map.get("feature3")
730+
assert collection_type_record[3].value_as_string_list == retrieved_feature_map.get(
731+
"feature4"
732+
)
733+
734+
667735
def test_soft_delete_record(
668736
feature_store_session,
669737
role,
@@ -949,7 +1017,12 @@ def test_create_dataset_with_feature_group_base(
9491017
base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role
9501018
)
9511019
_create_feature_group_and_ingest_data(
952-
feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role
1020+
feature_group,
1021+
feature_group_dataframe,
1022+
offline_store_s3_uri,
1023+
"fg_id",
1024+
"fg_time",
1025+
role,
9531026
)
9541027
base_table_name = _get_athena_table_name_after_data_replication(
9551028
feature_store_session, base, offline_store_s3_uri
@@ -1131,7 +1204,12 @@ def test_create_dataset_with_feature_group_base_with_additional_params(
11311204
base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role
11321205
)
11331206
_create_feature_group_and_ingest_data(
1134-
feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role
1207+
feature_group,
1208+
feature_group_dataframe,
1209+
offline_store_s3_uri,
1210+
"fg_id",
1211+
"fg_time",
1212+
role,
11351213
)
11361214
base_table_name = _get_athena_table_name_after_data_replication(
11371215
feature_store_session, base, offline_store_s3_uri
@@ -1157,7 +1235,10 @@ def test_create_dataset_with_feature_group_base_with_additional_params(
11571235
)
11581236
sorted_df = df.sort_values(by=list(df.columns)).reset_index(drop=True)
11591237
merged_df = base_dataframe.merge(
1160-
feature_group_dataframe, left_on="base_time", right_on="fg_time", how="outer"
1238+
feature_group_dataframe,
1239+
left_on="base_time",
1240+
right_on="fg_time",
1241+
how="outer",
11611242
)
11621243

11631244
expect_df = merged_df.sort_values(by=list(merged_df.columns)).reset_index(drop=True)

0 commit comments

Comments
 (0)