Skip to content

Commit 47571a7

Browse files
committed
Infer steps from shape
1 parent 2171790 commit 47571a7

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

pymc/distributions/timeseries.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class GaussianRandomWalkRV(RandomVariable):
4949
dtype = "floatX"
5050
_print_name = ("GaussianRandomWalk", "\\operatorname{GaussianRandomWalk}")
5151

52-
def _shape_from_params(self, dist_params, reop_param_idx=0, param_shapes=None):
52+
# TODO: Assert steps is a scalar!
53+
54+
def _shape_from_params(self, dist_params, **kwargs):
5355
steps = dist_params[3]
5456

5557
# TODO: Ask ricardo why this is correct. Isn't shape different if size is passed?
@@ -95,6 +97,7 @@ def rng_fn(
9597
ndarray
9698
"""
9799

100+
# TODO: Maybe we can remove this contraint?
98101
if steps is None or steps < 1:
99102
raise ValueError("Steps must be None or greater than 0")
100103

@@ -145,17 +148,26 @@ class GaussianRandomWalk(distribution.Continuous):
145148

146149
rv_op = gaussianrandomwalk
147150

148-
def __new__(cls, name, mu=0.0, sigma=1.0, init=None, steps: int = 1, **kwargs):
151+
def __new__(cls, name, mu=0.0, sigma=1.0, init=None, steps=None, **kwargs):
149152
check_dist_not_registered(init)
150153
return super().__new__(cls, name, mu, sigma, init, steps, **kwargs)
151154

152155
@classmethod
153156
def dist(
154-
cls, mu=0.0, sigma=1.0, init=None, steps: int = 1, size=None, **kwargs
157+
cls, mu=0.0, sigma=1.0, init=None, steps=None, size=None, shape=None, **kwargs
155158
) -> RandomVariable:
156159

157160
mu = at.as_tensor_variable(floatX(mu))
158161
sigma = at.as_tensor_variable(floatX(sigma))
162+
163+
if steps is None:
164+
# We can infer steps from the shape, if it was given
165+
if shape is not None:
166+
steps = to_tuple(shape)[-1] - 1
167+
else:
168+
# TODO: Raise ValueError?
169+
steps = 1
170+
159171
steps = at.as_tensor_variable(intX(steps))
160172

161173
if init is None:
@@ -175,7 +187,7 @@ def dist(
175187
mu_ = at.broadcast_arrays(mu, sigma)[0]
176188
init = change_rv_size(init, mu_.shape)
177189

178-
return super().dist([mu, sigma, init, steps], size=size, **kwargs)
190+
return super().dist([mu, sigma, init, steps], size=size, shape=shape, **kwargs)
179191

180192
def logp(
181193
value: at.Variable,

pymc/tests/test_distributions_timeseries.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,13 @@ def test_init_automatically_resized(self, size):
122122
init = x.owner.inputs[-2]
123123
assert init.eval().shape == size if size is not None else (2,)
124124

125+
@pytest.mark.parametrize("shape", (None, (6,), (3, 6)))
126+
def test_inferred_steps_from_shape(self, shape):
127+
with pm.Model() as m:
128+
x = GaussianRandomWalk("x", shape=shape)
129+
steps = x.owner.inputs[-1]
130+
assert steps.eval() == 5 if shape is not None else 1
131+
125132

126133
@pytest.mark.xfail(reason="Timeseries not refactored")
127134
def test_AR():

0 commit comments

Comments
 (0)