Skip to content

Commit 75204ad

Browse files
fix: removed field use_future_covariates and related unit tests from TimeSeriesModelConfig
1 parent af8b7de commit 75204ad

File tree

2 files changed

+1
-72
lines changed

2 files changed

+1
-72
lines changed

src/sagemaker/clarify.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@
347347
SchemaOptional("time_series_predictor_config"): {
348348
"forecast": str,
349349
"forecast_horizon": int,
350-
SchemaOptional("use_future_covariates"): bool,
351350
},
352351
},
353352
}
@@ -805,15 +804,12 @@ def __init__(
805804
self,
806805
forecast: str,
807806
forecast_horizon: int = TS_MODEL_DEFAULT_FORECAST_HORIZON,
808-
use_future_covariates: Optional[bool] = False,
809807
):
810808
"""Initializes model configuration fields for TimeSeries explainability use cases.
811809
812810
Args:
813811
forecast (str): JMESPath expression to extract the forecast result.
814812
forecast_horizon (int): An integer that tells the forecast horizon.
815-
use_future_covariates (None or bool): If set as True, future covariates
816-
included in model input and used for forecasting
817813
818814
Raises:
819815
AssertionError: when either ``forecast`` or ``forecast_horizon`` are not provided
@@ -829,15 +825,10 @@ def __init__(
829825
raise ValueError("Please provide a string JMESPath expression for ``forecast``.")
830826
if not isinstance(forecast_horizon, int):
831827
raise ValueError("Please provide an integer ``forecast_horizon``.")
832-
if use_future_covariates and not isinstance(use_future_covariates, bool):
833-
raise ValueError("Please provide a boolean value for ``use_future_covariates``.")
834828
# add fields to an internal config dictionary
835829
self.predictor_config = dict()
836830
_set(forecast, "forecast", self.predictor_config)
837831
_set(forecast_horizon, "forecast_horizon", self.predictor_config)
838-
_set(
839-
use_future_covariates, "use_future_covariates", self.predictor_config
840-
) # _set() does nothing if a given argument is None
841832

842833
def get_predictor_config(self):
843834
"""Returns TimeSeries predictor config dictionary"""

tests/unit/test_clarify.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,6 @@ def test_time_series_model_config(self):
897897
expected_config = {
898898
"forecast": forecast,
899899
"forecast_horizon": TS_MODEL_DEFAULT_FORECAST_HORIZON,
900-
"use_future_covariates": False,
901900
}
902901
# WHEN
903902
ts_model_config = TimeSeriesModelConfig(
@@ -919,7 +918,6 @@ def test_time_series_model_config_with_forecast_horizon(self):
919918
expected_config = {
920919
"forecast": forecast,
921920
"forecast_horizon": forecast_horizon,
922-
"use_future_covariates": False,
923921
}
924922
# WHEN
925923
ts_model_config = TimeSeriesModelConfig(
@@ -929,97 +927,39 @@ def test_time_series_model_config_with_forecast_horizon(self):
929927
# THEN
930928
assert ts_model_config.predictor_config == expected_config
931929

932-
def test_time_series_model_config_with_future_covariates(self):
933-
"""
934-
GIVEN a valid forecast expression
935-
WHEN a TimeSeriesModelConfig is constructed with it and use_future_covariates is True
936-
THEN the predictor_config dictionary matches the expected
937-
"""
938-
# GIVEN
939-
forecast = "results.[forecast]" # mock JMESPath expression for forecast
940-
# create expected output
941-
expected_config = {
942-
"forecast": forecast,
943-
"forecast_horizon": TS_MODEL_DEFAULT_FORECAST_HORIZON,
944-
"use_future_covariates": True,
945-
}
946-
# WHEN
947-
ts_model_config = TimeSeriesModelConfig(
948-
forecast,
949-
use_future_covariates=True,
950-
)
951-
# THEN
952-
assert ts_model_config.predictor_config == expected_config
953-
954-
def test_time_series_model_config_with_horizon_and_covariates(self):
955-
"""
956-
GIVEN a valid forecast expression and forecast horizon
957-
WHEN a TimeSeriesModelConfig is constructed with it and use_future_covariates is True
958-
THEN the predictor_config dictionary matches the expected
959-
"""
960-
# GIVEN
961-
forecast = "results.[forecast]" # mock JMESPath expression for forecast
962-
forecast_horizon = 25 # non-default forecast horizon
963-
# create expected output
964-
expected_config = {
965-
"forecast": forecast,
966-
"forecast_horizon": forecast_horizon,
967-
"use_future_covariates": True,
968-
}
969-
# WHEN
970-
ts_model_config = TimeSeriesModelConfig(
971-
forecast,
972-
forecast_horizon=forecast_horizon,
973-
use_future_covariates=True,
974-
)
975-
# THEN
976-
assert ts_model_config.predictor_config == expected_config
977-
978930
@pytest.mark.parametrize(
979-
("forecast", "forecast_horizon", "use_future_covariates", "error", "error_message"),
931+
("forecast", "forecast_horizon", "error", "error_message"),
980932
[
981933
(
982934
None,
983935
TS_MODEL_DEFAULT_FORECAST_HORIZON,
984-
None,
985936
AssertionError,
986937
"Please provide ``forecast``, a JMESPath expression to extract the forecast result.",
987938
),
988939
(
989940
"results.[forecast]",
990941
None,
991-
None,
992942
AssertionError,
993943
"Please provide an integer ``forecast_horizon``.",
994944
),
995945
(
996946
123,
997947
TS_MODEL_DEFAULT_FORECAST_HORIZON,
998-
None,
999948
ValueError,
1000949
"Please provide a string JMESPath expression for ``forecast``.",
1001950
),
1002951
(
1003952
"results.[forecast]",
1004953
"Not an int",
1005-
None,
1006954
ValueError,
1007955
"Please provide an integer ``forecast_horizon``.",
1008956
),
1009-
(
1010-
"results.[forecast]",
1011-
TS_MODEL_DEFAULT_FORECAST_HORIZON,
1012-
"Not a bool",
1013-
ValueError,
1014-
"Please provide a boolean value for ``use_future_covariates``.",
1015-
),
1016957
],
1017958
)
1018959
def test_time_series_model_config_invalid(
1019960
self,
1020961
forecast,
1021962
forecast_horizon,
1022-
use_future_covariates,
1023963
error,
1024964
error_message,
1025965
):
@@ -1032,7 +972,6 @@ def test_time_series_model_config_invalid(
1032972
TimeSeriesModelConfig(
1033973
forecast=forecast,
1034974
forecast_horizon=forecast_horizon,
1035-
use_future_covariates=use_future_covariates,
1036975
)
1037976

1038977
def test_model_config_with_time_series(self):
@@ -1064,7 +1003,6 @@ def test_model_config_with_time_series(self):
10641003
mock_ts_model_config_dict = {
10651004
"forecast": forecast,
10661005
"forecast_horizon": forecast_horizon,
1067-
"use_future_covariates": True,
10681006
}
10691007
mock_ts_model_config = Mock()
10701008
mock_ts_model_config.get_predictor_config.return_value = mock_ts_model_config_dict

0 commit comments

Comments
 (0)