Skip to content

Commit a8ad1d1

Browse files
feat: add validation for static covariates in tsx baseline
1 parent 6d1312f commit a8ad1d1

File tree

2 files changed

+223
-9
lines changed

2 files changed

+223
-9
lines changed

src/sagemaker/clarify.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ class DatasetType(Enum):
408408
class TimeSeriesJSONDatasetFormat(Enum):
409409
"""Possible dataset formats for JSON time series data files.
410410
411-
Below is an example ``COLUMNS`` dataset for time series explainability.::
411+
Below is an example ``COLUMNS`` dataset for time series explainability::
412412
413413
{
414414
"ids": [1, 2],
@@ -420,15 +420,15 @@ class TimeSeriesJSONDatasetFormat(Enum):
420420
"scv2": [30, 40]
421421
}
422422
423-
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.::
423+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
424424
425425
item_id="ids"
426426
timestamp="timestamps"
427427
target_time_series="target_ts"
428428
related_time_series=["rts1", "rts2"]
429429
static_covariates=["scv1", "scv2"]
430430
431-
Below is an example ``ITEM_RECORDS`` dataset for time series explainability.::
431+
Below is an example ``ITEM_RECORDS`` dataset for time series explainability::
432432
433433
[
434434
{
@@ -452,15 +452,15 @@ class TimeSeriesJSONDatasetFormat(Enum):
452452
}
453453
]
454454
455-
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.::
455+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
456456
457457
item_id="[*].id"
458458
timestamp="[*].timeseries[].timestamp"
459459
target_time_series="[*].timeseries[].target_ts"
460460
related_time_series=["[*].timeseries[].rts1", "[*].timeseries[].rts2"]
461461
static_covariates=["[*].scv1", "[*].scv2"]
462462
463-
Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability.::
463+
Below is an example ``TIMESTAMP_RECORDS`` dataset for time series explainability::
464464
465465
[
466466
{"id": 1, "timestamp": 1, "target_ts": 5, "scv1": 10, "rts1": 0.25},
@@ -469,7 +469,7 @@ class TimeSeriesJSONDatasetFormat(Enum):
469469
{"id": 2, "timestamp": 5, "target_ts": 10, "scv1": 20, "rts1": 1}
470470
]
471471
472-
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows.::
472+
For this example, JMESPaths are specified when creating ``TimeSeriesDataConfig`` as follows::
473473
474474
item_id="[*].id"
475475
timestamp="[*].timestamp"
@@ -1784,7 +1784,7 @@ def __init__(
17841784
values will be replaced with the average of a time series. For static data
17851785
(static covariates), a baseline value for each covariate should be provided for
17861786
each possible item_id. An example config follows, where ``item1`` and ``item2``
1787-
are item ids.::
1787+
are item ids::
17881788
17891789
{
17901790
"target_time_series": "zero",
@@ -2548,14 +2548,16 @@ def explainability(
25482548
explainability_config: Union[ExplainabilityConfig, List[ExplainabilityConfig]],
25492549
):
25502550
"""Generates a config for Explainability"""
2551-
# determine if this is a timeseries explainability case by checking
2551+
# determine if this is a time series explainability case by checking
25522552
# if *both* TimeSeriesDataConfig and TimeSeriesModelConfig were given
25532553
ts_data_config_present = "time_series_data_config" in data_config.analysis_config
25542554
ts_model_config_present = "time_series_predictor_config" in model_config.predictor_config
25552555

25562556
if isinstance(explainability_config, AsymmetricShapleyValueConfig):
25572557
assert ts_data_config_present, "Please provide a TimeSeriesDataConfig to DataConfig."
25582558
assert ts_model_config_present, "Please provide a TimeSeriesModelConfig to ModelConfig."
2559+
# Check static covariates baseline matches number of provided static covariate columns
2560+
25592561
else:
25602562
if ts_data_config_present:
25612563
raise ValueError(
@@ -2759,6 +2761,44 @@ def _merge_explainability_configs(
27592761
return explainability_methods
27602762
return explainability_config.get_explainability_config()
27612763

2764+
@classmethod
2765+
def _validate_time_series_static_covariates_baseline(
2766+
cls,
2767+
explainability_config: AsymmetricShapleyValueConfig,
2768+
data_config: DataConfig,
2769+
):
2770+
"""Validates static covariates in baseline for asymmetric shapley value (for time series).
2771+
2772+
Checks that baseline values set for static covariate columns are
2773+
consistent between every item_id and the number of static covariate columns
2774+
provided in DataConfig.
2775+
"""
2776+
baseline = explainability_config.get_explainability_config()[
2777+
"asymmetric_shapley_value"
2778+
].get("baseline")
2779+
if baseline and "static_covariates" in baseline:
2780+
covariate_count = len(
2781+
data_config.get_config()["time_series_data_config"].get("static_covariates", [])
2782+
)
2783+
if covariate_count > 0:
2784+
for item_id in baseline.get("static_covariates", []):
2785+
baseline_entry = baseline["static_covariates"][item_id]
2786+
assert isinstance(baseline_entry, list), (
2787+
f"Baseline entry for {item_id} must be a list, is "
2788+
f"{type(baseline_entry)}."
2789+
)
2790+
assert len(baseline_entry) == covariate_count, (
2791+
f"Length of baseline entry for {item_id} does not match number "
2792+
f"of static covariate columns. Please ensure every covariate "
2793+
f"has a baseline value for every item id."
2794+
)
2795+
else:
2796+
raise ValueError(
2797+
"Static covariate baselines are provided in AsymmetricShapleyValueConfig "
2798+
"when no static covariate columns are provided in TimeSeriesDataConfig. "
2799+
"Please check these configs."
2800+
)
2801+
27622802

27632803
def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key):
27642804
"""Uploads the local ``analysis_config_file`` to the ``s3_output_path``.

