Skip to content

Commit d4e1759

Browse files
icywang86ruiRui Wang Napieralski
andauthored
fix: Allow online store only FeatureGroups (#2015)
Co-authored-by: Rui Wang Napieralski <[email protected]>
1 parent 36b5f95 commit d4e1759

File tree

3 files changed

+80
-22
lines changed

3 files changed

+80
-22
lines changed

src/sagemaker/feature_store/feature_group.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class FeatureGroup:
290290

291291
def create(
292292
self,
293-
s3_uri: str,
293+
s3_uri: Union[str, bool],
294294
record_identifier_name: str,
295295
event_time_feature_name: str,
296296
role_arn: str,
@@ -305,7 +305,8 @@ def create(
305305
"""Create a SageMaker FeatureStore FeatureGroup.
306306
307307
Args:
308-
s3_uri (str): S3 URI of the offline store.
308+
s3_uri (Union[str, bool]): S3 URI of the offline store, set to
309+
``False`` to disable offline store.
309310
record_identifier_name (str): name of the record identifier feature.
310311
event_time_feature_name (str): name of the event time feature.
311312
role_arn (str): ARN of the role used to call CreateFeatureGroup.
@@ -342,15 +343,18 @@ def create(
342343
create_feature_store_args.update({"online_store_config": online_store_config.to_dict()})
343344

344345
# offline store configuration
345-
s3_storage_config = S3StorageConfig(s3_uri=s3_uri)
346-
if offline_store_kms_key_id:
347-
s3_storage_config.kms_key_id = offline_store_kms_key_id
348-
offline_store_config = OfflineStoreConfig(
349-
s3_storage_config=s3_storage_config,
350-
disable_glue_table_creation=disable_glue_table_creation,
351-
data_catalog_config=data_catalog_config,
352-
)
353-
create_feature_store_args.update({"offline_store_config": offline_store_config.to_dict()})
346+
if s3_uri:
347+
s3_storage_config = S3StorageConfig(s3_uri=s3_uri)
348+
if offline_store_kms_key_id:
349+
s3_storage_config.kms_key_id = offline_store_kms_key_id
350+
offline_store_config = OfflineStoreConfig(
351+
s3_storage_config=s3_storage_config,
352+
disable_glue_table_creation=disable_glue_table_creation,
353+
data_catalog_config=data_catalog_config,
354+
)
355+
create_feature_store_args.update(
356+
{"offline_store_config": offline_store_config.to_dict()}
357+
)
354358

355359
return self.sagemaker_session.create_feature_group(**create_feature_store_args)
356360

@@ -367,7 +371,9 @@ def describe(self, next_token: str = None) -> Dict[str, Any]:
367371
Returns:
368372
Response dict from the service.
369373
"""
370-
return self.sagemaker_session.describe_feature_group(self.name, next_token)
374+
return self.sagemaker_session.describe_feature_group(
375+
feature_group_name=self.name, next_token=next_token
376+
)
371377

372378
def load_feature_definitions(
373379
self,

tests/integ/test_feature_store.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,28 @@ def create_table_ddl():
141141
)
142142

143143

144+
def test_create_feature_store_online_only(
145+
feature_store_session,
146+
role,
147+
feature_group_name,
148+
pandas_data_frame,
149+
):
150+
feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
151+
feature_group.load_feature_definitions(data_frame=pandas_data_frame)
152+
153+
with cleanup_feature_group(feature_group):
154+
output = feature_group.create(
155+
s3_uri=False,
156+
record_identifier_name="feature1",
157+
event_time_feature_name="feature3",
158+
role_arn=role,
159+
enable_online_store=True,
160+
)
161+
_wait_for_feature_group_create(feature_group)
162+
163+
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
164+
165+
144166
def test_create_feature_store(
145167
feature_store_session,
146168
role,

tests/unit/sagemaker/feature_store/test_feature_store.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,36 +79,66 @@ def test_feature_store_create(
7979
role_arn=role_arn,
8080
enable_online_store=True,
8181
)
82-
assert sagemaker_session_mock.create_feature_group.called_with(
82+
sagemaker_session_mock.create_feature_group.assert_called_with(
8383
feature_group_name="MyFeatureGroup",
8484
record_identifier_name="feature1",
8585
event_time_feature_name="feature2",
86+
feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions],
8687
role_arn=role_arn,
88+
description=None,
89+
tags=None,
8790
online_store_config={"EnableOnlineStore": True},
91+
offline_store_config={
92+
"DisableGlueTableCreation": False,
93+
"S3StorageConfig": {"S3Uri": s3_uri},
94+
},
95+
)
96+
97+
98+
def test_feature_store_create_online_only(
99+
sagemaker_session_mock, role_arn, feature_group_dummy_definitions
100+
):
101+
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
102+
feature_group.feature_definitions = feature_group_dummy_definitions
103+
feature_group.create(
104+
s3_uri=False,
105+
record_identifier_name="feature1",
106+
event_time_feature_name="feature2",
107+
role_arn=role_arn,
108+
enable_online_store=True,
109+
)
110+
sagemaker_session_mock.create_feature_group.assert_called_with(
111+
feature_group_name="MyFeatureGroup",
112+
record_identifier_name="feature1",
113+
event_time_feature_name="feature2",
88114
feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions],
115+
role_arn=role_arn,
116+
description=None,
117+
tags=None,
118+
online_store_config={"EnableOnlineStore": True},
89119
)
90120

91121

92122
def test_feature_store_delete(sagemaker_session_mock):
93123
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
94124
feature_group.delete()
95-
assert sagemaker_session_mock.delete_feature_group.called_with(
125+
sagemaker_session_mock.delete_feature_group.assert_called_with(
96126
feature_group_name="MyFeatureGroup"
97127
)
98128

99129

100130
def test_feature_store_describe(sagemaker_session_mock):
101131
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
102132
feature_group.describe()
103-
assert sagemaker_session_mock.describe_feature_group.called_with(
104-
feature_group_name="MyFeatureGroup"
133+
sagemaker_session_mock.describe_feature_group.assert_called_with(
134+
feature_group_name="MyFeatureGroup", next_token=None
105135
)
106136

107137

108138
def test_put_record(sagemaker_session_mock):
109139
feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
110140
feature_group.put_record(record=[])
111-
assert sagemaker_session_mock.put_record.called_with(
141+
sagemaker_session_mock.put_record.assert_called_with(
112142
feature_group_name="MyFeatureGroup", record=[]
113143
)
114144

@@ -268,7 +298,7 @@ def query(sagemaker_session_mock):
268298
def test_athena_query_run(sagemaker_session_mock, query):
269299
sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"}
270300
query.run(query_string="query", output_location="s3://some-bucket/some-path")
271-
assert sagemaker_session_mock.start_query_execution.called_with(
301+
sagemaker_session_mock.start_query_execution.assert_called_with(
272302
catalog="catalog",
273303
database="database",
274304
query_string="query",
@@ -283,13 +313,13 @@ def test_athena_query_run(sagemaker_session_mock, query):
283313
def test_athena_query_wait(sagemaker_session_mock, query):
284314
query._current_query_execution_id = "query_id"
285315
query.wait()
286-
assert sagemaker_session_mock.wait_for_athena_query.called_with(query_execution_id="query_id")
316+
sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id")
287317

288318

289319
def test_athena_query_get_query_execution(sagemaker_session_mock, query):
290320
query._current_query_execution_id = "query_id"
291321
query.get_query_execution()
292-
assert sagemaker_session_mock.wait_for_athena_query.called_with(query_execution_id="query_id")
322+
sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id")
293323

294324

295325
@patch("tempfile.gettempdir", Mock(return_value="tmp"))
@@ -302,13 +332,13 @@ def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query):
302332
query._result_bucket = "bucket"
303333
query._result_file_prefix = "prefix"
304334
query.as_dataframe()
305-
assert sagemaker_session_mock.download_athena_query_result.called_with(
335+
sagemaker_session_mock.download_athena_query_result.assert_called_with(
306336
bucket="bucket",
307337
prefix="prefix",
308338
query_execution_id="query_id",
309339
filename="tmp/query_id.csv",
310340
)
311-
assert read_csv.called_with("tmp/query_id.csv", delimiter=",")
341+
read_csv.assert_called_with("tmp/query_id.csv", delimiter=",")
312342

313343

314344
@patch("tempfile.gettempdir", Mock(return_value="tmp"))

0 commit comments

Comments
 (0)