Skip to content

feature: Feature Store dataset builder, delete_record, get_record, list_feature_group #3534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
990 changes: 990 additions & 0 deletions src/sagemaker/feature_store/dataset_builder.py

Large diffs are not rendered by default.

45 changes: 41 additions & 4 deletions src/sagemaker/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,13 +435,14 @@ class FeatureGroup:
"uint64",
]
_FLOAT_TYPES = ["float_", "float16", "float32", "float64"]
_DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = {
DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = {
type: FeatureTypeEnum.INTEGRAL for type in _INTEGER_TYPES
}
_DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update(
DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update(
{type: FeatureTypeEnum.FRACTIONAL for type in _FLOAT_TYPES}
)
_DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING
DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING
DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["object"] = FeatureTypeEnum.STRING

_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = {
FeatureTypeEnum.INTEGRAL.value: "INT",
Expand Down Expand Up @@ -629,7 +630,7 @@ def load_feature_definitions(
"""
feature_definitions = []
for column in data_frame:
feature_type = self._DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
feature_type = self.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
str(data_frame[column].dtype), None
)
if feature_type:
Expand All @@ -644,6 +645,23 @@ def load_feature_definitions(
self.feature_definitions = feature_definitions
return self.feature_definitions

def get_record(
self, record_identifier_value_as_string: str, feature_names: Sequence[str] = None
) -> Sequence[Dict[str, str]]:
"""Get a single record in a FeatureGroup

Args:
record_identifier_value_as_string (String):
a String representing the value of the record identifier.
feature_names (Sequence[String]):
a list of Strings representing feature names.
"""
return self.sagemaker_session.get_record(
record_identifier_value_as_string=record_identifier_value_as_string,
feature_group_name=self.name,
feature_names=feature_names,
).get("Record")

def put_record(self, record: Sequence[FeatureValue]):
"""Put a single record in the FeatureGroup.

Expand All @@ -654,6 +672,25 @@ def put_record(self, record: Sequence[FeatureValue]):
feature_group_name=self.name, record=[value.to_dict() for value in record]
)

def delete_record(
self,
record_identifier_value_as_string: str,
event_time: str,
):
"""Delete a single record from a FeatureGroup.

Args:
record_identifier_value_as_string (String):
a String representing the value of the record identifier.
event_time (String):
a timestamp format String indicating when the deletion event occurred.
"""
return self.sagemaker_session.delete_record(
feature_group_name=self.name,
record_identifier_value_as_string=record_identifier_value_as_string,
event_time=event_time,
)

def ingest(
self,
data_frame: DataFrame,
Expand Down
130 changes: 130 additions & 0 deletions src/sagemaker/feature_store/feature_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Feature Store.

Amazon SageMaker Feature Store is a fully managed, purpose-built repository to store, share, and
manage features for machine learning (ML) models.
"""
from __future__ import absolute_import

import datetime
from typing import Any, Dict, Sequence, Union

import attr
import pandas as pd

from sagemaker import Session
from sagemaker.feature_store.dataset_builder import DatasetBuilder
from sagemaker.feature_store.feature_group import FeatureGroup


@attr.s
class FeatureStore:
"""FeatureStore definition.

This class instantiates a FeatureStore object that comprises a SageMaker session instance.

Attributes:
sagemaker_session (Session): session instance to perform boto calls.
"""

sagemaker_session: Session = attr.ib(default=Session)

def create_dataset(
self,
base: Union[FeatureGroup, pd.DataFrame],
output_path: str,
record_identifier_feature_name: str = None,
event_time_identifier_feature_name: str = None,
included_feature_names: Sequence[str] = None,
kms_key_id: str = None,
) -> DatasetBuilder:
"""Create a Dataset Builder for generating a Dataset.

Args:
base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
pandas.DataFrame and will be used to merge other FeatureGroups and generate a
Dataset.
output_path (str): An S3 URI which stores the output .csv file.
record_identifier_feature_name (str): A string representing the record identifier
feature if base is a DataFrame (default: None).
event_time_identifier_feature_name (str): A string representing the event time
identifier feature if base is a DataFrame (default: None).
included_feature_names (List[str]): A list of features to be included in the output
(default: None).
kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file
(default: None).

Raises:
ValueError: Base is a Pandas DataFrame but no record identifier feature name nor event
time identifier feature name is provided.
"""
if isinstance(base, pd.DataFrame):
if record_identifier_feature_name is None or event_time_identifier_feature_name is None:
raise ValueError(
"You must provide a record identifier feature name and an event time "
+ "identifier feature name if specify DataFrame as base."
)
return DatasetBuilder(
self.sagemaker_session,
base,
output_path,
record_identifier_feature_name,
event_time_identifier_feature_name,
included_feature_names,
kms_key_id,
)

def list_feature_groups(
self,
name_contains: str = None,
feature_group_status_equals: str = None,
offline_store_status_equals: str = None,
creation_time_after: datetime.datetime = None,
creation_time_before: datetime.datetime = None,
sort_order: str = None,
sort_by: str = None,
max_results: int = None,
next_token: str = None,
) -> Dict[str, Any]:
"""List all FeatureGroups satisfying given filters.

Args:
name_contains (str): A string that partially matches one or more FeatureGroups' names.
Filters FeatureGroups by name.
feature_group_status_equals (str): A FeatureGroup status.
Filters FeatureGroups by FeatureGroup status.
offline_store_status_equals (str): An OfflineStore status.
Filters FeatureGroups by OfflineStore status.
creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups
created after a specific date and time.
creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups
created before a specific date and time.
sort_order (str): The order in which FeatureGroups are listed.
sort_by (str): The value on which the FeatureGroup list is sorted.
max_results (int): The maximum number of results returned by ListFeatureGroups.
next_token (str): A token to resume pagination of ListFeatureGroups results.
Returns:
Response dict from service.
"""
return self.sagemaker_session.list_feature_groups(
name_contains=name_contains,
feature_group_status_equals=feature_group_status_equals,
offline_store_status_equals=offline_store_status_equals,
creation_time_after=creation_time_after,
creation_time_before=creation_time_before,
sort_order=sort_order,
sort_by=sort_by,
max_results=max_results,
next_token=next_token,
)
94 changes: 93 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
# For each object key, create the directory on the local machine if needed, and then
# download the file.
for key in keys:
tail_s3_uri_path = os.path.basename(key_prefix)
tail_s3_uri_path = os.path.basename(key)
if not os.path.splitext(key_prefix)[1]:
tail_s3_uri_path = os.path.relpath(key, key_prefix)
destination_path = os.path.join(path, tail_s3_uri_path)
Expand Down Expand Up @@ -4341,6 +4341,56 @@ def update_feature_group(
FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions
)

def list_feature_groups(
self,
name_contains,
feature_group_status_equals,
offline_store_status_equals,
creation_time_after,
creation_time_before,
sort_order,
sort_by,
max_results,
next_token,
) -> Dict[str, Any]:
"""List all FeatureGroups satisfying given filters.

Args:
name_contains (str): A string that partially matches one or more FeatureGroups' names.
Filters FeatureGroups by name.
feature_group_status_equals (str): A FeatureGroup status.
Filters FeatureGroups by FeatureGroup status.
offline_store_status_equals (str): An OfflineStore status.
Filters FeatureGroups by OfflineStore status.
creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups
created after a specific date and time.
creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups
created before a specific date and time.
sort_order (str): The order in which FeatureGroups are listed.
sort_by (str): The value on which the FeatureGroup list is sorted.
max_results (int): The maximum number of results returned by ListFeatureGroups.
next_token (str): A token to resume pagination of ListFeatureGroups results.
Returns:
Response dict from service.
"""
list_feature_groups_args = {}

def check_object(key, value):
if value is not None:
list_feature_groups_args[key] = value

check_object("NameContains", name_contains)
check_object("FeatureGroupStatusEquals", feature_group_status_equals)
check_object("OfflineStoreStatusEquals", offline_store_status_equals)
check_object("CreationTimeAfter", creation_time_after)
check_object("CreationTimeBefore", creation_time_before)
check_object("SortOrder", sort_order)
check_object("SortBy", sort_by)
check_object("MaxResults", max_results)
check_object("NextToken", next_token)

return self.sagemaker_client.list_feature_groups(**list_feature_groups_args)

def update_feature_metadata(
self,
feature_group_name: str,
Expand Down Expand Up @@ -4408,6 +4458,48 @@ def put_record(
Record=record,
)

def delete_record(
self,
feature_group_name: str,
record_identifier_value_as_string: str,
event_time: str,
):
"""Deletes a single record from the FeatureGroup.

Args:
feature_group_name (str): name of the FeatureGroup.
record_identifier_value_as_string (str): name of the record identifier.
event_time (str): a timestamp indicating when the deletion event occurred.
"""
return self.sagemaker_featurestore_runtime_client.delete_record(
FeatureGroupName=feature_group_name,
RecordIdentifierValueAsString=record_identifier_value_as_string,
EventTime=event_time,
)

def get_record(
self,
record_identifier_value_as_string: str,
feature_group_name: str,
feature_names: Sequence[str],
) -> Dict[str, Sequence[Dict[str, str]]]:
"""Gets a single record in the FeatureGroup.

Args:
record_identifier_value_as_string (str): name of the record identifier.
feature_group_name (str): name of the FeatureGroup.
feature_names (Sequence[str]): list of feature names.
"""
get_record_args = {
"FeatureGroupName": feature_group_name,
"RecordIdentifierValueAsString": record_identifier_value_as_string,
}

if feature_names:
get_record_args["FeatureNames"] = feature_names

return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args)

def start_query_execution(
self,
catalog: str,
Expand Down
Loading