Skip to content

Commit 848aa3c

Browse files
jessegrabowskiDekermanjian
authored andcommitted
Tracking down data bug
1 parent 7be2415 commit 848aa3c

File tree

2 files changed

+94
-2
lines changed

2 files changed

+94
-2
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,8 +1570,10 @@ def _validate_forecast_args(
15701570
raise ValueError(
15711571
"Integer start must be within the range of the data index used to fit the model."
15721572
)
1573-
if periods is None and end is None:
1574-
raise ValueError("Must specify one of either periods or end")
1573+
if periods is None and end is None and not use_scenario_index:
1574+
raise ValueError(
1575+
"Must specify one of either periods or end unless use_scenario_index=True"
1576+
)
15751577
if periods is not None and end is not None:
15761578
raise ValueError("Must specify exactly one of either periods or end")
15771579
if scenario is None and use_scenario_index:

tests/statespace/test_statespace.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,3 +870,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
870870
regression_effect_expected = (betas * scenario_xr).sum(dim=["state"])
871871

872872
assert_allclose(regression_effect, regression_effect_expected)
873+
874+
875+
@pytest.mark.filterwarnings("ignore:Provided data contains missing values.")
876+
@pytest.mark.filterwarnings("ignore:The RandomType SharedVariables")
877+
def test_foreacast_valid_index(rng):
878+
# Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424
879+
880+
index = pd.date_range(start="2023-05-01", end="2025-01-29", freq="D")
881+
T, k = len(index), 2
882+
data = np.zeros((T, k))
883+
idx = rng.choice(T, size=10, replace=False)
884+
cols = rng.choice(k, size=10, replace=True)
885+
886+
data[idx, cols] = 1
887+
888+
df_holidays = pd.DataFrame(data, index=index, columns=["Holiday 1", "Holiday 2"])
889+
890+
data = rng.normal(size=(T, 1))
891+
nan_locs = rng.choice(T, size=10, replace=False)
892+
data[nan_locs] = np.nan
893+
y = pd.DataFrame(data, index=index, columns=["sales"])
894+
895+
level_trend = st.LevelTrendComponent(order=1, innovations_order=[0])
896+
weekly_seasonality = st.TimeSeasonality(
897+
season_length=7,
898+
state_names=["Sun", "Mon", "Tues", "Wed", "Thu", "Fri", "Sat"],
899+
innovations=True,
900+
remove_first_state=False,
901+
)
902+
quarterly_seasonality = st.FrequencySeasonality(season_length=365, n=2, innovations=True)
903+
ar1 = st.AutoregressiveComponent(order=1)
904+
me = st.MeasurementError()
905+
906+
exog = st.RegressionComponent(
907+
name="exog", # Name of this exogenous variable component
908+
k_exog=2, # Only one exogenous variable now
909+
innovations=False, # Typically fixed effect (no stochastic evolution)
910+
state_names=df_holidays.columns.tolist(),
911+
)
912+
913+
combined_model = level_trend + weekly_seasonality + quarterly_seasonality + me + ar1 + exog
914+
ss_mod = combined_model.build()
915+
916+
with pm.Model(coords=ss_mod.coords) as struct_model:
917+
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=10, dims=["state"])
918+
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=["state", "state_aux"])
919+
920+
initial_trend = pm.Normal("initial_trend", mu=[0], sigma=[0.005], dims=["trend_state"])
921+
# sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=["trend_shock"]) # Applied to the level only
922+
923+
Seasonal_coefs = pm.ZeroSumNormal(
924+
"Seasonal[s=7]_coefs", sigma=0.5, dims=["Seasonal[s=7]_state"]
925+
) # DOW dev. from weekly mean
926+
sigma_Seasonal = pm.Gamma(
927+
"sigma_Seasonal[s=7]", alpha=2, beta=1
928+
) # How much this dev. can dev.
929+
930+
Frequency_coefs = pm.Normal(
931+
"Frequency[s=365, n=2]", mu=0, sigma=0.5, dims=["Frequency[s=365, n=2]_state"]
932+
) # amplitudes in short-term (weekly noise culprit)
933+
sigma_Frequency = pm.Gamma(
934+
"sigma_Frequency[s=365, n=2]", alpha=2, beta=1
935+
) # smoothness & adaptability over time
936+
937+
ar_params = pm.Laplace("ar_params", mu=0, b=0.2, dims=["ar_lag"])
938+
sigma_ar = pm.Gamma("sigma_ar", alpha=2, beta=1)
939+
940+
sigma_measurement_error = pm.HalfStudentT("sigma_MeasurementError", nu=3, sigma=1)
941+
942+
data_exog = pm.Data("data_exog", df_holidays.values, dims=["time", "exog_state"])
943+
beta_exog = pm.Normal("beta_exog", mu=0, sigma=1, dims=["exog_state"])
944+
945+
ss_mod.build_statespace_graph(y, mode="JAX")
946+
947+
idata = pm.sample_prior_predictive()
948+
949+
post = ss_mod.sample_conditional_prior(idata)
950+
951+
# Define start date and forecast period
952+
start_date, n_periods = pd.to_datetime("2024-4-15"), 8
953+
954+
# Extract exogenous data for the forecast period
955+
scenario = {
956+
"data_exog": pd.DataFrame(
957+
df_holidays.loc[start_date:].iloc[:n_periods], columns=df_holidays.columns
958+
)
959+
}
960+
961+
# Generate the forecast
962+
forecasts = ss_mod.forecast(idata.prior, scenario=scenario, use_scenario_index=True)

0 commit comments

Comments
 (0)