Skip to content

Commit c3cc70c

Browse files
committed
Modify MvGaussianRandomWalk.random
Improves the implementation by using MvNormal.random as suggested by @Sayam753. Also updated its docstring to add more examples.
1 parent fe33fe5 commit c3cc70c

File tree

1 file changed

+29
-46
lines changed

1 file changed

+29
-46
lines changed

pymc3/distributions/timeseries.py

Lines changed: 29 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def __init__(
435435

436436
self.init = init
437437
self.innovArgs = (mu, cov, tau, chol, lower)
438-
self.innov = multivariate.MvNormal.dist(*self.innovArgs)
438+
self.innov = multivariate.MvNormal.dist(*self.innovArgs, shape=self.shape[-1])
439439
self.mean = tt.as_tensor_variable(0.0)
440440

441441
def logp(self, x):
@@ -469,7 +469,7 @@ def random(self, point=None, size=None):
469469
point: dict, optional
470470
Dict of variable values on which random values are to be
471471
conditioned (uses default point if not specified).
472-
size: int, optional
472+
size: int or tuple of ints, optional
473473
Desired size of random sample (returns one sample if not
474474
specified).
475475
@@ -484,61 +484,44 @@ def random(self, point=None, size=None):
484484
485485
mu = np.array([1.0, 0.0])
486486
cov = np.array([[1.0, 0.0], [0.0, 2.0]])
487-
sample = MvGaussianRandomWalk(mu, cov, shape=(10, 2)).random(size=1)
488-
"""
487+
sample = MvGaussianRandomWalk(mu, cov, shape=(10, 2)).random()
489488
490-
param_attribute = getattr(
491-
self.innov, "chol_cov" if self.innov._cov_type == "chol" else self.innov._cov_type
492-
)
493-
mu, param = distribution.draw_values(
494-
[self.innov.mu, param_attribute], point=point, size=size
495-
)
496-
return distribution.generate_samples(
497-
self._random,
498-
size=size,
499-
dist_shape=self.shape,
500-
not_broadcast_kwargs={
501-
"sample_shape": to_tuple(size),
502-
"param": param,
503-
"mu": mu,
504-
"cov_type": self.innov._cov_type,
505-
},
506-
)
489+
Create three samples from a 2-dimensional Gaussian random walk with 10 timesteps::
507490
508-
def _random(self, mu, param, size, sample_shape, cov_type):
509-
"""
510-
Implements the multivariate Gaussian random walk as a cumulative
511-
sum of i.i.d. multivariate Gaussians.
512-
Assumes that
513-
size is of the form (samples, time, dims).
491+
mu = np.array([1.0, 0.0])
492+
cov = np.array([[1.0, 0.0], [0.0, 2.0]])
493+
sample = MvGaussianRandomWalk(mu, cov, shape=(10, 2)).random(size=3)
494+
495+
Create four samples from a 2-dimensional Gaussian random walk with 10
496+
timesteps, indexed with a (2, 2) array::
497+
498+
mu = np.array([1.0, 0.0])
499+
cov = np.array([[1.0, 0.0], [0.0, 2.0]])
500+
sample = MvGaussianRandomWalk(mu, cov, shape=(10, 2)).random(size=(2, 2))
514501
"""
515502

516-
if cov_type == "chol":
517-
cov = np.matmul(param, param.transpose())
518-
elif cov_type == "tau":
519-
cov = np.linalg.inv(param)
520-
else:
521-
cov = param
503+
time_steps = self.shape[0]
504+
size = to_tuple(size)
522505

523-
# time axis comes after the sample axis
524-
time_axis = len(sample_shape)
506+
# for each draw specified by the size input, we need to draw time_steps many
507+
# samples from MvNormal.
508+
size_time_steps = size + to_tuple(time_steps)
525509

526-
# spatial axis is last
527-
spatial_axis = -1
510+
multivariate_samples = self.innov.random(point=point, size=size_time_steps)
511+
# this has shape (size, time_steps, MvNormal_shape)
528512

529-
rv = stats.multivariate_normal(mean=mu, cov=cov)
513+
time_axis = len(size)
530514

531-
# only feed in sample and time dimensions since stats.multivariate_normal
532-
# automatically adds back in the spatial dimensions to the end when it samples.
533-
data = rv.rvs(size[:spatial_axis]).cumsum(axis=time_axis)
515+
multivariate_samples = multivariate_samples.cumsum(axis=time_axis)
534516

535517
# shift the walk to start at zero
536-
if len(data.shape) > 2:
537-
for i in range(size[0]):
538-
data[i] = data[i] - data[i][0]
518+
if len(multivariate_samples.shape) > 2:
519+
# this for loop covers the case where size is a tuple
520+
for idx in np.ndindex(size):
521+
multivariate_samples[idx] = multivariate_samples[idx] - multivariate_samples[idx][0]
539522
else:
540-
data = data - data[0]
541-
return data
523+
multivariate_samples = multivariate_samples - multivariate_samples[0]
524+
return multivariate_samples
542525

543526

544527
class MvStudentTRandomWalk(MvGaussianRandomWalk):

0 commit comments

Comments
 (0)