Skip to content

Commit c211e48

Browse files
committed
feature: utilities to faciliate working with Feature Groups
1 parent ca1e535 commit c211e48

File tree

1 file changed

+296
-0
lines changed

1 file changed

+296
-0
lines changed
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Utilities for working with FeatureGroups and FeatureStores."""
14+
from __future__ import absolute_import
15+
16+
import re
17+
import logging
18+
19+
from typing import Union
20+
from pathlib import Path
21+
22+
import pandas
23+
import boto3
24+
from pandas import DataFrame, Series, read_csv
25+
26+
from sagemaker.feature_store.feature_group import FeatureGroup
27+
from sagemaker.session import Session
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
def get_session_from_role(region: str, assume_role: str = None) -> Session:
33+
"""Method use to get the :class:`sagemaker.session.Session` from a role and a region.
34+
Helpful in case it's invoke from a session with a role without permission it can assume
35+
another role temporarily to perform certain tasks.
36+
Args:
37+
assume_role: role name
38+
region: region name
39+
Returns:
40+
"""
41+
boto_session = boto3.Session(region_name=region)
42+
43+
# It will try to assume the role specified
44+
if assume_role:
45+
sts = boto_session.client(
46+
"sts", region_name=region, endpoint_url="https://sts.eu-west-1.amazonaws.com"
47+
)
48+
49+
metadata = sts.assume_role(RoleArn=assume_role, RoleSessionName="SagemakerExecution")
50+
51+
access_key_id = metadata["Credentials"]["AccessKeyId"]
52+
secret_access_key = metadata["Credentials"]["SecretAccessKey"]
53+
session_token = metadata["Credentials"]["SessionToken"]
54+
55+
boto_session = boto3.session.Session(
56+
region_name=region,
57+
aws_access_key_id=access_key_id,
58+
aws_secret_access_key=secret_access_key,
59+
aws_session_token=session_token,
60+
)
61+
62+
# Sessions
63+
sagemaker_client = boto_session.client("sagemaker")
64+
sagemaker_runtime = boto_session.client("sagemaker-runtime")
65+
runtime_client = boto_session.client(service_name="sagemaker-featurestore-runtime")
66+
sagemaker_session = Session(
67+
boto_session=boto_session,
68+
sagemaker_client=sagemaker_client,
69+
sagemaker_runtime_client=sagemaker_runtime,
70+
sagemaker_featurestore_runtime_client=runtime_client,
71+
)
72+
73+
return sagemaker_session
74+
75+
76+
def get_feature_group_as_dataframe(
77+
feature_group_name: str,
78+
athena_bucket: str,
79+
query: str = """SELECT * FROM "sagemaker_featurestore"."#{table}"
80+
WHERE is_deleted=False """,
81+
role: str = None,
82+
region: str = None,
83+
session=None,
84+
event_time_feature_name: str = None,
85+
latest_ingestion: bool = True,
86+
verbose: bool = True,
87+
**pandas_read_csv_kwargs,
88+
) -> DataFrame:
89+
"""Get a :class:`sagemaker.feature_store.feature_group.FeatureGroup` as a pandas.DataFrame
90+
Description:
91+
Method to run an athena query over a Feature Group in a Feature Store
92+
to retrieve its data.It needs the sagemaker.Session linked to a role
93+
or the role and region used to work Feature Stores.Returns a dataframe
94+
with the data.
95+
Args:
96+
region (str): region of the target Feature Store
97+
feature_group_name (str): feature store name
98+
query (str): query to run. By default, it will take the latest ingest with data that
99+
wasn't deleted. If latest_ingestion is False it will take all the data
100+
in the feature group that wasn't deleted. It needs to use the keyword
101+
"#{table}" to refer to the FeatureGroup name. e.g.:
102+
'SELECT * FROM "sagemaker_featurestore"."#{table}"'
103+
athena_bucket (str): Amazon S3 bucket for running the query
104+
role (str): role of the account used to extract data from feature store
105+
session (str): :class:`sagemaker.session.Session`
106+
of SageMaker used to work with the feature store
107+
event_time_feature_name (str): eventTimeId feature. Mandatory only if the
108+
latest ingestion is True
109+
latest_ingestion (bool): if True it will get the data only from the latest ingestion.
110+
If False it will take whatever is specified in the query, or
111+
if not specify it, it will get all the data that wasn't deleted.
112+
verbose (bool): if True show messages, if False is silent.
113+
Returns:
114+
dataset (pandas.DataFrame): dataset with the data retrieved from feature group
115+
"""
116+
117+
logger.setLevel(logging.WARNING)
118+
if verbose:
119+
logger.setLevel(logging.INFO)
120+
121+
if latest_ingestion:
122+
if event_time_feature_name is not None:
123+
query += str(
124+
f"AND {event_time_feature_name}=(SELECT " +
125+
f"MAX({event_time_feature_name}) FROM " +
126+
'"sagemaker_featurestore"."#{table}")'
127+
)
128+
else:
129+
exc = Exception(
130+
"Argument event_time_feature_name must be specified "
131+
"when using latest_ingestion=True."
132+
)
133+
logger.exception(exc)
134+
raise exc
135+
query += ";"
136+
137+
if session is not None:
138+
sagemaker_session = session
139+
elif role is not None and region is not None:
140+
sagemaker_session = get_session_from_role(region=region)
141+
else:
142+
exc = Exception("Argument Session or role and region must be specified.")
143+
logger.exception(exc)
144+
raise exc
145+
146+
msg = f"Feature Group used: {feature_group_name}"
147+
logger.info(msg)
148+
149+
fg = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session)
150+
151+
sample_query = fg.athena_query()
152+
query_string = re.sub(r"#\{(table)\}", sample_query.table_name, query)
153+
154+
msg = f"Running query:\n\t{sample_query} \n\n\t-> Save on bucket {athena_bucket}\n"
155+
logger.info(msg)
156+
157+
sample_query.run(query_string=query_string, output_location=athena_bucket)
158+
159+
sample_query.wait()
160+
161+
# run Athena query. The output is loaded to a Pandas dataframe.
162+
dataset = sample_query.as_dataframe(**pandas_read_csv_kwargs)
163+
164+
msg = f"Data shape retrieve from {feature_group_name}: {dataset.shape}"
165+
logger.info(msg)
166+
167+
return dataset
168+
169+
170+
def _format_column_names(data: pandas.DataFrame) -> pandas.DataFrame:
171+
"""Formats the column names for :class:`sagemaker.feature_store.feature_group.FeatureGroup`
172+
Description:
173+
Module to format correctly the name of the columns of a DataFrame
174+
to later generate the features names of a Feature Group
175+
Args:
176+
data (pandas.DataFrame): dataframe used
177+
Returns:
178+
pandas.DataFrame
179+
"""
180+
data.rename(columns=lambda x: x.replace(" ", "_").replace(".", "").lower()[:62], inplace=True)
181+
return data
182+
183+
184+
def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame:
185+
"""Cast properly pandas object types to strings
186+
Method to convert 'object' and 'O' column dtypes of a pandas.DataFrame to
187+
a valid string type recognized by Feature Groups.
188+
Args:
189+
data_frame: dataframe used
190+
Returns:
191+
pandas.DataFrame
192+
"""
193+
for label in data_frame.select_dtypes(["object", "O"]).columns.tolist():
194+
data_frame[label] = data_frame[label].astype("str").astype("string")
195+
return data_frame
196+
197+
198+
def prepare_fg_from_dataframe_or_file(
199+
dataframe_or_path: Union[str, Path, pandas.DataFrame],
200+
feature_group_name: str,
201+
role: str = None,
202+
region: str = None,
203+
session=None,
204+
record_id: str = "record_id",
205+
event_id: str = "data_as_of_date",
206+
verbose: bool = False,
207+
**pandas_read_csv_kwargs
208+
) -> FeatureGroup:
209+
"""Prepares a dataframe to create a :class:`sagemaker.feature_store.feature_group.FeatureGroup`
210+
Description:
211+
Function to prepare a dataframe for creating a Feature Group from a pandas.DataFrame
212+
or a path to a file with proper dtypes, feature names and mandatory features (record_id,
213+
event_id). It needs the sagemaker.Session linked to a role or the role and region used
214+
to work Feature Stores. If record_id or event_id are not specified it will create ones
215+
by default with the names 'record_id' and 'data_as_of_date'.
216+
Args:
217+
**pandas_read_csv_kwargs (object):
218+
feature_group_name (str): feature group name
219+
dataframe_or_path (str, Path, pandas.DataFrame) : pandas.DataFrame or path to the data
220+
verbose (bool) : True for displaying messages, False for silent method.
221+
record_id (str, 'record_id'): (Optional) Feature identifier of the rows. If specified each
222+
value of that feature has to be unique. If not specified or
223+
record_id='record_id', then it will create a new feature from
224+
the index of the pandas.DataFrame.
225+
event_id (str) : (Optional) Feature with the time of the creation of data rows.
226+
If not specified it will create one with the current time
227+
called `data_as_of_date`
228+
role (str) : role used to get the session.
229+
region (str) : region used to get the session.
230+
session (str): session of SageMaker used to work with the feature store
231+
Returns:
232+
:class:`sagemaker.feature_store.feature_group.FeatureGroup`: FG prepared with all
233+
the methods and definitions properly defined
234+
"""
235+
236+
logger.setLevel(logging.WARNING)
237+
if verbose:
238+
logger.setLevel(logging.INFO)
239+
240+
if isinstance(dataframe_or_path, DataFrame):
241+
data = dataframe_or_path
242+
elif isinstance(dataframe_or_path, str):
243+
pandas_read_csv_kwargs.pop("filepath_or_buffer", None)
244+
data = read_csv(filepath_or_buffer=dataframe_or_path, **pandas_read_csv_kwargs)
245+
else:
246+
exc = Exception(
247+
str(
248+
f"Invalid type {type(dataframe_or_path)} for "
249+
"argument dataframe_or_path. \nParameter must be"
250+
" of type pandas.DataFrame or string"
251+
)
252+
)
253+
logger.exception(exc)
254+
raise exc
255+
256+
# Formatting cols
257+
data = _format_column_names(data=data)
258+
data = _cast_object_to_string(data_frame=data)
259+
260+
if record_id == "record_id" and record_id not in data.columns:
261+
data[record_id] = data.index
262+
263+
lg_uniq = len(data[record_id].unique())
264+
lg_id = len(data[record_id])
265+
266+
if lg_id != lg_uniq:
267+
exc = Exception(
268+
str(
269+
f"Record identifier {record_id} have {abs(lg_id - lg_uniq)} "
270+
"duplicated rows. \nRecord identifier must be unique"
271+
" in each row."
272+
)
273+
)
274+
logger.exception(exc)
275+
raise exc
276+
277+
if event_id not in data.columns:
278+
import time
279+
280+
current_time_sec = int(round(time.time()))
281+
data[event_id] = Series([current_time_sec] * lg_id, dtype="float64")
282+
283+
if session is not None:
284+
sagemaker_session = session
285+
elif role is not None and region is not None:
286+
sagemaker_session = get_session_from_role(region=region)
287+
else:
288+
exc = Exception("Argument Session or role and region must be specified.")
289+
logger.exception(exc)
290+
raise exc
291+
292+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session)
293+
294+
feature_group.load_feature_definitions(data_frame=data)
295+
296+
return feature_group

0 commit comments

Comments
 (0)