Skip to content

Commit 292dbdd

Browse files
Add ETS model
1 parent d50742d commit 292dbdd

File tree

6 files changed

+613
-3
lines changed

6 files changed

+613
-3
lines changed

pymc_experimental/statespace/core/statespace.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,19 @@ def _print_data_requirements(self) -> None:
301301
f"{out}"
302302
)
303303

304-
def _unpack_statespace_with_placeholders(self) -> tuple[pt.TensorVariable]:
304+
def _unpack_statespace_with_placeholders(
305+
self,
306+
) -> tuple[
307+
pt.TensorVariable,
308+
pt.TensorVariable,
309+
pt.TensorVariable,
310+
pt.TensorVariable,
311+
pt.TensorVariable,
312+
pt.TensorVariable,
313+
pt.TensorVariable,
314+
pt.TensorVariable,
315+
pt.TensorVariable,
316+
]:
305317
"""
306318
Helper function to quickly obtain all statespace matrices in the standard order. Matrices returned by this
307319
method will include pytensor placeholders.
@@ -445,7 +457,7 @@ def add_default_priors(self) -> None:
445457
raise NotImplementedError("The add_default_priors property has not been implemented!")
446458

447459
def make_and_register_variable(
448-
self, name, shape: int | tuple[int] | None = None, dtype=floatX
460+
self, name, shape: int | tuple[int, ...] | None = None, dtype=floatX
449461
) -> Variable:
450462
"""
451463
Helper function to create a pytensor symbolic variable and register it in the _name_to_variable dictionary

0 commit comments

Comments
 (0)