Skip to content

Commit d23458e

Browse files
committed
Allow init to be a distribution again
1 parent 31a37e6 commit d23458e

File tree

3 files changed

+53
-15
lines changed

3 files changed

+53
-15
lines changed

pymc/distributions/timeseries.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from aesara import scan
2121
from aesara.tensor.random.op import RandomVariable
2222

23-
from pymc.aesaraf import floatX, intX
23+
from pymc.aesaraf import change_rv_size, floatX, intX
2424
from pymc.distributions import distribution, logprob, multivariate
2525
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
2626
from pymc.distributions.shape_utils import to_tuple
@@ -35,6 +35,8 @@
3535
"MvStudentTRandomWalk",
3636
]
3737

38+
from pymc.util import check_dist_not_registered
39+
3840

3941
class GaussianRandomWalkRV(RandomVariable):
4042
"""
@@ -107,10 +109,10 @@ def rng_fn(
107109
init_size = (*size, 1)
108110
steps_size = (*size, steps)
109111

110-
init_val = rng.normal(init, sigma, size=init_size)
112+
init = np.reshape(init, init_size)
111113
steps = rng.normal(loc=mu, scale=sigma, size=steps_size)
112114

113-
grw = np.concatenate([init_val, steps], axis=-1)
115+
grw = np.concatenate([init, steps], axis=-1)
114116

115117
return np.cumsum(grw, axis=-1)
116118

@@ -132,8 +134,9 @@ class GaussianRandomWalk(distribution.Continuous):
132134
innovation drift, defaults to 0.0
133135
sigma: tensor_like of float, optional
134136
sigma > 0, innovation standard deviation, defaults to 0.0
135-
init: tensor_like of float, optional
136-
Mean value of initialization, defaults to 0.0
137+
init: Scalar PyMC distribution
138+
Scalar distribution of the initial value, created with the `.dist()` API. Defaults to
139+
Normal with same `mu` and `sigma` as the GaussianRandomWalk
137140
steps: int
138141
Number of steps in Gaussian Random Walks
139142
size: int
@@ -142,14 +145,37 @@ class GaussianRandomWalk(distribution.Continuous):
142145

143146
rv_op = gaussianrandomwalk
144147

148+
def __new__(cls, name, mu=0.0, sigma=1.0, init=None, steps: int = 1, **kwargs):
149+
check_dist_not_registered(init)
150+
return super().__new__(cls, name, mu, sigma, init, steps, **kwargs)
151+
145152
@classmethod
146-
def dist(cls, mu=0.0, sigma=1.0, *, steps: int, init=0.0, **kwargs) -> RandomVariable:
153+
def dist(
154+
cls, mu=0.0, sigma=1.0, init=None, steps: int = 1, size=None, **kwargs
155+
) -> RandomVariable:
156+
157+
mu = at.as_tensor_variable(floatX(mu))
158+
sigma = at.as_tensor_variable(floatX(sigma))
159+
steps = at.as_tensor_variable(intX(steps))
147160

148-
params = [at.as_tensor_variable(floatX(param)) for param in (mu, sigma, init)] + [
149-
at.as_tensor_variable(intX(steps))
150-
]
161+
if init is None:
162+
init = Normal.dist(mu, sigma, size=size)
163+
else:
164+
if not (
165+
isinstance(init, at.TensorVariable)
166+
and init.owner is not None
167+
and isinstance(init.owner.op, RandomVariable)
168+
and init.owner.op.ndim_supp == 0
169+
):
170+
raise TypeError("init must be a scalar distribution variable")
171+
if size is not None or shape is not None:
172+
init = change_rv_size(init, to_tuple(size or shape))
173+
else:
174+
# If not explicit, size is determined by the shape of mu and sigma
175+
mu_ = at.broadcast_arrays(mu, sigma)[0]
176+
init = change_rv_size(init, mu_.shape)
151177

152-
return super().dist(params, **kwargs)
178+
return super().dist([mu, sigma, init, steps], size=size, **kwargs)
153179

154180
def logp(
155181
value: at.Variable,
@@ -174,11 +200,13 @@ def logp(
174200
"""
175201

176202
# Calculate initialization logp
203+
init_logp = logprob.logp(init, value[..., 0])
204+
177205
# Make time series stationary around the mean value
178-
stationary_series = at.diff(value)
206+
stationary_series = at.diff(value, axis=-1)
179207
series_logp = logprob.logp(Normal.dist(mu, sigma), stationary_series)
180208

181-
return series_logp
209+
return init_logp + series_logp.sum(axis=-1)
182210

183211

184212
class AR1(distribution.Continuous):

pymc/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2622,7 +2622,7 @@ def test_grw(self):
26222622
pm.GaussianRandomWalk,
26232623
R,
26242624
{"mu": R, "sigma": Rplus, "steps": Nat},
2625-
lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma).cumsum(),
2625+
lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma).cumsum().sum(),
26262626
decimal=select_by_precision(float64=6, float32=1),
26272627
)
26282628

pymc/tests/test_distributions_timeseries.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_grw_logp(self):
6161
vals = [0, 1, 2]
6262
mu = 1
6363
sigma = 1
64-
init = 0
64+
init = pm.Normal.dist(mu, sigma)
6565

6666
with pm.Model():
6767
grw = GaussianRandomWalk("grw", mu, sigma, init=init, steps=2)
@@ -88,7 +88,7 @@ def test_grw_inference(self):
8888
with pm.Model():
8989
_mu = pm.Uniform("mu", -10, 10)
9090
_sigma = pm.Uniform("sigma", 0, 10)
91-
grw = GaussianRandomWalk("grw", _mu, _sigma, init=0, steps=steps, observed=obs)
91+
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs)
9292

9393
with pytest.raises(TypeError) as err:
9494
trace = pm.sample()
@@ -113,6 +113,16 @@ def test_grw_shape(self, steps, size, expected):
113113
expected_symbolic = tuple(grw_dist.shape.eval())
114114
assert expected_symbolic == expected
115115

116+
@pytest.mark.parametrize("size", (None, (1, 2), (10, 2), (3, 100, 2)))
117+
def test_init_automatically_resized(self, size):
118+
x = GaussianRandomWalk.dist(mu=[0, 1], init=pm.Normal.dist(), size=size)
119+
init = x.owner.inputs[-2]
120+
assert init.eval().shape == size if size is not None else (2,)
121+
122+
x = GaussianRandomWalk.dist(mu=[0, 1], init=pm.Normal.dist(size=5), shape=size)
123+
init = x.owner.inputs[-2]
124+
assert init.eval().shape == size if size is not None else (2,)
125+
116126

117127
@pytest.mark.xfail(reason="Timeseries not refactored")
118128
def test_AR():

0 commit comments

Comments
 (0)