tests/unit/test_clarify.py

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import copy
1717

1818
import pytest
19+
from dataclasses import dataclass
1920
from mock import ANY, MagicMock, Mock, patch
2021
from typing import Any, Dict, List, NamedTuple, Optional, Union
2122

@@ -2574,7 +2575,8 @@ def _build_pdp_config_mock():
25742575

25752576
def _build_asymmetric_shapley_value_config_mock():
25762577
asym_shap_val_config_dict = {
2577-
"explanation_type": "fine_grained",
2578+
"direction": "chronological",
2579+
"granularity": "fine_grained",
25782580
"num_samples": 20,
25792581
}
25802582
asym_shap_val_config = Mock(spec=AsymmetricShapleyValueConfig)
@@ -2613,6 +2615,14 @@ def _build_model_config_mock():
26132615
return model_config
26142616

26152617

2618+
@dataclass
2619+
class ValidateTSXBaselineCase:
2620+
explainability_config: AsymmetricShapleyValueConfig
2621+
data_config: DataConfig
2622+
error: Optional[Exception] = None
2623+
error_msg: Optional[str] = None
2624+
2625+
26162626
class TestAnalysisConfigGeneratorForTimeSeriesExplainability:
26172627
@patch("sagemaker.clarify._AnalysisConfigGenerator._add_methods")
26182628
@patch("sagemaker.clarify._AnalysisConfigGenerator._add_predictor")
@@ -2794,6 +2804,170 @@ def test_merge_explainability_configs_with_timeseries_invalid(
27942804
explainability_config=mock_config,
27952805
)
27962806

