Skip to content

Commit 2a5b419

Browse files
committed
Flag Ops whose output types depend on input values
These nodes must always be rebuilt in non-strict mode
1 parent c36f731 commit 2a5b419

File tree

7 files changed

+62
-2
lines changed

7 files changed

+62
-2
lines changed

pytensor/graph/basic.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,8 +266,15 @@ def clone_with_new_inputs(
266266
assert isinstance(inputs, (list, tuple))
267267
remake_node = False
268268
new_inputs: List["Variable"] = list(inputs)
269+
270+
# Some Ops like Alloc require the node to always be rebuilt
271+
# as the output type depends on the input values and not just their types
272+
output_type_depends_on_input_value = getattr(
273+
self.op, "_output_type_depends_on_input_value", False
274+
)
275+
269276
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
270-
if curr.type != new.type:
277+
if (curr.type != new.type) or output_type_depends_on_input_value:
271278
if strict:
272279
new_i = curr.type.filter_variable(new)
273280
new_inputs[i] = new_i

pytensor/tensor/basic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,6 +1418,8 @@ class Alloc(COp):
14181418
"""
14191419

14201420
_f16_ok = True
1421+
_output_type_depends_on_input_value = True
1422+
14211423
__props__ = ()
14221424

14231425
def make_node(self, value, *shape):
@@ -3817,6 +3819,8 @@ def perform(self, node, inputs, outputs):
38173819
class AllocEmpty(COp):
38183820
"""Implement Alloc on the cpu, but without initializing memory."""
38193821

3822+
_output_type_depends_on_input_value = True
3823+
38203824
__props__ = ("dtype",)
38213825
params_type = ParamsType(typecode=int32)
38223826

pytensor/tensor/extra_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1630,6 +1630,8 @@ def broadcast_shape_iter(
16301630
class BroadcastTo(COp):
16311631
"""An `Op` for `numpy.broadcast_to`."""
16321632

1633+
_output_type_depends_on_input_value = True
1634+
16331635
__props__ = ()
16341636

16351637
view_map = {0: [0]}

pytensor/tensor/random/op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ class RandomVariable(Op):
9191
9292
"""
9393

94+
_output_type_depends_on_input_value = True
95+
9496
__props__ = ("name", "ndim_supp", "ndims_params", "dtype", "inplace")
9597
default_output = 1
9698

tests/tensor/random/test_basic.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytensor.graph.basic import Constant, Variable, graph_inputs
1515
from pytensor.graph.fg import FunctionGraph
1616
from pytensor.graph.op import get_test_value
17+
from pytensor.graph.replace import clone_replace
1718
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
1819
from pytensor.tensor.random.basic import (
1920
bernoulli,
@@ -57,7 +58,7 @@
5758
weibull,
5859
)
5960
from pytensor.tensor.rewriting.shape import ShapeFeature
60-
from pytensor.tensor.type import iscalar, scalar, tensor
61+
from pytensor.tensor.type import iscalar, scalar, vector, tensor
6162
from tests.unittest_tools import create_pytensor_param
6263

6364

@@ -1422,3 +1423,17 @@ def test_pickle():
14221423
a_unpkl = pickle.loads(a_pkl)
14231424

14241425
assert a_unpkl.owner.op._props() == sample_a.owner.op._props()
1426+
1427+
1428+
def test_rebuild():
1429+
x = vector(shape=(50,))
1430+
x_test = np.zeros((50,))
1431+
y = normal(size=x.shape)
1432+
assert y.shape.eval({x: x_test}) == (50,)
1433+
assert y.eval({x: x_test}).shape == (50,)
1434+
1435+
x_new = vector(shape=(100,))
1436+
x_new_test = np.zeros((100,))
1437+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=True)
1438+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
1439+
assert y_new.eval({x_new: x_new_test}).shape == (100,)

tests/tensor/test_basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from pytensor.gradient import grad, hessian
1717
from pytensor.graph.basic import Apply
1818
from pytensor.graph.op import Op
19+
from pytensor.graph.replace import clone_replace
1920
from pytensor.misc.safe_asarray import _asarray
2021
from pytensor.raise_op import Assert
2122
from pytensor.scalar import autocast_float, autocast_float_as
@@ -818,6 +819,20 @@ def test_full(self):
818819
res = pytensor.function([], full_at, mode=self.mode)()
819820
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64"))
820821

822+
@pytest.mark.parametrize("func", (at.zeros, at.empty))
823+
def test_rebuild(self, func):
824+
x = vector(shape=(50,))
825+
x_test = np.zeros((50,))
826+
y = func(x.shape)
827+
assert y.shape.eval({x: x_test}) == (50,)
828+
assert y.eval({x: x_test}).shape == (50,)
829+
830+
x_new = vector(shape=(100,))
831+
x_new_test = np.zeros((100,))
832+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
833+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
834+
assert y_new.eval({x_new: x_new_test}).shape == (100,)
835+
821836

822837
def test_infer_shape():
823838
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):

tests/tensor/test_extra_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pytensor.configdefaults import config
1111
from pytensor.graph.basic import Constant, applys_between
1212
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
13+
from pytensor.graph.replace import clone_replace
1314
from pytensor.raise_op import Assert
1415
from pytensor.tensor.elemwise import DimShuffle
1516
from pytensor.tensor.extra_ops import (
@@ -1393,6 +1394,20 @@ def test_inplace(self):
13931394

13941395
assert advincsub_node.op.inplace is False
13951396

1397+
def test_rebuild(self):
1398+
x = vector(shape=(50,))
1399+
x_test = np.zeros((50,))
1400+
i = 0
1401+
y = broadcast_to(i, x.shape)
1402+
assert y.shape.eval({x: x_test}) == (50,)
1403+
assert y.eval({x: x_test}).shape == (50,)
1404+
1405+
x_new = vector(shape=(100,))
1406+
x_new_test = np.zeros((100,))
1407+
y_new = clone_replace(y, {x: x_new}, rebuild_strict=False)
1408+
assert y_new.shape.eval({x_new: x_new_test}) == (100,)
1409+
assert y_new.eval({x_new: x_new_test}).shape == (100,)
1410+
13961411

13971412
def test_broadcast_arrays():
13981413
x, y = at.dvector(), at.dmatrix()

0 commit comments

Comments
 (0)