Skip to content

Commit 325c6bd

Browse files
Add helper function to sample from statespace models
1 parent 1fb5536 commit 325c6bd

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytensor
4+
import pytensor.tensor as pt
5+
6+
from pymc_experimental.statespace.core import PyMCStateSpace
7+
from pymc_experimental.statespace.filters.distributions import LinearGaussianStateSpace
8+
from pymc_experimental.statespace.utils.constants import SHORT_NAME_TO_LONG
9+
10+
11+
def compile_statespace(
12+
statespace_model: PyMCStateSpace, steps: int | None = None, **compile_kwargs
13+
):
14+
if steps is None:
15+
steps = pt.iscalar("steps")
16+
17+
x0, _, c, d, T, Z, R, H, Q = statespace_model._unpack_statespace_with_placeholders()
18+
19+
sequence_names = [x.name for x in [c, d] if x.ndim == 2]
20+
sequence_names += [x.name for x in [T, Z, R, H, Q] if x.ndim == 3]
21+
22+
rename_dict = {v: k for k, v in SHORT_NAME_TO_LONG.items()}
23+
sequence_names = list(map(rename_dict.get, sequence_names))
24+
25+
P0 = pt.zeros((x0.shape[0], x0.shape[0]))
26+
27+
outputs = LinearGaussianStateSpace.dist(
28+
x0, P0, c, d, T, Z, R, H, Q, steps=steps, sequence_names=sequence_names
29+
)
30+
31+
inputs = list(pytensor.graph.basic.explicit_graph_inputs(outputs))
32+
33+
_f = pm.compile_pymc(inputs, outputs, on_unused_input="ignore", **compile_kwargs)
34+
35+
def f(*, draws=1, **params):
36+
if isinstance(steps, pt.Variable):
37+
inner_steps = params.get("steps", 100)
38+
else:
39+
inner_steps = steps
40+
41+
output = [np.empty((draws, inner_steps + 1, x.type.shape[-1])) for x in outputs]
42+
for i in range(draws):
43+
draw = _f(**params)
44+
for j, x in enumerate(draw):
45+
output[j][i] = x
46+
return [x.squeeze() for x in output]
47+
48+
return f

0 commit comments

Comments
 (0)