@@ -1027,6 +1027,9 @@ def _kalman_filter_outputs_from_dummy_graph(
1027
1027
provided when the model was built.
1028
1028
data_dims: str or tuple of str, optional
1029
1029
Dimension names associated with the model data. If None, defaults to ("time", "obs_state")
1030
+ scenario: dict[str, pd.DataFrame], optional
1031
+ Dictionary of out-of-sample scenario dataframes. If provided, it must have values for all data variables
1032
+ in the model. pm.set_data is used to replace training data with new values.
1030
1033
1031
1034
Returns
1032
1035
-------
@@ -2060,6 +2063,7 @@ def forecast(
2060
2063
2061
2064
with pm .Model (coords = temp_coords ) as forecast_model :
2062
2065
(_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2066
+ scenario = scenario ,
2063
2067
data_dims = ["data_time" , OBS_STATE_DIM ],
2064
2068
)
2065
2069
@@ -2073,17 +2077,6 @@ def forecast(
2073
2077
"P0_slice" , cov [t0_idx ], dims = cov_dims [1 :] if cov_dims is not None else None
2074
2078
)
2075
2079
2076
- if scenario is not None :
2077
- sub_dict = {
2078
- forecast_model [data_name ]: pt .as_tensor_variable (
2079
- scenario .get (data_name ), name = data_name
2080
- )
2081
- for data_name in self .data_names
2082
- }
2083
-
2084
- matrices = graph_replace (matrices , replace = sub_dict , strict = True )
2085
- [setattr (matrix , "name" , name ) for name , matrix in zip (MATRIX_NAMES [2 :], matrices )]
2086
-
2087
2080
_ = LinearGaussianStateSpace (
2088
2081
"forecast" ,
2089
2082
x0 ,
0 commit comments