Skip to content

Commit 057c8cf

Browse files
jessegrabowskiDekermanjian
authored andcommitted
Use set_data in forecast
1 parent 66caebd commit 057c8cf

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,9 @@ def _kalman_filter_outputs_from_dummy_graph(
10271027
provided when the model was built.
10281028
data_dims: str or tuple of str, optional
10291029
Dimension names associated with the model data. If None, defaults to ("time", "obs_state")
1030+
scenario: dict[str, pd.DataFrame], optional
1031+
Dictionary of out-of-sample scenario dataframes. If provided, it must have values for all data variables
1032+
in the model. pm.set_data is used to replace training data with new values.
10301033
10311034
Returns
10321035
-------
@@ -2060,6 +2063,7 @@ def forecast(
20602063

20612064
with pm.Model(coords=temp_coords) as forecast_model:
20622065
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2066+
scenario=scenario,
20632067
data_dims=["data_time", OBS_STATE_DIM],
20642068
)
20652069

@@ -2073,17 +2077,6 @@ def forecast(
20732077
"P0_slice", cov[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
20742078
)
20752079

2076-
if scenario is not None:
2077-
sub_dict = {
2078-
forecast_model[data_name]: pt.as_tensor_variable(
2079-
scenario.get(data_name), name=data_name
2080-
)
2081-
for data_name in self.data_names
2082-
}
2083-
2084-
matrices = graph_replace(matrices, replace=sub_dict, strict=True)
2085-
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]
2086-
20872080
_ = LinearGaussianStateSpace(
20882081
"forecast",
20892082
x0,

0 commit comments

Comments
 (0)