Skip to content

Commit 5067bdd

Browse files
committed
Add batch_get_record and search API for FeatureStore
1 parent 37c0d3a commit 5067bdd

File tree

7 files changed

+498
-3
lines changed

7 files changed

+498
-3
lines changed

src/sagemaker/feature_store/feature_store.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@
2626
from sagemaker import Session
2727
from sagemaker.feature_store.dataset_builder import DatasetBuilder
2828
from sagemaker.feature_store.feature_group import FeatureGroup
29+
from sagemaker.feature_store.inputs import (
30+
Filter,
31+
ResourceEnum,
32+
SearchOperatorEnum,
33+
SortOrderEnum,
34+
Identifier,
35+
)
2936

3037

3138
@attr.s
@@ -114,6 +121,7 @@ def list_feature_groups(
114121
sort_by (str): The value on which the FeatureGroup list is sorted.
115122
max_results (int): The maximum number of results returned by ListFeatureGroups.
116123
next_token (str): A token to resume pagination of ListFeatureGroups results.
124+
117125
Returns:
118126
Response dict from service.
119127
"""
@@ -128,3 +136,61 @@ def list_feature_groups(
128136
max_results=max_results,
129137
next_token=next_token,
130138
)
139+
140+
def batch_get_record(self, identifiers: Sequence[Identifier]) -> Dict[str, Any]:
141+
"""Get record in batch from FeatureStore
142+
143+
Args:
144+
identifiers (Sequence[Identifier]): A list of identifiers to uniquely identify records
145+
in FeatureStore.
146+
147+
Returns:
148+
Response dict from service.
149+
"""
150+
batch_get_record_identifiers = [identifier.to_dict() for identifier in identifiers]
151+
return self.sagemaker_session.batch_get_record(identifiers=batch_get_record_identifiers)
152+
153+
def search(
154+
self,
155+
resource: ResourceEnum,
156+
filters: Sequence[Filter] = None,
157+
operator: SearchOperatorEnum = None,
158+
sort_by: str = None,
159+
sort_order: SortOrderEnum = None,
160+
next_token: str = None,
161+
max_results: int = None,
162+
) -> Dict[str, Any]:
163+
"""Search for FeatureGroups or FeatureMetadata satisfying given filters.
164+
165+
Args:
166+
resource (ResourceEnum): The name of the Amazon SageMaker resource to search for.
167+
Valid values are ``FeatureGroup`` or ``FeatureMetadata``.
168+
filters (Sequence[Filter]): A list of filter objects (Default: None).
169+
operator (SearchOperatorEnum): A Boolean operator used to evaluate the filters.
170+
Valid values are ``And`` or ``Or``. The default is ``And`` (Default: None).
171+
sort_by (str): The name of the resource property used to sort the ``SearchResults``.
172+
The default is ``LastModifiedTime``.
173+
sort_order (SortOrderEnum): How ``SearchResults`` are ordered.
174+
Valid values are ``Ascending`` or ``Descending``. The default is ``Descending``.
175+
next_token (str): If more than ``MaxResults`` resources match the specified
176+
filters, the response includes a ``NextToken``. The ``NextToken`` can be passed to
177+
the next ``SearchRequest`` to continue retrieving results (Default: None).
178+
max_results (int): The maximum number of results to return (Default: None).
179+
180+
Returns:
181+
Response dict from service.
182+
"""
183+
search_expression = {}
184+
if filters:
185+
search_expression["Filters"] = [filter.to_dict() for filter in filters]
186+
if operator:
187+
search_expression["Operator"] = str(operator)
188+
189+
return self.sagemaker_session.search(
190+
resource=str(resource),
191+
search_expression=search_expression,
192+
sort_by=sort_by,
193+
sort_order=None if not sort_order else str(sort_order),
194+
next_token=next_token,
195+
max_results=max_results,
196+
)

src/sagemaker/feature_store/inputs.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from __future__ import absolute_import
3030

3131
import abc
32-
from typing import Dict, Any
32+
from typing import Dict, Any, List
3333
from enum import Enum
3434

3535
import attr
@@ -245,3 +245,127 @@ def to_dict(self) -> Dict[str, Any]:
245245
Key=self.key,
246246
Value=self.value,
247247
)
248+
249+
250+
class ResourceEnum(Enum):
251+
"""Enum of resources.
252+
253+
The data type of resource can be ``FeatureGroup`` or ``FeatureMetadata``.
254+
"""
255+
256+
def __str__(self):
257+
"""Override str method to return enum value."""
258+
return str(self.value)
259+
260+
FEATURE_GROUP = "FeatureGroup"
261+
FEATURE_METADATA = "FeatureMetadata"
262+
263+
264+
class SearchOperatorEnum(Enum):
265+
"""Enum of search operators.
266+
267+
The data type of search operator can be ``And`` or ``Or``.
268+
"""
269+
270+
def __str__(self):
271+
"""Override str method to return enum value."""
272+
return str(self.value)
273+
274+
AND = "And"
275+
OR = "Or"
276+
277+
278+
class SortOrderEnum(Enum):
279+
"""Enum of sort orders.
280+
281+
The data type of sort order can be ``Ascending`` or ``Descending``.
282+
"""
283+
284+
def __str__(self):
285+
"""Override str method to return enum value."""
286+
return str(self.value)
287+
288+
ASCENDING = "Ascending"
289+
DESCENDING = "Descending"
290+
291+
292+
class FilterOperatorEnum(Enum):
293+
"""Enum of filter operators.
294+
295+
The data type of filter operator can be ``Equals``, ``NotEquals``, ``GreaterThan``,
296+
``GreaterThanOrEqualTo``, ``LessThan``, ``LessThanOrEqualTo``, ``Contains``, ``Exists``,
297+
``NotExists``, or ``In``.
298+
"""
299+
300+
def __str__(self):
301+
"""Override str method to return enum value."""
302+
return str(self.value)
303+
304+
EQUALS = "Equals"
305+
NOT_EQUALS = "NotEquals"
306+
GREATER_THAN = "GreaterThan"
307+
GREATER_THAN_OR_EQUAL_TO = "GreaterThanOrEqualTo"
308+
LESS_THAN = "LessThan"
309+
LESS_THAN_OR_EQUAL_TO = "LessThanOrEqualTo"
310+
CONTAINS = "Contains"
311+
EXISTS = "Exists"
312+
NOT_EXISTS = "NotExists"
313+
IN = "In"
314+
315+
316+
@attr.s
317+
class Filter(Config):
318+
"""Filter for FeatureStore search.
319+
320+
Attributes:
321+
name (str): A resource property name.
322+
value (str): A value used with ``Name`` and ``Operator`` to determine which resources
323+
satisfy the filter's condition.
324+
operator (FilterOperatorEnum): A Boolean binary operator that is used to evaluate the
325+
filter. If specify ``Value`` without ``Operator``, Amazon SageMaker uses ``Equals``
326+
(default: None).
327+
"""
328+
329+
name: str = attr.ib()
330+
value: str = attr.ib()
331+
operator: FilterOperatorEnum = attr.ib(default=None)
332+
333+
def to_dict(self) -> Dict[str, Any]:
334+
"""Construct a dictionary based on the attributes provided.
335+
336+
Returns:
337+
dict represents the attributes.
338+
"""
339+
return Config.construct_dict(
340+
Name=self.name,
341+
Value=self.value,
342+
Operator=None if not self.operator else str(self.operator),
343+
)
344+
345+
346+
@attr.s
347+
class Identifier(Config):
348+
"""Identifier of batch get record API.
349+
350+
Attributes:
351+
feature_group_name (str): name of a feature group.
352+
record_identifiers_value_as_string (List[str]): string value of record identifier.
353+
feature_names (List[str]): list of feature names (default: None).
354+
"""
355+
356+
feature_group_name: str = attr.ib()
357+
record_identifiers_value_as_string: List[str] = attr.ib()
358+
feature_names: List[str] = attr.ib(default=None)
359+
360+
def to_dict(self) -> Dict[str, Any]:
361+
"""Construct a dictionary based on the attributes provided.
362+
363+
Returns:
364+
dict represents the attributes.
365+
"""
366+
367+
return Config.construct_dict(
368+
FeatureGroupName=self.feature_group_name,
369+
RecordIdentifiersValueAsString=self.record_identifiers_value_as_string,
370+
FeatureNames=None if not self.feature_names else self.feature_names,
371+
)

src/sagemaker/session.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4473,6 +4473,48 @@ def describe_feature_metadata(
44734473
FeatureGroupName=feature_group_name, FeatureName=feature_name
44744474
)
44754475

4476+
def search(
4477+
self,
4478+
resource: str,
4479+
search_expression: Dict[str, any] = None,
4480+
sort_by: str = None,
4481+
sort_order: str = None,
4482+
next_token: str = None,
4483+
max_results: int = None,
4484+
) -> Dict[str, Any]:
4485+
"""Search for SageMaker resources satisfying given filters.
4486+
4487+
Args:
4488+
resource (str): The name of the Amazon SageMaker resource to search for.
4489+
search_expression (Dict[str, any]): A Boolean conditional statement. Resources must
4490+
satisfy this condition to be included in search results.
4491+
sort_by (str): The name of the resource property used to sort the ``SearchResults``.
4492+
The default is ``LastModifiedTime``.
4493+
sort_order (str): How ``SearchResults`` are ordered.
4494+
Valid values are ``Ascending`` or ``Descending``. The default is ``Descending``.
4495+
next_token (str): If more than ``MaxResults`` resources match the specified
4496+
``SearchExpression``, the response includes a ``NextToken``. The ``NextToken`` can
4497+
be passed to the next ``SearchRequest`` to continue retrieving results.
4498+
max_results (int): The maximum number of results to return.
4499+
4500+
Returns:
4501+
Response dict from service.
4502+
"""
4503+
search_args = {"Resource": resource}
4504+
4505+
if search_expression:
4506+
search_args["SearchExpression"] = search_expression
4507+
if sort_by:
4508+
search_args["SortBy"] = sort_by
4509+
if sort_order:
4510+
search_args["SortOrder"] = sort_order
4511+
if next_token:
4512+
search_args["NextToken"] = next_token
4513+
if max_results:
4514+
search_args["MaxResults"] = max_results
4515+
4516+
return self.sagemaker_client.search(**search_args)
4517+
44764518
def put_record(
44774519
self,
44784520
feature_group_name: str,
@@ -4532,6 +4574,20 @@ def get_record(
45324574

45334575
return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args)
45344576

4577+
def batch_get_record(self, identifiers: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
4578+
"""Gets a batch of record from FeatureStore.
4579+
4580+
Args:
4581+
identifiers (Sequence[Dict[str, Any]]): list of identifiers to uniquely identify records
4582+
in FeatureStore.
4583+
4584+
Returns:
4585+
Response dict from service.
4586+
"""
4587+
batch_get_record_args = {"Identifiers": identifiers}
4588+
4589+
return self.sagemaker_featurestore_runtime_client.batch_get_record(**batch_get_record_args)
4590+
45354591
def start_query_execution(
45364592
self,
45374593
catalog: str,

tests/integ/test_feature_store.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@
2626
from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition
2727
from sagemaker.feature_store.feature_group import FeatureGroup
2828
from sagemaker.feature_store.feature_store import FeatureStore
29-
from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum
29+
from sagemaker.feature_store.inputs import (
30+
FeatureValue,
31+
FeatureParameter,
32+
TableFormatEnum,
33+
Filter,
34+
ResourceEnum,
35+
Identifier,
36+
)
3037
from sagemaker.session import get_execution_role, Session
3138
from tests.integ.timeout import timeout
3239

@@ -321,7 +328,7 @@ def test_create_feature_group_glue_table_format(
321328
assert table_format == "Glue"
322329

323330

324-
def test_get_record(
331+
def test_get_and_batch_get_record(
325332
feature_store_session,
326333
role,
327334
feature_group_name,
@@ -367,6 +374,24 @@ def test_get_record(
367374
)
368375
assert retrieved_record is None
369376

377+
# Retrieve data using batch_get_record
378+
feature_store = FeatureStore(sagemaker_session=feature_store_session)
379+
records = feature_store.batch_get_record(
380+
identifiers=[
381+
Identifier(
382+
feature_group_name=feature_group_name,
383+
record_identifiers_value_as_string=[record_identifier_value_as_string],
384+
feature_names=record_names,
385+
)
386+
]
387+
)["Records"]
388+
assert records[0]["FeatureGroupName"] == feature_group_name
389+
assert records[0]["RecordIdentifierValueAsString"] == record_identifier_value_as_string
390+
assert len(records[0]["Record"]) == len(record_names)
391+
for feature in records[0]["Record"]:
392+
assert feature["FeatureName"] in record_names
393+
assert feature["FeatureName"] is not removed_feature_name
394+
370395

371396
def test_delete_record(
372397
feature_store_session,
@@ -503,6 +528,28 @@ def test_feature_metadata(
503528
assert 1 == len(describe_feature_metadata.get("Parameters"))
504529

505530

531+
def test_search(feature_store_session, role, feature_group_name, pandas_data_frame):
532+
feature_store = FeatureStore(sagemaker_session=feature_store_session)
533+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
534+
feature_group.load_feature_definitions(data_frame=pandas_data_frame)
535+
536+
with cleanup_feature_group(feature_group):
537+
feature_group.create(
538+
s3_uri=False,
539+
record_identifier_name="feature1",
540+
event_time_feature_name="feature3",
541+
role_arn=role,
542+
enable_online_store=True,
543+
)
544+
_wait_for_feature_group_create(feature_group)
545+
output = feature_store.search(
546+
resource=ResourceEnum.FEATURE_GROUP,
547+
filters=[Filter(name="FeatureGroupName", value=feature_group_name)],
548+
)
549+
print(output)
550+
assert output["Results"][0]["FeatureGroup"]["FeatureGroupName"] == feature_group_name
551+
552+
506553
def test_ingest_without_string_feature(
507554
feature_store_session,
508555
role,

0 commit comments

Comments
 (0)