Skip to content

Commit e903875

Browse files
committed
Make steps required explicitly and check it is a scalar
1 parent c67dd8b commit e903875

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

pymc/distributions/timeseries.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ class GaussianRandomWalkRV(RandomVariable):
4949
dtype = "floatX"
5050
_print_name = ("GaussianRandomWalk", "\\operatorname{GaussianRandomWalk}")
5151

52+
def make_node(self, rng, size, dtype, mu, sigma, init, steps):
53+
steps = at.as_tensor_variable(steps)
54+
if not steps.ndim == 0 or not steps.dtype.startswith("float"):
55+
raise ValueError("steps must be an integer scalar (ndim=0).")
56+
5257
def _supp_shape_from_params(self, dist_params, reop_param_idx=0, param_shapes=None):
5358
steps = dist_params[3]
5459

@@ -94,8 +99,8 @@ def rng_fn(
9499
ndarray
95100
"""
96101

97-
if steps is None or steps < 1:
98-
raise ValueError("Steps must be None or greater than 0")
102+
if steps < 1:
103+
raise ValueError("Steps must be greater than 0")
99104

100105
# If size is None then the returned series should be (1+steps,)
101106
if size is None:
@@ -143,17 +148,19 @@ class GaussianRandomWalk(distribution.Continuous):
143148

144149
rv_op = gaussianrandomwalk
145150

146-
def __new__(cls, name, mu=0.0, sigma=1.0, init=None, steps: int = 1, **kwargs):
151+
def __new__(cls, name, mu=0.0, sigma=1.0, init=None, steps=None, **kwargs):
147152
check_dist_not_registered(init)
148153
return super().__new__(cls, name, mu, sigma, init, steps, **kwargs)
149154

150155
@classmethod
151156
def dist(
152-
cls, mu=0.0, sigma=1.0, init=None, steps: int = 1, size=None, **kwargs
157+
cls, mu=0.0, sigma=1.0, init=None, steps=None, size=None, **kwargs
153158
) -> at.TensorVariable:
154159

155160
mu = at.as_tensor_variable(floatX(mu))
156161
sigma = at.as_tensor_variable(floatX(sigma))
162+
if steps is None:
163+
raise ValueError("Must specify steps parameter")
157164
steps = at.as_tensor_variable(intX(steps))
158165

159166
if "shape" in kwargs.keys():

pymc/tests/test_distributions_timeseries.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ def check_rv_inferred_size(self):
5454
expected_symbolic = tuple(pymc_rv.shape.eval())
5555
assert expected_symbolic == expected
5656

57+
def test_steps_scalar_check(self):
58+
with pytest.raises(ValueError, match="steps must be an integer scalar"):
59+
self.pymc_dist.dist(steps=[1])
60+
5761

5862
def test_gaussianrandomwalk_inference():
5963
mu, sigma, steps = 2, 1, 1000

0 commit comments

Comments
 (0)