Skip to content

Commit 2c1170f

Browse files
committed
Added support for feature group schema change and feature parameters
1 parent 0bb007c commit 2c1170f

File tree

7 files changed

+349
-2
lines changed

7 files changed

+349
-2
lines changed

src/sagemaker/feature_store/feature_group.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
OfflineStoreConfig,
5454
DataCatalogConfig,
5555
FeatureValue,
56+
FeatureParameter,
5657
)
5758

5859
logger = logging.getLogger(__name__)
@@ -537,6 +538,66 @@ def describe(self, next_token: str = None) -> Dict[str, Any]:
537538
feature_group_name=self.name, next_token=next_token
538539
)
539540

541+
def update(self, feature_additions: Sequence[FeatureDefinition]) -> Dict[str, Any]:
542+
"""Update a FeatureGroup and add new features from the given feature definitions.
543+
544+
Args:
545+
feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated.
546+
547+
Returns:
548+
Response dict from service.
549+
"""
550+
551+
return self.sagemaker_session.update_feature_group(
552+
feature_group_name=self.name,
553+
feature_additions=[
554+
feature_addition.to_dict() for feature_addition in feature_additions
555+
],
556+
)
557+
558+
def update_feature_metadata(
559+
self,
560+
feature_name: str,
561+
description: str = None,
562+
parameter_additions: Sequence[FeatureParameter] = None,
563+
parameter_removals: Sequence[str] = None,
564+
) -> Dict[str, Any]:
565+
"""Update a feature metadata and add/remove metadata.
566+
567+
Args:
568+
feature_name (str): name of the feature to update.
569+
description (str): description of the feature to update.
570+
parameter_additions (Sequence[Dict[str, str]): list of feature parameter to be added.
571+
parameter_removals (Sequence[str]): list of feature parameter key to be removed.
572+
573+
Returns:
574+
Response dict from service.
575+
"""
576+
return self.sagemaker_session.update_feature_metadata(
577+
feature_group_name=self.name,
578+
feature_name=feature_name,
579+
description=description,
580+
parameter_additions=[
581+
parameter_addition.to_dict() for parameter_addition in parameter_additions
582+
]
583+
if parameter_additions is not None
584+
else [],
585+
parameter_removals=parameter_removals if parameter_removals is not None else [],
586+
)
587+
588+
def describe_feature_metadata(self, feature_name: str) -> Dict[str, Any]:
589+
"""Describe feature metadata by feature name.
590+
591+
Args:
592+
feature_name (str): name of the feature.
593+
Returns:
594+
Response dict from service.
595+
"""
596+
597+
return self.sagemaker_session.describe_feature_metadata(
598+
feature_group_name=self.name, feature_name=feature_name
599+
)
600+
540601
def load_feature_definitions(
541602
self,
542603
data_frame: DataFrame,

src/sagemaker/feature_store/inputs.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,27 @@ def to_dict(self) -> Dict[str, Any]:
207207
FeatureName=self.feature_name,
208208
ValueAsString=self.value_as_string,
209209
)
210+
211+
212+
@attr.s
213+
class FeatureParameter(Config):
214+
"""FeatureParameter for FeatureStore.
215+
216+
Attributes:
217+
key (str): key of the parameter.
218+
value (str): value of the parameter.
219+
"""
220+
221+
key: str = attr.ib(default=None)
222+
value: str = attr.ib(default=None)
223+
224+
def to_dict(self) -> Dict[str, Any]:
225+
"""Construct a dictionary based on the attributes provided.
226+
227+
Returns:
228+
dict represents the attributes.
229+
"""
230+
return Config.construct_dict(
231+
Key=self.key,
232+
Value=self.value,
233+
)

