Skip to content

Commit 512822a

Browse files
committed
Implement random for MvGaussianRandomWalk
Implements the random method for MvGaussianRandomWalk, partially fixing #4337.
1 parent acb8da0 commit 512822a

File tree

1 file changed

+68
-0
lines changed

1 file changed

+68
-0
lines changed

pymc3/distributions/timeseries.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,74 @@ def _distr_parameters_for_repr(self):
461461
return ["mu", "cov"]
462462

463463

464+
def random(self, point=None, size=None):
465+
"""
466+
Draw random values from MvGaussianRandomWalk.
467+
468+
Parameters
469+
----------
470+
point: dict, optional
471+
Dict of variable values on which random values are to be
472+
conditioned (uses default point if not specified).
473+
size: int, optional
474+
Desired size of random sample (returns one sample if not
475+
specified).
476+
477+
Returns
478+
-------
479+
array
480+
"""
481+
482+
param_attribute = getattr(self.innov, "chol_cov" if self.innov._cov_type == "chol" else self.innov._cov_type)
483+
mu, param = distribution.draw_values([self.innov.mu, param_attribute], point=point, size=size)
484+
return distribution.generate_samples(
485+
self._random,
486+
size=size,
487+
dist_shape=self.shape,
488+
not_broadcast_kwargs={
489+
"sample_shape": to_tuple(size),
490+
"param": param,
491+
"mu": mu,
492+
"cov_type": self.innov._cov_type
493+
}
494+
)
495+
496+
def _random(self, mu, param, size, sample_shape, cov_type):
497+
"""
498+
Implements the multivariate Gaussian random walk as a cumulative
499+
sum of i.i.d. multivariate Gaussians.
500+
Assumes that
501+
size is of the form (samples, time, dims).
502+
"""
503+
504+
if cov_type == "chol":
505+
cov = np.matmul(param, param.transpose())
506+
elif cov_type == "tau":
507+
cov = np.linalg.inv(param)
508+
else:
509+
cov = param
510+
511+
# time axis comes after the sample axis
512+
time_axis = len(sample_shape)
513+
514+
# spatial axis is last
515+
spatial_axis = -1
516+
517+
rv = stats.multivariate_normal(mean=mu, cov=cov)
518+
519+
# only feed in sample and time dimensions since stats.multivariate_normal
520+
# automatically adds back in the spatial dimensions to the end when it samples.
521+
data = rv.rvs(size[:spatial_axis]).cumsum(axis=time_axis)
522+
523+
# shift the walk to start at zero
524+
if len(data.shape) > 2:
525+
for i in range(size[0]):
526+
data[i] = data[i] - data[i][0]
527+
else:
528+
data = data - data[0]
529+
return data
530+
531+
464532
class MvStudentTRandomWalk(MvGaussianRandomWalk):
465533
r"""
466534
Multivariate Random Walk with StudentT innovations

0 commit comments

Comments
 (0)