@@ -461,6 +461,74 @@ def _distr_parameters_for_repr(self):
461
461
return ["mu" , "cov" ]
462
462
463
463
464
+ def random (self , point = None , size = None ):
465
+ """
466
+ Draw random values from MvGaussianRandomWalk.
467
+
468
+ Parameters
469
+ ----------
470
+ point: dict, optional
471
+ Dict of variable values on which random values are to be
472
+ conditioned (uses default point if not specified).
473
+ size: int, optional
474
+ Desired size of random sample (returns one sample if not
475
+ specified).
476
+
477
+ Returns
478
+ -------
479
+ array
480
+ """
481
+
482
+ param_attribute = getattr (self .innov , "chol_cov" if self .innov ._cov_type == "chol" else self .innov ._cov_type )
483
+ mu , param = distribution .draw_values ([self .innov .mu , param_attribute ], point = point , size = size )
484
+ return distribution .generate_samples (
485
+ self ._random ,
486
+ size = size ,
487
+ dist_shape = self .shape ,
488
+ not_broadcast_kwargs = {
489
+ "sample_shape" : to_tuple (size ),
490
+ "param" : param ,
491
+ "mu" : mu ,
492
+ "cov_type" : self .innov ._cov_type
493
+ }
494
+ )
495
+
496
+ def _random (self , mu , param , size , sample_shape , cov_type ):
497
+ """
498
+ Implements the multivariate Gaussian random walk as a cumulative
499
+ sum of i.i.d. multivariate Gaussians.
500
+ Assumes that
501
+ size is of the form (samples, time, dims).
502
+ """
503
+
504
+ if cov_type == "chol" :
505
+ cov = np .matmul (param , param .transpose ())
506
+ elif cov_type == "tau" :
507
+ cov = np .linalg .inv (param )
508
+ else :
509
+ cov = param
510
+
511
+ # time axis comes after the sample axis
512
+ time_axis = len (sample_shape )
513
+
514
+ # spatial axis is last
515
+ spatial_axis = - 1
516
+
517
+ rv = stats .multivariate_normal (mean = mu , cov = cov )
518
+
519
+ # only feed in sample and time dimensions since stats.multivariate_normal
520
+ # automatically adds back in the spatial dimensions to the end when it samples.
521
+ data = rv .rvs (size [:spatial_axis ]).cumsum (axis = time_axis )
522
+
523
+ # shift the walk to start at zero
524
+ if len (data .shape ) > 2 :
525
+ for i in range (size [0 ]):
526
+ data [i ] = data [i ] - data [i ][0 ]
527
+ else :
528
+ data = data - data [0 ]
529
+ return data
530
+
531
+
464
532
class MvStudentTRandomWalk (MvGaussianRandomWalk ):
465
533
r"""
466
534
Multivariate Random Walk with StudentT innovations
0 commit comments