17
17
18
18
import pytest
19
19
from mock import ANY , MagicMock , Mock , patch
20
- from typing import List , NamedTuple , Optional , Union
20
+ from typing import Any , Dict , List , NamedTuple , Optional , Union
21
21
22
22
from sagemaker import Processor , image_uris
23
23
from sagemaker .clarify import (
@@ -1283,9 +1283,10 @@ def test_shap_config_no_parameters():
1283
1283
class AsymmetricShapleyValueConfigCase (NamedTuple ):
1284
1284
direction : str
1285
1285
granularity : str
1286
- num_samples : Optional [int ]
1287
- error : Exception
1288
- error_message : str
1286
+ num_samples : Optional [int ] = None
1287
+ baseline : Optional [Union [str , Dict [str , Any ]]] = None
1288
+ error : Exception = None
1289
+ error_message : str = None
1289
1290
1290
1291
1291
1292
class TestAsymmetricShapleyValueConfig :
@@ -1296,22 +1297,28 @@ class TestAsymmetricShapleyValueConfig:
1296
1297
direction = direction ,
1297
1298
granularity = "timewise" ,
1298
1299
num_samples = None ,
1299
- error = None ,
1300
- error_message = None ,
1301
1300
)
1302
1301
for direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS
1303
1302
]
1304
1303
+ [
1305
- AsymmetricShapleyValueConfigCase ( # cases for fine_grained granularity
1304
+ AsymmetricShapleyValueConfigCase ( # case for fine_grained granularity
1306
1305
direction = "chronological" ,
1307
1306
granularity = "fine_grained" ,
1308
1307
num_samples = 1 ,
1309
- error = None ,
1310
- error_message = None ,
1311
- )
1308
+ ),
1309
+ AsymmetricShapleyValueConfigCase ( # case for target time series baseline
1310
+ direction = "chronological" ,
1311
+ granularity = "timewise" ,
1312
+ baseline = {"target_time_series" : "mean" },
1313
+ ),
1314
+ AsymmetricShapleyValueConfigCase ( # case for related time series baseline
1315
+ direction = "chronological" ,
1316
+ granularity = "timewise" ,
1317
+ baseline = {"related_time_series" : "zero" },
1318
+ ),
1312
1319
],
1313
1320
)
1314
- def test_asymmetric_shapley_value_config (self , test_case ):
1321
+ def test_asymmetric_shapley_value_config (self , test_case : AsymmetricShapleyValueConfigCase ):
1315
1322
"""
1316
1323
GIVEN valid arguments for an AsymmetricShapleyValueConfig object
1317
1324
WHEN AsymmetricShapleyValueConfig object is instantiated with those arguments
@@ -1325,11 +1332,14 @@ def test_asymmetric_shapley_value_config(self, test_case):
1325
1332
}
1326
1333
if test_case .granularity == "fine_grained" :
1327
1334
expected_config ["num_samples" ] = test_case .num_samples
1335
+ if test_case .baseline :
1336
+ expected_config ["baseline" ] = test_case .baseline
1328
1337
# WHEN
1329
1338
asym_shap_val_config = AsymmetricShapleyValueConfig (
1330
1339
direction = test_case .direction ,
1331
1340
granularity = test_case .granularity ,
1332
1341
num_samples = test_case .num_samples ,
1342
+ baseline = test_case .baseline ,
1333
1343
)
1334
1344
# THEN
1335
1345
assert asym_shap_val_config .asymmetric_shapley_value_config == expected_config
@@ -1380,6 +1390,20 @@ def test_asymmetric_shapley_value_config(self, test_case):
1380
1390
error = AssertionError ,
1381
1391
error_message = "not supported together." ,
1382
1392
),
1393
+ AsymmetricShapleyValueConfigCase ( # case for unsupported target time series baseline value
1394
+ direction = "chronological" ,
1395
+ granularity = "timewise" ,
1396
+ baseline = {"target_time_series" : "median" },
1397
+ error = AssertionError ,
1398
+ error_message = "for ``target_time_series`` is invalid." ,
1399
+ ),
1400
+ AsymmetricShapleyValueConfigCase ( # case for unsupported related time series baseline value
1401
+ direction = "chronological" ,
1402
+ granularity = "timewise" ,
1403
+ baseline = {"related_time_series" : "mode" },
1404
+ error = AssertionError ,
1405
+ error_message = "for ``related_time_series`` is invalid." ,
1406
+ ),
1383
1407
],
1384
1408
)
1385
1409
def test_asymmetric_shapley_value_config_invalid (self , test_case ):
@@ -1394,6 +1418,7 @@ def test_asymmetric_shapley_value_config_invalid(self, test_case):
1394
1418
direction = test_case .direction ,
1395
1419
granularity = test_case .granularity ,
1396
1420
num_samples = test_case .num_samples ,
1421
+ baseline = test_case .baseline ,
1397
1422
)
1398
1423
1399
1424
0 commit comments