327
327
),
328
328
),
329
329
SchemaOptional ("num_samples" ): int ,
330
+ SchemaOptional ("baseline" ): Or (
331
+ str ,
332
+ {
333
+ SchemaOptional ("target_ts" , default = "zero" ): str ,
334
+ SchemaOptional ("related_ts" ): str ,
335
+ SchemaOptional ("static_covariates" ): [Or (str , int , float )],
336
+ },
337
+ ),
330
338
},
331
339
},
332
340
SchemaOptional ("predictor" ): {
@@ -1661,6 +1669,7 @@ def __init__(
1661
1669
"fine_grained" ,
1662
1670
] = ASYM_SHAP_VAL_DEFAULT_EXPLANATION_GRANULARITY ,
1663
1671
num_samples : Optional [int ] = None ,
1672
+ baseline : Optional [Union [str , Dict [str , Any ]]] = None ,
1664
1673
):
1665
1674
"""Initialises config for time series explainability with Asymmetric Shapley Values.
1666
1675
@@ -1675,6 +1684,8 @@ def __init__(
1675
1684
num_samples (None or int): Number of samples to be used in the Asymmetric Shapley
1676
1685
Value forecasting algorithm. Only applicable when using ``"fine_grained"``
1677
1686
explanations.
1687
+ baseline (str or dict): Link to a baseline configuration or a dictionary for it.
1688
+ # TODO: improve above.
1678
1689
1679
1690
Raises:
1680
1691
AssertionError: when ``direction`` or ``granularity`` are not valid,
@@ -1707,6 +1718,8 @@ def __init__(
1707
1718
_set (
1708
1719
num_samples , "num_samples" , self .asymmetric_shapley_value_config
1709
1720
) # _set() does nothing if a given argument is None
1721
+ # TODO: add sdk-side validation to baseline
1722
+ _set (baseline , "baseline" , self .asymmetric_shapley_value_config )
1710
1723
1711
1724
def get_explainability_config (self ):
1712
1725
"""Returns an asymmetric shap config dictionary."""
0 commit comments