@@ -49,7 +49,9 @@ class GaussianRandomWalkRV(RandomVariable):
49
49
dtype = "floatX"
50
50
_print_name = ("GaussianRandomWalk" , "\\ operatorname{GaussianRandomWalk}" )
51
51
52
- def _shape_from_params (self , dist_params , reop_param_idx = 0 , param_shapes = None ):
52
+ # TODO: Assert steps is a scalar!
53
+
54
+ def _shape_from_params (self , dist_params , ** kwargs ):
53
55
steps = dist_params [3 ]
54
56
55
57
# TODO: Ask ricardo why this is correct. Isn't shape different if size is passed?
@@ -95,6 +97,7 @@ def rng_fn(
95
97
ndarray
96
98
"""
97
99
100
+ # TODO: Maybe we can remove this contraint?
98
101
if steps is None or steps < 1 :
99
102
raise ValueError ("Steps must be None or greater than 0" )
100
103
@@ -145,17 +148,26 @@ class GaussianRandomWalk(distribution.Continuous):
145
148
146
149
rv_op = gaussianrandomwalk
147
150
148
- def __new__ (cls , name , mu = 0.0 , sigma = 1.0 , init = None , steps : int = 1 , ** kwargs ):
151
+ def __new__ (cls , name , mu = 0.0 , sigma = 1.0 , init = None , steps = None , ** kwargs ):
149
152
check_dist_not_registered (init )
150
153
return super ().__new__ (cls , name , mu , sigma , init , steps , ** kwargs )
151
154
152
155
@classmethod
153
156
def dist (
154
- cls , mu = 0.0 , sigma = 1.0 , init = None , steps : int = 1 , size = None , ** kwargs
157
+ cls , mu = 0.0 , sigma = 1.0 , init = None , steps = None , size = None , shape = None , ** kwargs
155
158
) -> RandomVariable :
156
159
157
160
mu = at .as_tensor_variable (floatX (mu ))
158
161
sigma = at .as_tensor_variable (floatX (sigma ))
162
+
163
+ if steps is None :
164
+ # We can infer steps from the shape, if it was given
165
+ if shape is not None :
166
+ steps = to_tuple (shape )[- 1 ] - 1
167
+ else :
168
+ # TODO: Raise ValueError?
169
+ steps = 1
170
+
159
171
steps = at .as_tensor_variable (intX (steps ))
160
172
161
173
if init is None :
@@ -175,7 +187,7 @@ def dist(
175
187
mu_ = at .broadcast_arrays (mu , sigma )[0 ]
176
188
init = change_rv_size (init , mu_ .shape )
177
189
178
- return super ().dist ([mu , sigma , init , steps ], size = size , ** kwargs )
190
+ return super ().dist ([mu , sigma , init , steps ], size = size , shape = shape , ** kwargs )
179
191
180
192
def logp (
181
193
value : at .Variable ,
0 commit comments