Skip to content

Commit 9d6dfa9

Browse files
committed
Fix init dist and add tests
1 parent e903875 commit 9d6dfa9

File tree

2 files changed

+69
-19
lines changed

2 files changed

+69
-19
lines changed

pymc/distributions/timeseries.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,20 @@ def rng_fn(
104104

105105
# If size is None then the returned series should be (1+steps,)
106106
if size is None:
107-
init_size = 1
108-
steps_size = steps
107+
bcast_shape = np.broadcast_shapes(
108+
np.asarray(mu).shape,
109+
np.asarray(sigma).shape,
110+
np.asarray(init).shape,
111+
)
112+
dist_shape = (*bcast_shape, int(steps))
109113

110114
# If size is None then the returned series should be (size, 1+steps)
111115
else:
112116
init_size = (*size, 1)
113-
steps_size = (*size, steps)
114-
115-
init = np.reshape(init, init_size)
116-
steps = rng.normal(loc=mu, scale=sigma, size=steps_size)
117-
118-
grw = np.concatenate([init, steps], axis=-1)
117+
dist_shape = (*size, int(steps))
119118

119+
innovations = rng.normal(loc=mu, scale=sigma, size=dist_shape)
120+
grw = np.concatenate([init[..., None], innovations], axis=-1)
120121
return np.cumsum(grw, axis=-1)
121122

122123

@@ -149,7 +150,8 @@ class GaussianRandomWalk(distribution.Continuous):
149150
rv_op = gaussianrandomwalk
150151

151152
def __new__(cls, name, mu=0.0, sigma=1.0, init=None, steps=None, **kwargs):
152-
check_dist_not_registered(init)
153+
if init is not None:
154+
check_dist_not_registered(init)
153155
return super().__new__(cls, name, mu, sigma, init, steps, **kwargs)
154156

155157
@classmethod
@@ -163,14 +165,15 @@ def dist(
163165
raise ValueError("Must specify steps parameter")
164166
steps = at.as_tensor_variable(intX(steps))
165167

166-
if "shape" in kwargs.keys():
167-
shape = kwargs["shape"]
168+
shape = kwargs.get("shape", None)
169+
if size is None and shape is None:
170+
init_size = None
168171
else:
169-
shape = None
172+
init_size = to_tuple(size) if size is not None else to_tuple(shape)[:-1]
170173

171-
# If no scalar distribution is passed then initialize with a Normal of same sd and mu
174+
# If no scalar distribution is passed then initialize with a Normal of same mu and sigma
172175
if init is None:
173-
init = Normal.dist(mu, sigma, size=size)
176+
init = Normal.dist(mu, sigma, size=init_size)
174177
else:
175178
if not (
176179
isinstance(init, at.TensorVariable)
@@ -180,12 +183,12 @@ def dist(
180183
):
181184
raise TypeError("init must be a univariate distribution variable")
182185

183-
if size is not None or shape is not None:
184-
init = change_rv_size(init, to_tuple(size or shape))
186+
if init_size is not None:
187+
init = change_rv_size(init, init_size)
185188
else:
186-
# If not explicit, size is determined by the shape of mu and sigma
187-
mu_ = at.broadcast_arrays(mu, sigma)[0]
188-
init = change_rv_size(init, mu_.shape)
189+
# If not explicit, size is determined by the shapes of mu, sigma, and init
190+
bcast_shape = at.broadcast_arrays(mu, sigma, init)[0].shape
191+
init = change_rv_size(init, bcast_shape)
189192

190193
# Ignores logprob of init var because that's accounted for in the logp method
191194
init.tag.ignore_logprob = True

pymc/tests/test_distributions_timeseries.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import numpy as np
1515
import pytest
16+
import scipy.stats
1617

1718
import pymc as pm
1819

@@ -77,6 +78,52 @@ def test_gaussianrandomwalk_inference():
7778
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
7879

7980

81+
@pytest.mark.parametrize("init", [None, pm.Normal.dist()])
82+
def test_gaussian_random_walk_init_dist_shape(init):
83+
"""Test that init_dist is properly resized"""
84+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init)
85+
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
86+
87+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, size=(5,))
88+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
89+
90+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=1)
91+
assert tuple(grw.owner.inputs[-2].shape.eval()) == ()
92+
93+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=1, init=init, shape=(5, 1))
94+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (5,)
95+
96+
grw = pm.GaussianRandomWalk.dist(mu=[0, 0], sigma=1, steps=1, init=init)
97+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
98+
99+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=[1, 1], steps=1, init=init)
100+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (2,)
101+
102+
grw = pm.GaussianRandomWalk.dist(mu=np.zeros((3, 1)), sigma=[1, 1], steps=1, init=init)
103+
assert tuple(grw.owner.inputs[-2].shape.eval()) == (3, 2)
104+
105+
106+
def test_gaussianrandomwalk_broadcasted_by_init_dist():
107+
grw = pm.GaussianRandomWalk.dist(mu=0, sigma=1, steps=4, init=pm.Normal.dist(size=(2, 3)))
108+
assert tuple(grw.shape.eval()) == (2, 3, 5)
109+
assert grw.eval().shape == (2, 3, 5)
110+
111+
112+
@pytest.mark.parametrize(
113+
"init",
114+
[
115+
pm.HalfNormal.dist(sigma=2),
116+
pm.StudentT.dist(nu=4, mu=1, sigma=0.5),
117+
],
118+
)
119+
def test_gaussian_random_walk_init_dist_logp(init):
120+
grw = pm.GaussianRandomWalk.dist(init=init, steps=1)
121+
assert np.isclose(
122+
pm.logp(grw, [0, 0]).eval(),
123+
pm.logp(init, 0).eval() + scipy.stats.norm.logpdf(0),
124+
)
125+
126+
80127
@pytest.mark.xfail(reason="Timeseries not refactored")
81128
def test_AR():
82129
# AR1

0 commit comments

Comments
 (0)