@@ -870,3 +870,93 @@ def test_forecast_with_exog_data(rng, exog_ss_mod, idata_exog, start):
870
870
regression_effect_expected = (betas * scenario_xr ).sum (dim = ["state" ])
871
871
872
872
assert_allclose (regression_effect , regression_effect_expected )
873
+
874
+
875
+ @pytest .mark .filterwarnings ("ignore:Provided data contains missing values." )
876
+ @pytest .mark .filterwarnings ("ignore:The RandomType SharedVariables" )
877
+ def test_foreacast_valid_index (rng ):
878
+ # Regression test for issue reported at https://github.com/pymc-devs/pymc-extras/issues/424
879
+
880
+ index = pd .date_range (start = "2023-05-01" , end = "2025-01-29" , freq = "D" )
881
+ T , k = len (index ), 2
882
+ data = np .zeros ((T , k ))
883
+ idx = rng .choice (T , size = 10 , replace = False )
884
+ cols = rng .choice (k , size = 10 , replace = True )
885
+
886
+ data [idx , cols ] = 1
887
+
888
+ df_holidays = pd .DataFrame (data , index = index , columns = ["Holiday 1" , "Holiday 2" ])
889
+
890
+ data = rng .normal (size = (T , 1 ))
891
+ nan_locs = rng .choice (T , size = 10 , replace = False )
892
+ data [nan_locs ] = np .nan
893
+ y = pd .DataFrame (data , index = index , columns = ["sales" ])
894
+
895
+ level_trend = st .LevelTrendComponent (order = 1 , innovations_order = [0 ])
896
+ weekly_seasonality = st .TimeSeasonality (
897
+ season_length = 7 ,
898
+ state_names = ["Sun" , "Mon" , "Tues" , "Wed" , "Thu" , "Fri" , "Sat" ],
899
+ innovations = True ,
900
+ remove_first_state = False ,
901
+ )
902
+ quarterly_seasonality = st .FrequencySeasonality (season_length = 365 , n = 2 , innovations = True )
903
+ ar1 = st .AutoregressiveComponent (order = 1 )
904
+ me = st .MeasurementError ()
905
+
906
+ exog = st .RegressionComponent (
907
+ name = "exog" , # Name of this exogenous variable component
908
+ k_exog = 2 , # Only one exogenous variable now
909
+ innovations = False , # Typically fixed effect (no stochastic evolution)
910
+ state_names = df_holidays .columns .tolist (),
911
+ )
912
+
913
+ combined_model = level_trend + weekly_seasonality + quarterly_seasonality + me + ar1 + exog
914
+ ss_mod = combined_model .build ()
915
+
916
+ with pm .Model (coords = ss_mod .coords ) as struct_model :
917
+ P0_diag = pm .Gamma ("P0_diag" , alpha = 2 , beta = 10 , dims = ["state" ])
918
+ P0 = pm .Deterministic ("P0" , pt .diag (P0_diag ), dims = ["state" , "state_aux" ])
919
+
920
+ initial_trend = pm .Normal ("initial_trend" , mu = [0 ], sigma = [0.005 ], dims = ["trend_state" ])
921
+ # sigma_trend = pm.Gamma("sigma_trend", alpha=2, beta=1, dims=["trend_shock"]) # Applied to the level only
922
+
923
+ Seasonal_coefs = pm .ZeroSumNormal (
924
+ "Seasonal[s=7]_coefs" , sigma = 0.5 , dims = ["Seasonal[s=7]_state" ]
925
+ ) # DOW dev. from weekly mean
926
+ sigma_Seasonal = pm .Gamma (
927
+ "sigma_Seasonal[s=7]" , alpha = 2 , beta = 1
928
+ ) # How much this dev. can dev.
929
+
930
+ Frequency_coefs = pm .Normal (
931
+ "Frequency[s=365, n=2]" , mu = 0 , sigma = 0.5 , dims = ["Frequency[s=365, n=2]_state" ]
932
+ ) # amplitudes in short-term (weekly noise culprit)
933
+ sigma_Frequency = pm .Gamma (
934
+ "sigma_Frequency[s=365, n=2]" , alpha = 2 , beta = 1
935
+ ) # smoothness & adaptability over time
936
+
937
+ ar_params = pm .Laplace ("ar_params" , mu = 0 , b = 0.2 , dims = ["ar_lag" ])
938
+ sigma_ar = pm .Gamma ("sigma_ar" , alpha = 2 , beta = 1 )
939
+
940
+ sigma_measurement_error = pm .HalfStudentT ("sigma_MeasurementError" , nu = 3 , sigma = 1 )
941
+
942
+ data_exog = pm .Data ("data_exog" , df_holidays .values , dims = ["time" , "exog_state" ])
943
+ beta_exog = pm .Normal ("beta_exog" , mu = 0 , sigma = 1 , dims = ["exog_state" ])
944
+
945
+ ss_mod .build_statespace_graph (y , mode = "JAX" )
946
+
947
+ idata = pm .sample_prior_predictive ()
948
+
949
+ post = ss_mod .sample_conditional_prior (idata )
950
+
951
+ # Define start date and forecast period
952
+ start_date , n_periods = pd .to_datetime ("2024-4-15" ), 8
953
+
954
+ # Extract exogenous data for the forecast period
955
+ scenario = {
956
+ "data_exog" : pd .DataFrame (
957
+ df_holidays .loc [start_date :].iloc [:n_periods ], columns = df_holidays .columns
958
+ )
959
+ }
960
+
961
+ # Generate the forecast
962
+ forecasts = ss_mod .forecast (idata .prior , scenario = scenario , use_scenario_index = True )
0 commit comments