Skip to content

Commit bbeba31

Browse files
feat: validation for asymmetric shapley value config baseline
doc: fix baseline doc to be sphinx-compliant
1 parent ef9659f commit bbeba31

File tree

2 files changed

+72
-16
lines changed

2 files changed

+72
-16
lines changed

src/sagemaker/clarify.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,24 @@
332332
SchemaOptional("baseline"): Or(
333333
str,
334334
{
335-
SchemaOptional("target_time_series", default="zero"): str,
336-
SchemaOptional("related_time_series"): str,
335+
SchemaOptional("target_time_series", default="zero"): And(
336+
str,
337+
Use(str.lower),
338+
lambda s: s
339+
in (
340+
"zero",
341+
"mean",
342+
),
343+
),
344+
SchemaOptional("related_time_series"): And(
345+
str,
346+
Use(str.lower),
347+
lambda s: s
348+
in (
349+
"zero",
350+
"mean",
351+
),
352+
),
337353
SchemaOptional("static_covariates"): {Or(str, int): [Or(str, int, float)]},
338354
},
339355
),
@@ -1769,13 +1785,14 @@ def __init__(
17691785
(static covariates), a baseline value for each covariate should be provided for
17701786
each possible item_id. An example config follows, where ``item1`` and ``item2``
17711787
are item ids.::
1788+
17721789
{
1790+
"target_time_series": "zero",
17731791
"related_time_series": "zero",
17741792
"static_covariates": {
17751793
"item1": [1, 1],
17761794
"item2": [0, 1]
1777-
},
1778-
"target_time_series": "zero"
1795+
}
17791796
}
17801797
17811798
Raises:
@@ -1803,13 +1820,27 @@ def __init__(
18031820
), f"{direction} and {granularity} granularity are not supported together."
18041821
elif num_samples: # validate num_samples is not provided when unnecessary
18051822
raise ValueError("``num_samples`` is only used for fine-grained explanations.")
1823+
# validate baseline if provided as a dictionary
1824+
if isinstance(baseline, dict):
1825+
temporal_baselines = ["zero", "mean"] # possible baseline options for temporal fields
1826+
if "target_time_series" in baseline:
1827+
target_baseline = baseline.get("target_time_series")
1828+
assert target_baseline in temporal_baselines, (
1829+
f"Provided value {target_baseline} for ``target_time_series`` is "
1830+
f"invalid. Please select one of {temporal_baselines}."
1831+
)
1832+
if "related_time_series" in baseline:
1833+
related_baseline = baseline.get("related_time_series")
1834+
assert related_baseline in temporal_baselines, (
1835+
f"Provided value {related_baseline} for ``related_time_series`` is "
1836+
f"invalid. Please select one of {temporal_baselines}."
1837+
)
18061838
# set explanation type and (if provided) num_samples in internal config dictionary
18071839
_set(direction, "direction", self.asymmetric_shapley_value_config)
18081840
_set(granularity, "granularity", self.asymmetric_shapley_value_config)
18091841
_set(
18101842
num_samples, "num_samples", self.asymmetric_shapley_value_config
18111843
) # _set() does nothing if a given argument is None
1812-
# TODO: add sdk-side validation to baseline
18131844
_set(baseline, "baseline", self.asymmetric_shapley_value_config)
18141845

18151846
def get_explainability_config(self):

tests/unit/test_clarify.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919
from mock import ANY, MagicMock, Mock, patch
20-
from typing import List, NamedTuple, Optional, Union
20+
from typing import Any, Dict, List, NamedTuple, Optional, Union
2121

