Skip to content

Add batch_get_record and search API for FeatureStore #3580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/sagemaker/feature_store/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
from sagemaker import Session
from sagemaker.feature_store.dataset_builder import DatasetBuilder
from sagemaker.feature_store.feature_group import FeatureGroup
from sagemaker.feature_store.inputs import (
Filter,
ResourceEnum,
SearchOperatorEnum,
SortOrderEnum,
Identifier,
)


@attr.s
Expand Down Expand Up @@ -114,6 +121,7 @@ def list_feature_groups(
sort_by (str): The value on which the FeatureGroup list is sorted.
max_results (int): The maximum number of results returned by ListFeatureGroups.
next_token (str): A token to resume pagination of ListFeatureGroups results.

Returns:
Response dict from service.
"""
Expand All @@ -128,3 +136,61 @@ def list_feature_groups(
max_results=max_results,
next_token=next_token,
)

def batch_get_record(self, identifiers: Sequence[Identifier]) -> Dict[str, Any]:
"""Get record in batch from FeatureStore

Args:
identifiers (Sequence[Identifier]): A list of identifiers to uniquely identify records
in FeatureStore.

Returns:
Response dict from service.
"""
batch_get_record_identifiers = [identifier.to_dict() for identifier in identifiers]
return self.sagemaker_session.batch_get_record(identifiers=batch_get_record_identifiers)

def search(
self,
resource: ResourceEnum,
filters: Sequence[Filter] = None,
operator: SearchOperatorEnum = None,
sort_by: str = None,
sort_order: SortOrderEnum = None,
next_token: str = None,
max_results: int = None,
) -> Dict[str, Any]:
"""Search for FeatureGroups or FeatureMetadata satisfying given filters.

Args:
resource (ResourceEnum): The name of the Amazon SageMaker resource to search for.
Valid values are ``FeatureGroup`` or ``FeatureMetadata``.
filters (Sequence[Filter]): A list of filter objects (Default: None).
operator (SearchOperatorEnum): A Boolean operator used to evaluate the filters.
Valid values are ``And`` or ``Or``. The default is ``And`` (Default: None).
sort_by (str): The name of the resource property used to sort the ``SearchResults``.
The default is ``LastModifiedTime``.
sort_order (SortOrderEnum): How ``SearchResults`` are ordered.
Valid values are ``Ascending`` or ``Descending``. The default is ``Descending``.
next_token (str): If more than ``MaxResults`` resources match the specified
filters, the response includes a ``NextToken``. The ``NextToken`` can be passed to
the next ``SearchRequest`` to continue retrieving results (Default: None).
max_results (int): The maximum number of results to return (Default: None).

Returns:
Response dict from service.
"""
search_expression = {}
if filters:
search_expression["Filters"] = [filter.to_dict() for filter in filters]
if operator:
search_expression["Operator"] = str(operator)

return self.sagemaker_session.search(
resource=str(resource),
search_expression=search_expression,
sort_by=sort_by,
sort_order=None if not sort_order else str(sort_order),
next_token=next_token,
max_results=max_results,
)
126 changes: 125 additions & 1 deletion src/sagemaker/feature_store/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from __future__ import absolute_import

import abc
from typing import Dict, Any
from typing import Dict, Any, List
from enum import Enum

import attr
Expand Down Expand Up @@ -245,3 +245,127 @@ def to_dict(self) -> Dict[str, Any]:
Key=self.key,
Value=self.value,
)


class ResourceEnum(Enum):
"""Enum of resources.

The data type of resource can be ``FeatureGroup`` or ``FeatureMetadata``.
"""

def __str__(self):
"""Override str method to return enum value."""
return str(self.value)

FEATURE_GROUP = "FeatureGroup"
FEATURE_METADATA = "FeatureMetadata"


class SearchOperatorEnum(Enum):
"""Enum of search operators.

The data type of search operator can be ``And`` or ``Or``.
"""

def __str__(self):
"""Override str method to return enum value."""
return str(self.value)

AND = "And"
OR = "Or"


class SortOrderEnum(Enum):
"""Enum of sort orders.

The data type of sort order can be ``Ascending`` or ``Descending``.
"""

def __str__(self):
"""Override str method to return enum value."""
return str(self.value)

ASCENDING = "Ascending"
DESCENDING = "Descending"


class FilterOperatorEnum(Enum):
"""Enum of filter operators.

The data type of filter operator can be ``Equals``, ``NotEquals``, ``GreaterThan``,
``GreaterThanOrEqualTo``, ``LessThan``, ``LessThanOrEqualTo``, ``Contains``, ``Exists``,
``NotExists``, or ``In``.
"""

def __str__(self):
"""Override str method to return enum value."""
return str(self.value)

EQUALS = "Equals"
NOT_EQUALS = "NotEquals"
GREATER_THAN = "GreaterThan"
GREATER_THAN_OR_EQUAL_TO = "GreaterThanOrEqualTo"
LESS_THAN = "LessThan"
LESS_THAN_OR_EQUAL_TO = "LessThanOrEqualTo"
CONTAINS = "Contains"
EXISTS = "Exists"
NOT_EXISTS = "NotExists"
IN = "In"


@attr.s
class Filter(Config):
"""Filter for FeatureStore search.

Attributes:
name (str): A resource property name.
value (str): A value used with ``Name`` and ``Operator`` to determine which resources
satisfy the filter's condition.
operator (FilterOperatorEnum): A Boolean binary operator that is used to evaluate the
filter. If specify ``Value`` without ``Operator``, Amazon SageMaker uses ``Equals``
(default: None).
"""

name: str = attr.ib()
value: str = attr.ib()
operator: FilterOperatorEnum = attr.ib(default=None)

def to_dict(self) -> Dict[str, Any]:
"""Construct a dictionary based on the attributes provided.

Returns:
dict represents the attributes.
"""
return Config.construct_dict(
Name=self.name,
Value=self.value,
Operator=None if not self.operator else str(self.operator),
)


@attr.s
class Identifier(Config):
"""Identifier of batch get record API.

Attributes:
feature_group_name (str): name of a feature group.
record_identifiers_value_as_string (List[str]): string value of record identifier.
feature_names (List[str]): list of feature names (default: None).
"""

feature_group_name: str = attr.ib()
record_identifiers_value_as_string: List[str] = attr.ib()
feature_names: List[str] = attr.ib(default=None)

def to_dict(self) -> Dict[str, Any]:
"""Construct a dictionary based on the attributes provided.

Returns:
dict represents the attributes.
"""

return Config.construct_dict(
FeatureGroupName=self.feature_group_name,
RecordIdentifiersValueAsString=self.record_identifiers_value_as_string,
FeatureNames=None if not self.feature_names else self.feature_names,
)
56 changes: 56 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4476,6 +4476,48 @@ def describe_feature_metadata(
FeatureGroupName=feature_group_name, FeatureName=feature_name
)

def search(
self,
resource: str,
search_expression: Dict[str, any] = None,
sort_by: str = None,
sort_order: str = None,
next_token: str = None,
max_results: int = None,
) -> Dict[str, Any]:
"""Search for SageMaker resources satisfying given filters.

Args:
resource (str): The name of the Amazon SageMaker resource to search for.
search_expression (Dict[str, any]): A Boolean conditional statement. Resources must
satisfy this condition to be included in search results.
sort_by (str): The name of the resource property used to sort the ``SearchResults``.
The default is ``LastModifiedTime``.
sort_order (str): How ``SearchResults`` are ordered.
Valid values are ``Ascending`` or ``Descending``. The default is ``Descending``.
next_token (str): If more than ``MaxResults`` resources match the specified
``SearchExpression``, the response includes a ``NextToken``. The ``NextToken`` can
be passed to the next ``SearchRequest`` to continue retrieving results.
max_results (int): The maximum number of results to return.

Returns:
Response dict from service.
"""
search_args = {"Resource": resource}

if search_expression:
search_args["SearchExpression"] = search_expression
if sort_by:
search_args["SortBy"] = sort_by
if sort_order:
search_args["SortOrder"] = sort_order
if next_token:
search_args["NextToken"] = next_token
if max_results:
search_args["MaxResults"] = max_results

return self.sagemaker_client.search(**search_args)

def put_record(
self,
feature_group_name: str,
Expand Down Expand Up @@ -4535,6 +4577,20 @@ def get_record(

return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args)

def batch_get_record(self, identifiers: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
"""Gets a batch of record from FeatureStore.

Args:
identifiers (Sequence[Dict[str, Any]]): list of identifiers to uniquely identify records
in FeatureStore.

Returns:
Response dict from service.
"""
batch_get_record_args = {"Identifiers": identifiers}

return self.sagemaker_featurestore_runtime_client.batch_get_record(**batch_get_record_args)

def start_query_execution(
self,
catalog: str,
Expand Down
51 changes: 49 additions & 2 deletions tests/integ/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@
from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition
from sagemaker.feature_store.feature_group import FeatureGroup
from sagemaker.feature_store.feature_store import FeatureStore
from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum
from sagemaker.feature_store.inputs import (
FeatureValue,
FeatureParameter,
TableFormatEnum,
Filter,
ResourceEnum,
Identifier,
)
from sagemaker.session import get_execution_role, Session
from tests.integ.timeout import timeout

Expand Down Expand Up @@ -321,7 +328,7 @@ def test_create_feature_group_glue_table_format(
assert table_format == "Glue"


def test_get_record(
def test_get_and_batch_get_record(
feature_store_session,
role,
feature_group_name,
Expand Down Expand Up @@ -367,6 +374,24 @@ def test_get_record(
)
assert retrieved_record is None

# Retrieve data using batch_get_record
feature_store = FeatureStore(sagemaker_session=feature_store_session)
records = feature_store.batch_get_record(
identifiers=[
Identifier(
feature_group_name=feature_group_name,
record_identifiers_value_as_string=[record_identifier_value_as_string],
feature_names=record_names,
)
]
)["Records"]
assert records[0]["FeatureGroupName"] == feature_group_name
assert records[0]["RecordIdentifierValueAsString"] == record_identifier_value_as_string
assert len(records[0]["Record"]) == len(record_names)
for feature in records[0]["Record"]:
assert feature["FeatureName"] in record_names
assert feature["FeatureName"] is not removed_feature_name


def test_delete_record(
feature_store_session,
Expand Down Expand Up @@ -503,6 +528,28 @@ def test_feature_metadata(
assert 1 == len(describe_feature_metadata.get("Parameters"))


def test_search(feature_store_session, role, feature_group_name, pandas_data_frame):
feature_store = FeatureStore(sagemaker_session=feature_store_session)
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
feature_group.load_feature_definitions(data_frame=pandas_data_frame)

with cleanup_feature_group(feature_group):
feature_group.create(
s3_uri=False,
record_identifier_name="feature1",
event_time_feature_name="feature3",
role_arn=role,
enable_online_store=True,
)
_wait_for_feature_group_create(feature_group)
output = feature_store.search(
resource=ResourceEnum.FEATURE_GROUP,
filters=[Filter(name="FeatureGroupName", value=feature_group_name)],
)
print(output)
assert output["Results"][0]["FeatureGroup"]["FeatureGroupName"] == feature_group_name


def test_ingest_without_string_feature(
feature_store_session,
role,
Expand Down
Loading