|
16 | 16 | import copy
|
17 | 17 |
|
18 | 18 | import pytest
|
| 19 | +from dataclasses import dataclass |
19 | 20 | from mock import ANY, MagicMock, Mock, patch
|
20 | 21 | from typing import Any, Dict, List, NamedTuple, Optional, Union
|
21 | 22 |
|
@@ -2574,7 +2575,8 @@ def _build_pdp_config_mock():
|
2574 | 2575 |
|
2575 | 2576 | def _build_asymmetric_shapley_value_config_mock():
|
2576 | 2577 | asym_shap_val_config_dict = {
|
2577 |
| - "explanation_type": "fine_grained", |
| 2578 | + "direction": "chronological", |
| 2579 | + "granularity": "fine_grained", |
2578 | 2580 | "num_samples": 20,
|
2579 | 2581 | }
|
2580 | 2582 | asym_shap_val_config = Mock(spec=AsymmetricShapleyValueConfig)
|
@@ -2613,6 +2615,14 @@ def _build_model_config_mock():
|
2613 | 2615 | return model_config
|
2614 | 2616 |
|
2615 | 2617 |
|
| 2618 | +@dataclass |
| 2619 | +class ValidateTSXBaselineCase: |
| 2620 | + explainability_config: AsymmetricShapleyValueConfig |
| 2621 | + data_config: DataConfig |
| 2622 | + error: Optional[Exception] = None |
| 2623 | + error_msg: Optional[str] = None |
| 2624 | + |
| 2625 | + |
2616 | 2626 | class TestAnalysisConfigGeneratorForTimeSeriesExplainability:
|
2617 | 2627 | @patch("sagemaker.clarify._AnalysisConfigGenerator._add_methods")
|
2618 | 2628 | @patch("sagemaker.clarify._AnalysisConfigGenerator._add_predictor")
|
@@ -2794,6 +2804,170 @@ def test_merge_explainability_configs_with_timeseries_invalid(
|
2794 | 2804 | explainability_config=mock_config,
|
2795 | 2805 | )
|
2796 | 2806 |
|
| 2807 | + @pytest.mark.parametrize( |
| 2808 | + "case", |
| 2809 | + [ |
| 2810 | + ValidateTSXBaselineCase( |
| 2811 | + explainability_config=AsymmetricShapleyValueConfig( |
| 2812 | + direction="chronological", |
| 2813 | + granularity="timewise", |
| 2814 | + baseline={ |
| 2815 | + "target_time_series": "zero", |
| 2816 | + "related_time_series": "zero", |
| 2817 | + "static_covariates": { |
| 2818 | + "item1": [0.0, 0.5, 1.0], |
| 2819 | + "item2": [0.3, 0.6, 0.9], |
| 2820 | + "item3": [0.0, 1.0, 1.0], |
| 2821 | + "item4": [0.9, 0.6, 0.3], |
| 2822 | + "item5": [1.0, 0.5, 0.0], |
| 2823 | + }, |
| 2824 | + }, |
| 2825 | + ), |
| 2826 | + data_config=DataConfig( |
| 2827 | + s3_data_input_path="s3://data/input", |
| 2828 | + s3_output_path="s3://data/output", |
| 2829 | + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], |
| 2830 | + dataset_type="application/json", |
| 2831 | + time_series_data_config=TimeSeriesDataConfig( |
| 2832 | + item_id="[].id", |
| 2833 | + timestamp="[].temporal[].timestamp", |
| 2834 | + target_time_series="[].temporal[].target", |
| 2835 | + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], |
| 2836 | + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], |
| 2837 | + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, |
| 2838 | + ), |
| 2839 | + ), |
| 2840 | + ), |
| 2841 | + ], |
| 2842 | + ) |
| 2843 | + def test_time_series_baseline_valid_static_covariates(self, case: ValidateTSXBaselineCase): |
| 2844 | + """ |
| 2845 | + GIVEN AsymmetricShapleyValueConfig and TimeSeriesDataConfig are created and a baseline |
| 2846 | + is provided |
| 2847 | + WHEN AnalysisConfigGenerator._validate_time_series_static_covariates_baseline() is called |
| 2848 | + THEN no error is raised |
| 2849 | + """ |
| 2850 | + _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( |
| 2851 | + explainability_config=case.explainability_config, |
| 2852 | + data_config=case.data_config, |
| 2853 | + ) |
| 2854 | + |
| 2855 | + @pytest.mark.parametrize( |
| 2856 | + "case", |
| 2857 | + [ |
| 2858 | + ValidateTSXBaselineCase( # some item ids are missing baseline values |
| 2859 | + explainability_config=AsymmetricShapleyValueConfig( |
| 2860 | + direction="chronological", |
| 2861 | + granularity="timewise", |
| 2862 | + baseline={ |
| 2863 | + "target_time_series": "zero", |
| 2864 | + "related_time_series": "zero", |
| 2865 | + "static_covariates": { |
| 2866 | + "item1": [0.0, 0.5, 1.0], |
| 2867 | + "item2": [0.3, 0.6, 0.9], |
| 2868 | + "item3": [0.0], |
| 2869 | + "item4": [0.9, 0.6, 0.3], |
| 2870 | + "item5": [1.0], |
| 2871 | + }, |
| 2872 | + }, |
| 2873 | + ), |
| 2874 | + data_config=DataConfig( |
| 2875 | + s3_data_input_path="s3://data/input", |
| 2876 | + s3_output_path="s3://data/output", |
| 2877 | + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], |
| 2878 | + dataset_type="application/json", |
| 2879 | + time_series_data_config=TimeSeriesDataConfig( |
| 2880 | + item_id="[].id", |
| 2881 | + timestamp="[].temporal[].timestamp", |
| 2882 | + target_time_series="[].temporal[].target", |
| 2883 | + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], |
| 2884 | + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], |
| 2885 | + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, |
| 2886 | + ), |
| 2887 | + ), |
| 2888 | + error=AssertionError, |
| 2889 | + error_msg="baseline entry for item3 does not match number", |
| 2890 | + ), |
| 2891 | + ValidateTSXBaselineCase( # no static covariates are in data config |
| 2892 | + explainability_config=AsymmetricShapleyValueConfig( |
| 2893 | + direction="chronological", |
| 2894 | + granularity="timewise", |
| 2895 | + baseline={ |
| 2896 | + "target_time_series": "zero", |
| 2897 | + "related_time_series": "zero", |
| 2898 | + "static_covariates": { |
| 2899 | + "item1": [0.0, 0.5, 1.0], |
| 2900 | + "item2": [0.3, 0.6, 0.9], |
| 2901 | + "item3": [0.0, 1.0, 1.0], |
| 2902 | + "item4": [0.9, 0.6, 0.3], |
| 2903 | + "item5": [1.0, 0.5, 0.0], |
| 2904 | + }, |
| 2905 | + }, |
| 2906 | + ), |
| 2907 | + data_config=DataConfig( |
| 2908 | + s3_data_input_path="s3://data/input", |
| 2909 | + s3_output_path="s3://data/output", |
| 2910 | + headers=["id", "time", "tts", "rts_1", "rts_2"], |
| 2911 | + dataset_type="application/json", |
| 2912 | + time_series_data_config=TimeSeriesDataConfig( |
| 2913 | + item_id="[].id", |
| 2914 | + timestamp="[].temporal[].timestamp", |
| 2915 | + target_time_series="[].temporal[].target", |
| 2916 | + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], |
| 2917 | + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, |
| 2918 | + ), |
| 2919 | + ), |
| 2920 | + error=ValueError, |
| 2921 | + error_msg="no static covariate columns are provided in TimeSeriesDataConfig", |
| 2922 | + ), |
| 2923 | + ValidateTSXBaselineCase( # some item ids do not have a list as their baseline |
| 2924 | + explainability_config=AsymmetricShapleyValueConfig( |
| 2925 | + direction="chronological", |
| 2926 | + granularity="timewise", |
| 2927 | + baseline={ |
| 2928 | + "target_time_series": "zero", |
| 2929 | + "related_time_series": "zero", |
| 2930 | + "static_covariates": { |
| 2931 | + "item1": [0.0, 0.5, 1.0], |
| 2932 | + "item2": [0.3, 0.6, 0.9], |
| 2933 | + "item3": [0.0, 1.0, 1.0], |
| 2934 | + "item4": [0.9, 0.6, 0.3], |
| 2935 | + "item5": {"cov_1": 1.0, "cov_2": 0.5, "cov_3": 0.0}, |
| 2936 | + }, |
| 2937 | + }, |
| 2938 | + ), |
| 2939 | + data_config=DataConfig( |
| 2940 | + s3_data_input_path="s3://data/input", |
| 2941 | + s3_output_path="s3://data/output", |
| 2942 | + headers=["id", "time", "tts", "rts_1", "rts_2", "scv1", "scv2", "scv3"], |
| 2943 | + dataset_type="application/json", |
| 2944 | + time_series_data_config=TimeSeriesDataConfig( |
| 2945 | + item_id="[].id", |
| 2946 | + timestamp="[].temporal[].timestamp", |
| 2947 | + target_time_series="[].temporal[].target", |
| 2948 | + related_time_series=["[].temporal[].rts_1", "[].temporal[].rts_2"], |
| 2949 | + static_covariates=["[].cov_1", "[].cov_2", "[].cov_3"], |
| 2950 | + dataset_format=TimeSeriesJSONDatasetFormat.ITEM_RECORDS, |
| 2951 | + ), |
| 2952 | + ), |
| 2953 | + error=AssertionError, |
| 2954 | + error_msg="Baseline entry for item5 must be a list", |
| 2955 | + ), |
| 2956 | + ], |
| 2957 | + ) |
| 2958 | + def test_time_series_baseline_invalid_static_covariates(self, case: ValidateTSXBaselineCase): |
| 2959 | + """ |
| 2960 | + GIVEN AsymmetricShapleyValueConfig and TimeSeriesDataConfig are created and a baseline |
| 2961 | + is provided where the static covariates baseline values are misconfigured |
| 2962 | + WHEN AnalysisConfigGenerator._validate_time_series_static_covariates_baseline() is called |
| 2963 | + THEN the appropriate error is raised |
| 2964 | + """ |
| 2965 | + with pytest.raises(case.error, match=case.error_msg): |
| 2966 | + _AnalysisConfigGenerator._validate_time_series_static_covariates_baseline( |
| 2967 | + explainability_config=case.explainability_config, |
| 2968 | + data_config=case.data_config, |
| 2969 | + ) |
| 2970 | + |
2797 | 2971 |
|
2798 | 2972 | class TestProcessingOutputHandler:
|
2799 | 2973 | def test_get_s3_upload_mode_image(self):
|
|
0 commit comments