2222
from sagemaker import Processor, image_uris
2323
from sagemaker.clarify import (
@@ -1283,9 +1283,10 @@ def test_shap_config_no_parameters():
12831283
class AsymmetricShapleyValueConfigCase(NamedTuple):
12841284
direction: str
12851285
granularity: str
1286-
num_samples: Optional[int]
1287-
error: Exception
1288-
error_message: str
1286+
num_samples: Optional[int] = None
1287+
baseline: Optional[Union[str, Dict[str, Any]]] = None
1288+
error: Exception = None
1289+
error_message: str = None
12891290

12901291

12911292
class TestAsymmetricShapleyValueConfig:
@@ -1296,22 +1297,28 @@ class TestAsymmetricShapleyValueConfig:
12961297
direction=direction,
12971298
granularity="timewise",
12981299
num_samples=None,
1299-
error=None,
1300-
error_message=None,
13011300
)
13021301
for direction in ASYM_SHAP_VAL_EXPLANATION_DIRECTIONS
13031302
]
13041303
+ [
1305-
AsymmetricShapleyValueConfigCase( # cases for fine_grained granularity
1304+
AsymmetricShapleyValueConfigCase( # case for fine_grained granularity
13061305
direction="chronological",
13071306
granularity="fine_grained",
13081307
num_samples=1,
1309-
error=None,
1310-
error_message=None,
1311-
)
1308+
),
1309+
AsymmetricShapleyValueConfigCase( # case for target time series baseline
1310+
direction="chronological",
1311+
granularity="timewise",
1312+
baseline={"target_time_series": "mean"},
1313+
),
1314+
AsymmetricShapleyValueConfigCase( # case for related time series baseline
1315+
direction="chronological",
1316+
granularity="timewise",
1317+
baseline={"related_time_series": "zero"},
1318+
),
13121319
],
13131320
)
1314-
def test_asymmetric_shapley_value_config(self, test_case):
1321+
def test_asymmetric_shapley_value_config(self, test_case: AsymmetricShapleyValueConfigCase):
13151322
"""
13161323
GIVEN valid arguments for an AsymmetricShapleyValueConfig object
13171324
WHEN AsymmetricShapleyValueConfig object is instantiated with those arguments
@@ -1325,11 +1332,14 @@ def test_asymmetric_shapley_value_config(self, test_case):
13251332
}
13261333
if test_case.granularity == "fine_grained":
13271334
expected_config["num_samples"] = test_case.num_samples
1335+
if test_case.baseline:
1336+
expected_config["baseline"] = test_case.baseline
13281337
# WHEN
13291338
asym_shap_val_config = AsymmetricShapleyValueConfig(
13301339
direction=test_case.direction,
13311340
granularity=test_case.granularity,
13321341
num_samples=test_case.num_samples,
1342+
baseline=test_case.baseline,
13331343
)
13341344
# THEN
13351345
assert asym_shap_val_config.asymmetric_shapley_value_config == expected_config
@@ -1380,6 +1390,20 @@ def test_asymmetric_shapley_value_config(self, test_case):
13801390
error=AssertionError,
13811391
error_message="not supported together.",
13821392
),
1393+
AsymmetricShapleyValueConfigCase( # case for unsupported target time series baseline value
1394+
direction="chronological",
1395+
granularity="timewise",
1396+
baseline={"target_time_series": "median"},
1397+
error=AssertionError,
1398+
error_message="for ``target_time_series`` is invalid.",
1399+
),
1400+
AsymmetricShapleyValueConfigCase( # case for unsupported related time series baseline value
1401+
direction="chronological",
1402+
granularity="timewise",
1403+
baseline={"related_time_series": "mode"},
1404+
error=AssertionError,
1405+
error_message="for ``related_time_series`` is invalid.",
1406+
),
13831407
],
13841408
)
13851409
def test_asymmetric_shapley_value_config_invalid(self, test_case):
@@ -1394,6 +1418,7 @@ def test_asymmetric_shapley_value_config_invalid(self, test_case):
13941418
direction=test_case.direction,
13951419
granularity=test_case.granularity,
13961420
num_samples=test_case.num_samples,
1421+
baseline=test_case.baseline,
13971422
)
13981423

13991424

0 commit comments

Comments
 (0)