45
45
from aesara .sandbox .rng_mrg import MRG_RandomStream as RandomStream
46
46
from aesara .tensor .elemwise import Elemwise
47
47
from aesara .tensor .random .op import RandomVariable
48
+ from aesara .tensor .shape import SpecifyShape
48
49
from aesara .tensor .sharedvar import SharedVariable
49
50
from aesara .tensor .subtensor import AdvancedIncSubtensor , AdvancedIncSubtensor1
50
51
from aesara .tensor .var import TensorVariable
51
52
53
+ from pymc3 .exceptions import ShapeError
52
54
from pymc3 .vartypes import continuous_types , int_types , isgenerator , typefilter
53
55
54
56
PotentialShapeType = Union [
@@ -146,6 +148,16 @@ def change_rv_size(
146
148
Expand the existing size by `new_size`.
147
149
148
150
"""
151
+ # Check the dimensionality of the `new_size` kwarg
152
+ new_size_ndim = np .ndim (new_size )
153
+ if new_size_ndim > 1 :
154
+ raise ShapeError ("The `new_size` must be ≤1-dimensional." , actual = new_size_ndim )
155
+ elif new_size_ndim == 0 :
156
+ new_size = (new_size ,)
157
+
158
+ # Extract the RV node that is to be resized, together with its inputs, name and tag
159
+ if isinstance (rv_var .owner .op , SpecifyShape ):
160
+ rv_var = rv_var .owner .inputs [0 ]
149
161
rv_node = rv_var .owner
150
162
rng , size , dtype , * dist_params = rv_node .inputs
151
163
name = rv_var .name
@@ -154,10 +166,10 @@ def change_rv_size(
154
166
if expand :
155
167
if rv_node .op .ndim_supp == 0 and at .get_vector_length (size ) == 0 :
156
168
size = rv_node .op ._infer_shape (size , dist_params )
157
- new_size = tuple (np . atleast_1d ( new_size ) ) + tuple (size )
169
+ new_size = tuple (new_size ) + tuple (size )
158
170
159
- # Make sure the new size is a tensor. This helps to not unnecessarily pick
160
- # up a `Cast` in some cases
171
+ # Make sure the new size is a tensor. This dtype-aware conversion helps
172
+ # to not unnecessarily pick up a `Cast` in some cases (see #4652).
161
173
new_size = at .as_tensor (new_size , ndim = 1 , dtype = "int64" )
162
174
163
175
new_rv_node = rv_node .op .make_node (rng , new_size , dtype , * dist_params )
0 commit comments