src/sagemaker/session.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4076,7 +4076,7 @@ def describe_feature_group(
40764076
"""Describe a FeatureGroup by name in FeatureStore service.
40774077
40784078
Args:
4079-
feature_group_name (str): name of the FeatureGroup to descibe.
4079+
feature_group_name (str): name of the FeatureGroup to describe.
40804080
next_token (str): next_token to get next page of features.
40814081
Returns:
40824082
Response dict from service.
@@ -4086,6 +4086,72 @@ def describe_feature_group(
40864086
update_args(kwargs, NextToken=next_token)
40874087
return self.sagemaker_client.describe_feature_group(**kwargs)
40884088

4089+
def update_feature_group(
4090+
self, feature_group_name: str, feature_additions: Sequence[Dict[str, str]]
4091+
) -> Dict[str, Any]:
4092+
"""Update a FeatureGroup and add new features from the given feature definitions.
4093+
4094+
Args:
4095+
feature_group_name (str): name of the FeatureGroup to update.
4096+
feature_additions (Sequence[Dict[str, str]): list of feature definitions to be updated.
4097+
Returns:
4098+
Response dict from service.
4099+
"""
4100+
4101+
return self.sagemaker_client.update_feature_group(
4102+
FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions
4103+
)
4104+
4105+
def update_feature_metadata(
4106+
self,
4107+
feature_group_name: str,
4108+
feature_name: str,
4109+
description: str = None,
4110+
parameter_additions: Sequence[Dict[str, str]] = None,
4111+
parameter_removals: Sequence[str] = None,
4112+
) -> Dict[str, Any]:
4113+
"""Update a feature metadata and add/remove metadata.
4114+
4115+
Args:
4116+
feature_group_name (str): name of the FeatureGroup to update.
4117+
feature_name (str): name of the feature to update.
4118+
description (str): description of the feature to update.
4119+
parameter_additions (Sequence[Dict[str, str]): list of feature parameter to be added.
4120+
parameter_removals (Sequence[Dict[str, str]): list of feature parameter to be removed.
4121+
Returns:
4122+
Response dict from service.
4123+
"""
4124+
4125+
request = {
4126+
"FeatureGroupName": feature_group_name,
4127+
"FeatureName": feature_name,
4128+
}
4129+
4130+
if description is not None:
4131+
request["Description"] = description
4132+
if parameter_additions is not None:
4133+
request["ParameterAdditions"] = parameter_additions
4134+
if parameter_removals is not None:
4135+
request["ParameterRemovals"] = parameter_removals
4136+
4137+
return self.sagemaker_client.update_feature_metadata(**request)
4138+
4139+
def describe_feature_metadata(
4140+
self, feature_group_name: str, feature_name: str
4141+
) -> Dict[str, Any]:
4142+
"""Describe feature metadata by feature name in FeatureStore service.
4143+
4144+
Args:
4145+
feature_group_name (str): name of the FeatureGroup.
4146+
feature_name (str): name of the feature.
4147+
Returns:
4148+
Response dict from service.
4149+
"""
4150+
4151+
return self.sagemaker_client.describe_feature_metadata(
4152+
FeatureGroupName=feature_group_name, FeatureName=feature_name
4153+
)
4154+
40894155
def put_record(
40904156
self,
40914157
feature_group_name: str,

tests/integ/test_feature_store.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
import pytest
2323
from pandas import DataFrame
2424

25+
from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition
2526
from sagemaker.feature_store.feature_group import FeatureGroup
26-
from sagemaker.feature_store.inputs import FeatureValue
27+
from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter
2728
from sagemaker.session import get_execution_role, Session
2829
from tests.integ.timeout import timeout
2930

@@ -237,6 +238,83 @@ def test_create_feature_store(
237238
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
238239

239240

241+
def test_update_feature_group(
242+
feature_store_session,
243+
role,
244+
feature_group_name,
245+
offline_store_s3_uri,
246+
pandas_data_frame,
247+
):
248+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
249+
feature_group.load_feature_definitions(data_frame=pandas_data_frame)
250+
251+
with cleanup_feature_group(feature_group):
252+
feature_group.create(
253+
s3_uri=offline_store_s3_uri,
254+
record_identifier_name="feature1",
255+
event_time_feature_name="feature3",
256+
role_arn=role,
257+
enable_online_store=True,
258+
)
259+
_wait_for_feature_group_create(feature_group)
260+
261+
new_feature_name = "new_feature"
262+
new_features = [FractionalFeatureDefinition(feature_name=new_feature_name)]
263+
feature_group.update(new_features)
264+
_wait_for_feature_group_update(feature_group)
265+
feature_definitions = feature_group.describe().get("FeatureDefinitions")
266+
assert any([True for elem in feature_definitions if new_feature_name in elem.values()])
267+
268+
269+
def test_feature_metadata(
270+
feature_store_session,
271+
role,
272+
feature_group_name,
273+
offline_store_s3_uri,
274+
pandas_data_frame,
275+
):
276+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
277+
feature_group.load_feature_definitions(data_frame=pandas_data_frame)
278+
279+
with cleanup_feature_group(feature_group):
280+
feature_group.create(
281+
s3_uri=offline_store_s3_uri,
282+
record_identifier_name="feature1",
283+
event_time_feature_name="feature3",
284+
role_arn=role,
285+
enable_online_store=True,
286+
)
287+
_wait_for_feature_group_create(feature_group)
288+
289+
parameter_additions = [
290+
FeatureParameter(key="key1", value="value1"),
291+
FeatureParameter(key="key2", value="value2"),
292+
]
293+
description = "test description"
294+
feature_name = "feature1"
295+
feature_group.update_feature_metadata(
296+
feature_name=feature_name,
297+
description=description,
298+
parameter_additions=parameter_additions,
299+
)
300+
describe_feature_metadata = feature_group.describe_feature_metadata(
301+
feature_name=feature_name
302+
)
303+
print(describe_feature_metadata)
304+
assert description == describe_feature_metadata.get("Description")
305+
assert 2 == len(describe_feature_metadata.get("Parameters"))
306+
307+
parameter_removals = ["key1"]
308+
feature_group.update_feature_metadata(
309+
feature_name=feature_name, parameter_removals=parameter_removals
310+
)
311+
describe_feature_metadata = feature_group.describe_feature_metadata(
312+
feature_name=feature_name
313+
)
314+
assert description == describe_feature_metadata.get("Description")
315+
assert 1 == len(describe_feature_metadata.get("Parameters"))
316+
317+
240318
def test_ingest_without_string_feature(
241319
feature_store_session,
242320
role,
@@ -304,6 +382,18 @@ def _wait_for_feature_group_create(feature_group: FeatureGroup):
304382
print(f"FeatureGroup {feature_group.name} successfully created.")
305383

306384

385+
def _wait_for_feature_group_update(feature_group: FeatureGroup):
386+
status = feature_group.describe().get("LastUpdateStatus").get("Status")
387+
while status == "InProgress":
388+
print("Waiting for Feature Group Update")
389+
time.sleep(5)
390+
status = feature_group.describe().get("LastUpdateStatus").get("Status")
391+
if status != "Successful":
392+
print(feature_group.describe())
393+
raise RuntimeError(f"Failed to update feature group {feature_group.name}")
394+
print(f"FeatureGroup {feature_group.name} successfully updated.")
395+
396+
307397
@contextmanager
308398
def cleanup_feature_group(feature_group: FeatureGroup):
309399
try:

tests/unit/sagemaker/feature_store/test_feature_store.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
AthenaQuery,
3232
IngestionError,
3333
)
34+
from sagemaker.feature_store.inputs import FeatureParameter
3435

3536

3637
class PicklableMock(Mock):
@@ -154,6 +155,52 @@ def test_feature_store_describe(sagemaker_session_mock):
154155
)
155156

156157

158+
def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions):
159+
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
160+
feature_group.update(feature_group_dummy_definitions)
161+
sagemaker_session_mock.update_feature_group.assert_called_with(
162+
feature_group_name="MyFeatureGroup",
163+
feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions],
164+
)
165+
166+
167+
def test_feature_metadata_update(sagemaker_session_mock):
168+
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
169+
170+
parameter_additions = [FeatureParameter(key="key1", value="value1")]
171+
parameter_removals = ["key2"]
172+
173+
feature_group.update_feature_metadata(
174+
feature_name="Feature1",
175+
description="TestDescription",
176+
parameter_additions=parameter_additions,
177+
parameter_removals=parameter_removals,
178+
)
179+
sagemaker_session_mock.update_feature_metadata.assert_called_with(
180+
feature_group_name="MyFeatureGroup",
181+
feature_name="Feature1",
182+
description="TestDescription",
183+
parameter_additions=[pa.to_dict() for pa in parameter_additions],
184+
parameter_removals=parameter_removals,
185+
)
186+
feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription")
187+
sagemaker_session_mock.update_feature_metadata.assert_called_with(
188+
feature_group_name="MyFeatureGroup",
189+
feature_name="Feature1",
190+
description="TestDescription",
191+
parameter_additions=[],
192+
parameter_removals=[],
193+
)
194+
195+
196+
def test_feature_metadata_describe(sagemaker_session_mock):
197+
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
198+
feature_group.describe_feature_metadata(feature_name="Feature1")
199+
sagemaker_session_mock.describe_feature_metadata.assert_called_with(
200+
feature_group_name="MyFeatureGroup", feature_name="Feature1"
201+
)
202+
203+
157204
def test_put_record(sagemaker_session_mock):
158205
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
159206
feature_group.put_record(record=[])

tests/unit/sagemaker/feature_store/test_inputs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
S3StorageConfig,
2020
DataCatalogConfig,
2121
OfflineStoreConfig,
22+
FeatureParameter,
2223
)
2324

2425

@@ -83,3 +84,8 @@ def test_offline_data_store_config():
8384
"DisableGlueTableCreation": False,
8485
}
8586
)
87+
88+
89+
def test_feature_metadata():
90+
config = FeatureParameter(key="key", value="value")
91+
assert ordered(config.to_dict()) == ordered({"Key": "key", "Value": "value"})

0 commit comments

Comments
 (0)