2807+
@pytest.mark.parametrize(
2808+
"case",
2809+
[
2810+
ValidateTSXBaselineCase(
2811+
explainability_config=AsymmetricShapleyValueConfig(
2812+
direction="chronological",
2813+
granularity="timewise",
2814+
baseline={
2815+
"target_time_series": "zero",
2816+
"related_time_series": "zero",
2817+
"static_covariates": {
2818+
"item1": [0.0, 0.5, 1.0],
2819+
"item2": [0.3, 0.6, 0.9],
2820+
"item3": [0.0, 1.0, 1.0],
2821+
"item4": [0.9, 0.6, 0.3],
2822+
"item5": [1.0, 0.5, 0.0],
2823+
},
2824+
},
2825+
),
2826+
data_config=DataConfig(
2827+
s3_data_input_path="s3://data/input",
2828+
s3_output_path="s3://data/output",
2829+
headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"],
2830+
dataset_type="application/json",
2831+
time_series_data_config=TimeSeriesDataConfig(
2832+
item_id="[].id",
2833+
timestamp="[].temporal[].timestamp",
2834+
target_time_series="[].temporal[].target",
2835+
related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"],
2836+
static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"],
2837+
dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS,
2838+
),
2839+
),
2840+
),
2841+
],
2842+
)
2843+
def test_time_series_baseline_valid_static_covariates(self, case: ValidateTSXBaselineCase):
2844+
"""
2845+
GIVEN AsymmetricShapleyValueConfig and TimeSeriesDataConfig are created and a baseline
2846+
is provided
2847+
WHEN AnalysisConfigGenerator._validate_time_series_static_covariates_baseline() is called
2848+
THEN no error is raised
2849+
"""
2850+
_AnalysisConfigGenerator._validate_time_series_static_covariates_baseline(
2851+
explainability_config=case.explainability_config,
2852+
data_config=case.data_config,
2853+
)
2854+
2855+
@pytest.mark.parametrize(
2856+
"case",
2857+
[
2858+
ValidateTSXBaselineCase( # some item ids are missing baseline values
2859+
explainability_config=AsymmetricShapleyValueConfig(
2860+
direction="chronological",
2861+
granularity="timewise",
2862+
baseline={
2863+
"target_time_series": "zero",
2864+
"related_time_series": "zero",
2865+
"static_covariates": {
2866+
"item1": [0.0, 0.5, 1.0],
2867+
"item2": [0.3, 0.6, 0.9],
2868+
"item3": [0.0],
2869+
"item4": [0.9, 0.6, 0.3],
2870+
"item5": [1.0],
2871+
},
2872+
},
2873+
),
2874+
data_config=DataConfig(
2875+
s3_data_input_path="s3://data/input",
2876+
s3_output_path="s3://data/output",
2877+
headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"],
2878+
dataset_type="application/json",
2879+
time_series_data_config=TimeSeriesDataConfig(
2880+
item_id="[].id",
2881+
timestamp="[].temporal[].timestamp",
2882+
target_time_series="[].temporal[].target",
2883+
related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"],
2884+
static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"],
2885+
dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS,
2886+
),
2887+
),
2888+
error=AssertionError,
2889+
error_msg="baseline entry for item3 does not match number",
2890+
),
2891+
ValidateTSXBaselineCase( # no static covariates are in data config
2892+
explainability_config=AsymmetricShapleyValueConfig(
2893+
direction="chronological",
2894+
granularity="timewise",
2895+
baseline={
2896+
"target_time_series": "zero",
2897+
"related_time_series": "zero",
2898+
"static_covariates": {
2899+
"item1": [0.0, 0.5, 1.0],
2900+
"item2": [0.3, 0.6, 0.9],
2901+
"item3": [0.0, 1.0, 1.0],
2902+
"item4": [0.9, 0.6, 0.3],
2903+
"item5": [1.0, 0.5, 0.0],
2904+
},
2905+
},
2906+
),
2907+
data_config=DataConfig(
2908+
s3_data_input_path="s3://data/input",
2909+
s3_output_path="s3://data/output",
2910+
headers=["id", "time", "tts", "rts_1", "rts_2"],
2911+
dataset_type="application/json",
2912+
time_series_data_config=TimeSeriesDataConfig(
2913+
item_id="[].id",
2914+
timestamp="[].temporal[].timestamp",
2915+
target_time_series="[].temporal[].target",
2916+
related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"],
2917+
dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS,
2918+
),
2919+
),
2920+
error=ValueError,
2921+
error_msg="no static covariate columns are provided in TimeSeriesDataConfig",
2922+
),
2923+
ValidateTSXBaselineCase( # some item ids do not have a list as their baseline
2924+
explainability_config=AsymmetricShapleyValueConfig(
2925+
direction="chronological",
2926+
granularity="timewise",
2927+
baseline={
2928+
"target_time_series": "zero",
2929+
"related_time_series": "zero",
2930+
"static_covariates": {
2931+
"item1": [0.0, 0.5, 1.0],
2932+
"item2": [0.3, 0.6, 0.9],
2933+
"item3": [0.0, 1.0, 1.0],
2934+
"item4": [0.9, 0.6, 0.3],
2935+
"item5": {"cov_1": 1.0, "cov_2": 0.5, "cov_3": 0.0},
2936+
},
2937+
},
2938+
),
2939+
data_config=DataConfig(
2940+
s3_data_input_path="s3://data/input",
2941+
s3_output_path="s3://data/output",
2942+
headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"],
2943+
dataset_type="application/json",
2944+
time_series_data_config=TimeSeriesDataConfig(
2945+
item_id="[].id",
2946+
timestamp="[].temporal[].timestamp",
2947+
target_time_series="[].temporal[].target",
2948+
related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"],
2949+
static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"],
2950+
dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS,
2951+
),
2952+
),
2953+
error=AssertionError,
2954+
error_msg="Baseline entry for item5 must be a list",
2955+
),
2956+
],
2957+
)
2958+
def test_time_series_baseline_invalid_static_covariates(self, case: ValidateTSXBaselineCase):
2959+
"""
2960+
GIVEN AsymmetricShapleyValueConfig and TimeSeriesDataConfig are created and a baseline
2961+
is provided where the static covariates baseline values are misconfigured
2962+
WHEN AnalysisConfigGenerator._validate_time_series_static_covariates_baseline() is called
2963+
THEN the appropriate error is raised
2964+
"""
2965+
with pytest.raises(case.error, match=case.error_msg):
2966+
_AnalysisConfigGenerator._validate_time_series_static_covariates_baseline(
2967+
explainability_config=case.explainability_config,
2968+
data_config=case.data_config,
2969+
)
2970+
27972971

27982972
class TestProcessingOutputHandler:
27992973
def test_get_s3_upload_mode_image(self):

0 commit comments

Comments
 (0)