@@ -104,19 +104,20 @@ def rng_fn(
104
104
105
105
# If size is None then the returned series should be (1+steps,)
106
106
if size is None :
107
- init_size = 1
108
- steps_size = steps
107
+ bcast_shape = np .broadcast_shapes (
108
+ np .asarray (mu ).shape ,
109
+ np .asarray (sigma ).shape ,
110
+ np .asarray (init ).shape ,
111
+ )
112
+ dist_shape = (* bcast_shape , int (steps ))
109
113
110
114
# If size is None then the returned series should be (size, 1+steps)
111
115
else :
112
116
init_size = (* size , 1 )
113
- steps_size = (* size , steps )
114
-
115
- init = np .reshape (init , init_size )
116
- steps = rng .normal (loc = mu , scale = sigma , size = steps_size )
117
-
118
- grw = np .concatenate ([init , steps ], axis = - 1 )
117
+ dist_shape = (* size , int (steps ))
119
118
119
+ innovations = rng .normal (loc = mu , scale = sigma , size = dist_shape )
120
+ grw = np .concatenate ([init [..., None ], innovations ], axis = - 1 )
120
121
return np .cumsum (grw , axis = - 1 )
121
122
122
123
@@ -149,7 +150,8 @@ class GaussianRandomWalk(distribution.Continuous):
149
150
rv_op = gaussianrandomwalk
150
151
151
152
def __new__ (cls , name , mu = 0.0 , sigma = 1.0 , init = None , steps = None , ** kwargs ):
152
- check_dist_not_registered (init )
153
+ if init is not None :
154
+ check_dist_not_registered (init )
153
155
return super ().__new__ (cls , name , mu , sigma , init , steps , ** kwargs )
154
156
155
157
@classmethod
@@ -163,14 +165,15 @@ def dist(
163
165
raise ValueError ("Must specify steps parameter" )
164
166
steps = at .as_tensor_variable (intX (steps ))
165
167
166
- if "shape" in kwargs .keys ():
167
- shape = kwargs ["shape" ]
168
+ shape = kwargs .get ("shape" , None )
169
+ if size is None and shape is None :
170
+ init_size = None
168
171
else :
169
- shape = None
172
+ init_size = to_tuple ( size ) if size is not None else to_tuple ( shape )[: - 1 ]
170
173
171
- # If no scalar distribution is passed then initialize with a Normal of same sd and mu
174
+ # If no scalar distribution is passed then initialize with a Normal of same mu and sigma
172
175
if init is None :
173
- init = Normal .dist (mu , sigma , size = size )
176
+ init = Normal .dist (mu , sigma , size = init_size )
174
177
else :
175
178
if not (
176
179
isinstance (init , at .TensorVariable )
@@ -180,12 +183,12 @@ def dist(
180
183
):
181
184
raise TypeError ("init must be a univariate distribution variable" )
182
185
183
- if size is not None or shape is not None :
184
- init = change_rv_size (init , to_tuple ( size or shape ) )
186
+ if init_size is not None :
187
+ init = change_rv_size (init , init_size )
185
188
else :
186
- # If not explicit, size is determined by the shape of mu and sigma
187
- mu_ = at .broadcast_arrays (mu , sigma )[0 ]
188
- init = change_rv_size (init , mu_ . shape )
189
+ # If not explicit, size is determined by the shapes of mu, sigma, and init
190
+ bcast_shape = at .broadcast_arrays (mu , sigma , init )[0 ]. shape
191
+ init = change_rv_size (init , bcast_shape )
189
192
190
193
# Ignores logprob of init var because that's accounted for in the logp method
191
194
init .tag .ignore_logprob = True
0 commit comments