Skip to content

Commit 6272f28

Browse files
michaelosthegericardoV94twiecki
committed
Make change_rv_size more robust
Co-authored-by: Ricardo <[email protected]> Co-authored-by: Thomas Wiecki <[email protected]>
1 parent 60347c0 commit 6272f28

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

pymc3/aesaraf.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@
4545
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
4646
from aesara.tensor.elemwise import Elemwise
4747
from aesara.tensor.random.op import RandomVariable
48+
from aesara.tensor.shape import SpecifyShape
4849
from aesara.tensor.sharedvar import SharedVariable
4950
from aesara.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
5051
from aesara.tensor.var import TensorVariable
5152

53+
from pymc3.exceptions import ShapeError
5254
from pymc3.vartypes import continuous_types, int_types, isgenerator, typefilter
5355

5456
PotentialShapeType = Union[
@@ -146,6 +148,16 @@ def change_rv_size(
146148
Expand the existing size by `new_size`.
147149
148150
"""
151+
# Check the dimensionality of the `new_size` kwarg
152+
new_size_ndim = np.ndim(new_size)
153+
if new_size_ndim > 1:
154+
raise ShapeError("The `new_size` must be ≤1-dimensional.", actual=new_size_ndim)
155+
elif new_size_ndim == 0:
156+
new_size = (new_size,)
157+
158+
# Extract the RV node that is to be resized, together with its inputs, name and tag
159+
if isinstance(rv_var.owner.op, SpecifyShape):
160+
rv_var = rv_var.owner.inputs[0]
149161
rv_node = rv_var.owner
150162
rng, size, dtype, *dist_params = rv_node.inputs
151163
name = rv_var.name
@@ -154,10 +166,10 @@ def change_rv_size(
154166
if expand:
155167
if rv_node.op.ndim_supp == 0 and at.get_vector_length(size) == 0:
156168
size = rv_node.op._infer_shape(size, dist_params)
157-
new_size = tuple(np.atleast_1d(new_size)) + tuple(size)
169+
new_size = tuple(new_size) + tuple(size)
158170

159-
# Make sure the new size is a tensor. This helps to not unnecessarily pick
160-
# up a `Cast` in some cases
171+
# Make sure the new size is a tensor. This dtype-aware conversion helps
172+
# to not unnecessarily pick up a `Cast` in some cases (see #4652).
161173
new_size = at.as_tensor(new_size, ndim=1, dtype="int64")
162174

163175
new_rv_node = rv_node.op.make_node(rng, new_size, dtype, *dist_params)

pymc3/tests/test_aesaraf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
take_along_axis,
4242
walk_model,
4343
)
44+
from pymc3.exceptions import ShapeError
4445
from pymc3.vartypes import int_types
4546

4647
FLOATX = str(aesara.config.floatX)
@@ -53,6 +54,11 @@ def test_change_rv_size():
5354
assert rv.ndim == 1
5455
assert rv.eval().shape == (2,)
5556

57+
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
58+
change_rv_size(rv, new_size=[[2, 3]])
59+
with pytest.raises(ShapeError, match="must be ≤1-dimensional"):
60+
change_rv_size(rv, new_size=at.as_tensor_variable([[2, 3], [4, 5]]))
61+
5662
rv_new = change_rv_size(rv, new_size=(3,), expand=True)
5763
assert rv_new.ndim == 2
5864
assert rv_new.eval().shape == (3, 2)

pymc3/tests/test_distributions_random.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,10 @@ def get_random_variable(self, shape, with_vector_params=False, name=None):
187187

188188
@staticmethod
189189
def sample_random_variable(random_variable, size):
190-
"""Draws samples from a RandomVariable using its .random() method."""
191-
if size is None:
192-
return random_variable.eval()
193-
else:
194-
return change_rv_size(random_variable, size, expand=True).eval()
190+
""" Draws samples from a RandomVariable. """
191+
if size:
192+
random_variable = change_rv_size(random_variable, size, expand=True)
193+
return random_variable.eval()
195194

196195
@pytest.mark.parametrize("size", [None, (), 1, (1,), 5, (4, 5)], ids=str)
197196
@pytest.mark.parametrize("shape", [None, ()], ids=str)

0 commit comments

Comments
 (0)