@@ -79,36 +79,66 @@ def test_feature_store_create(
79
79
role_arn = role_arn ,
80
80
enable_online_store = True ,
81
81
)
82
- assert sagemaker_session_mock .create_feature_group .called_with (
82
+ sagemaker_session_mock .create_feature_group .assert_called_with (
83
83
feature_group_name = "MyFeatureGroup" ,
84
84
record_identifier_name = "feature1" ,
85
85
event_time_feature_name = "feature2" ,
86
+ feature_definitions = [fd .to_dict () for fd in feature_group_dummy_definitions ],
86
87
role_arn = role_arn ,
88
+ description = None ,
89
+ tags = None ,
87
90
online_store_config = {"EnableOnlineStore" : True },
91
+ offline_store_config = {
92
+ "DisableGlueTableCreation" : False ,
93
+ "S3StorageConfig" : {"S3Uri" : s3_uri },
94
+ },
95
+ )
96
+
97
+
98
+ def test_feature_store_create_online_only (
99
+ sagemaker_session_mock , role_arn , feature_group_dummy_definitions
100
+ ):
101
+ feature_group = FeatureGroup (name = "MyFeatureGroup" , sagemaker_session = sagemaker_session_mock )
102
+ feature_group .feature_definitions = feature_group_dummy_definitions
103
+ feature_group .create (
104
+ s3_uri = False ,
105
+ record_identifier_name = "feature1" ,
106
+ event_time_feature_name = "feature2" ,
107
+ role_arn = role_arn ,
108
+ enable_online_store = True ,
109
+ )
110
+ sagemaker_session_mock .create_feature_group .assert_called_with (
111
+ feature_group_name = "MyFeatureGroup" ,
112
+ record_identifier_name = "feature1" ,
113
+ event_time_feature_name = "feature2" ,
88
114
feature_definitions = [fd .to_dict () for fd in feature_group_dummy_definitions ],
115
+ role_arn = role_arn ,
116
+ description = None ,
117
+ tags = None ,
118
+ online_store_config = {"EnableOnlineStore" : True },
89
119
)
90
120
91
121
92
122
def test_feature_store_delete (sagemaker_session_mock ):
93
123
feature_group = FeatureGroup (name = "MyFeatureGroup" , sagemaker_session = sagemaker_session_mock )
94
124
feature_group .delete ()
95
- assert sagemaker_session_mock .delete_feature_group .called_with (
125
+ sagemaker_session_mock .delete_feature_group .assert_called_with (
96
126
feature_group_name = "MyFeatureGroup"
97
127
)
98
128
99
129
100
130
def test_feature_store_describe (sagemaker_session_mock ):
101
131
feature_group = FeatureGroup (name = "MyFeatureGroup" , sagemaker_session = sagemaker_session_mock )
102
132
feature_group .describe ()
103
- assert sagemaker_session_mock .describe_feature_group .called_with (
104
- feature_group_name = "MyFeatureGroup"
133
+ sagemaker_session_mock .describe_feature_group .assert_called_with (
134
+ feature_group_name = "MyFeatureGroup" , next_token = None
105
135
)
106
136
107
137
108
138
def test_put_record (sagemaker_session_mock ):
109
139
feature_group = FeatureGroup (name = "MyFeatureGroup" , sagemaker_session = sagemaker_session_mock )
110
140
feature_group .put_record (record = [])
111
- assert sagemaker_session_mock .put_record .called_with (
141
+ sagemaker_session_mock .put_record .assert_called_with (
112
142
feature_group_name = "MyFeatureGroup" , record = []
113
143
)
114
144
@@ -268,7 +298,7 @@ def query(sagemaker_session_mock):
268
298
def test_athena_query_run (sagemaker_session_mock , query ):
269
299
sagemaker_session_mock .start_query_execution .return_value = {"QueryExecutionId" : "query_id" }
270
300
query .run (query_string = "query" , output_location = "s3://some-bucket/some-path" )
271
- assert sagemaker_session_mock .start_query_execution .called_with (
301
+ sagemaker_session_mock .start_query_execution .assert_called_with (
272
302
catalog = "catalog" ,
273
303
database = "database" ,
274
304
query_string = "query" ,
@@ -283,13 +313,13 @@ def test_athena_query_run(sagemaker_session_mock, query):
283
313
def test_athena_query_wait (sagemaker_session_mock , query ):
284
314
query ._current_query_execution_id = "query_id"
285
315
query .wait ()
286
- assert sagemaker_session_mock .wait_for_athena_query .called_with (query_execution_id = "query_id" )
316
+ sagemaker_session_mock .wait_for_athena_query .assert_called_with (query_execution_id = "query_id" )
287
317
288
318
289
319
def test_athena_query_get_query_execution (sagemaker_session_mock , query ):
290
320
query ._current_query_execution_id = "query_id"
291
321
query .get_query_execution ()
292
- assert sagemaker_session_mock .wait_for_athena_query . called_with (query_execution_id = "query_id" )
322
+ sagemaker_session_mock .get_query_execution . assert_called_with (query_execution_id = "query_id" )
293
323
294
324
295
325
@patch ("tempfile.gettempdir" , Mock (return_value = "tmp" ))
@@ -302,13 +332,13 @@ def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query):
302
332
query ._result_bucket = "bucket"
303
333
query ._result_file_prefix = "prefix"
304
334
query .as_dataframe ()
305
- assert sagemaker_session_mock .download_athena_query_result .called_with (
335
+ sagemaker_session_mock .download_athena_query_result .assert_called_with (
306
336
bucket = "bucket" ,
307
337
prefix = "prefix" ,
308
338
query_execution_id = "query_id" ,
309
339
filename = "tmp/query_id.csv" ,
310
340
)
311
- assert read_csv .called_with ("tmp/query_id.csv" , delimiter = "," )
341
+ read_csv .assert_called_with ("tmp/query_id.csv" , delimiter = "," )
312
342
313
343
314
344
@patch ("tempfile.gettempdir" , Mock (return_value = "tmp" ))
0 commit comments