Skip to content

Commit 90bafe0

Browse files
feat: add (early version of) baseline config to asym shap val config
1 parent a849d1c commit 90bafe0

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/sagemaker/clarify.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,14 @@
327327
),
328328
),
329329
SchemaOptional("num_samples"): int,
330+
SchemaOptional("baseline"): Or(
331+
str,
332+
{
333+
SchemaOptional("target_ts", default="zero"): str,
334+
SchemaOptional("related_ts"): str,
335+
SchemaOptional("static_covariates"): [Or(str, int, float)],
336+
},
337+
),
330338
},
331339
},
332340
SchemaOptional("predictor"): {
@@ -1661,6 +1669,7 @@ def __init__(
16611669
"fine_grained",
16621670
] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY,
16631671
num_samples: Optional[int] = None,
1672+
baseline: Optional[Union[str, Dict[str, Any]]] = None,
16641673
):
16651674
"""Initialises config for time series explainability with Asymmetric Shapley Values.
16661675
@@ -1675,6 +1684,8 @@ def __init__(
16751684
num_samples (None or int): Number of samples to be used in the Asymmetric Shapley
16761685
Value forecasting algorithm. Only applicable when using ``"fine_grained"``
16771686
explanations.
1687+
baseline (str or dict): Link to a baseline configuration or a dictionary for it.
1688+
# TODO: improve above.
16781689
16791690
Raises:
16801691
AssertionError: when ``direction`` or ``granularity`` are not valid,
@@ -1707,6 +1718,8 @@ def __init__(
17071718
_set(
17081719
num_samples, "num_samples", self.asymmetric_shapley_value_config
17091720
) # _set() does nothing if a given argument is None
1721+
# TODO: add sdk-side validation to baseline
1722+
_set(baseline, "baseline", self.asymmetric_shapley_value_config)
17101723

17111724
def get_explainability_config(self):
17121725
"""Returns an asymmetric shap config dictionary."""

0 commit comments

Comments
 (0)