@@ -435,7 +435,7 @@ def __init__(
435
435
436
436
self .init = init
437
437
self .innovArgs = (mu , cov , tau , chol , lower )
438
- self .innov = multivariate .MvNormal .dist (* self .innovArgs )
438
+ self .innov = multivariate .MvNormal .dist (* self .innovArgs , shape = self . shape [ - 1 ] )
439
439
self .mean = tt .as_tensor_variable (0.0 )
440
440
441
441
def logp (self , x ):
@@ -469,7 +469,7 @@ def random(self, point=None, size=None):
469
469
point: dict, optional
470
470
Dict of variable values on which random values are to be
471
471
conditioned (uses default point if not specified).
472
- size: int, optional
472
+ size: int or tuple of ints , optional
473
473
Desired size of random sample (returns one sample if not
474
474
specified).
475
475
@@ -484,61 +484,44 @@ def random(self, point=None, size=None):
484
484
485
485
mu = np.array([1.0, 0.0])
486
486
cov = np.array([[1.0, 0.0], [0.0, 2.0]])
487
- sample = MvGaussianRandomWalk(mu, cov, shape=(10, 2)).random(size=1)
488
- """
487
+ sample = MvGaussianRandomWalk(mu, cov, shape=(10, 2)).random()
489
488
490
- param_attribute = getattr (
491
- self .innov , "chol_cov" if self .innov ._cov_type == "chol" else self .innov ._cov_type
492
- )
493
- mu , param = distribution .draw_values (
494
- [self .innov .mu , param_attribute ], point = point , size = size
495
- )
496
- return distribution .generate_samples (
497
- self ._random ,
498
- size = size ,
499
- dist_shape = self .shape ,
500
- not_broadcast_kwargs = {
501
- "sample_shape" : to_tuple (size ),
502
- "param" : param ,
503
- "mu" : mu ,
504
- "cov_type" : self .innov ._cov_type ,
505
- },
506
- )
489
+ Create three samples from a 2-dimensional Gaussian random walk with 10 timesteps::
507
490
508
- def _random (self , mu , param , size , sample_shape , cov_type ):
509
- """
510
- Implements the multivariate Gaussian random walk as a cumulative
511
- sum of i.i.d. multivariate Gaussians.
512
- Assumes that
513
- size is of the form (samples, time, dims).
491
+ mu = np.array([1.0, 0.0])
492
+ cov = np.array([[1.0, 0.0], [0.0, 2.0]])
493
+ sample = MvGaussianRandomWalk(mu, cov, shape=(10, 2)).random(size=3)
494
+
495
+ Create four samples from a 2-dimensional Gaussian random walk with 10
496
+ timesteps, indexed with a (2, 2) array::
497
+
498
+ mu = np.array([1.0, 0.0])
499
+ cov = np.array([[1.0, 0.0], [0.0, 2.0]])
500
+ sample = MvGaussianRandomWalk(mu, cov, shape=(10, 2)).random(size=(2, 2))
514
501
"""
515
502
516
- if cov_type == "chol" :
517
- cov = np .matmul (param , param .transpose ())
518
- elif cov_type == "tau" :
519
- cov = np .linalg .inv (param )
520
- else :
521
- cov = param
503
+ time_steps = self .shape [0 ]
504
+ size = to_tuple (size )
522
505
523
- # time axis comes after the sample axis
524
- time_axis = len (sample_shape )
506
+ # for each draw specified by the size input, we need to draw time_steps many
507
+ # samples from MvNormal.
508
+ size_time_steps = size + to_tuple (time_steps )
525
509
526
- # spatial axis is last
527
- spatial_axis = - 1
510
+ multivariate_samples = self . innov . random ( point = point , size = size_time_steps )
511
+ # this has shape (size, time_steps, MvNormal_shape)
528
512
529
- rv = stats . multivariate_normal ( mean = mu , cov = cov )
513
+ time_axis = len ( size )
530
514
531
- # only feed in sample and time dimensions since stats.multivariate_normal
532
- # automatically adds back in the spatial dimensions to the end when it samples.
533
- data = rv .rvs (size [:spatial_axis ]).cumsum (axis = time_axis )
515
+ multivariate_samples = multivariate_samples .cumsum (axis = time_axis )
534
516
535
517
# shift the walk to start at zero
536
- if len (data .shape ) > 2 :
537
- for i in range (size [0 ]):
538
- data [i ] = data [i ] - data [i ][0 ]
518
+ if len (multivariate_samples .shape ) > 2 :
519
+ # this for loop covers the case where size is a tuple
520
+ for idx in np .ndindex (size ):
521
+ multivariate_samples [idx ] = multivariate_samples [idx ] - multivariate_samples [idx ][0 ]
539
522
else :
540
- data = data - data [0 ]
541
- return data
523
+ multivariate_samples = multivariate_samples - multivariate_samples [0 ]
524
+ return multivariate_samples
542
525
543
526
544
527
class MvStudentTRandomWalk (MvGaussianRandomWalk ):
0 commit comments