Skip to content

Commit e65b0c5

Browse files
Make OpFromGraph.make_node interface consistent with its Apply nodes
1 parent 6ef1452 commit e65b0c5

File tree

2 files changed

+106
-9
lines changed

2 files changed

+106
-9
lines changed

aesara/compile/builders.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,11 @@ def __init__(
375375
self.kwargs = kwargs
376376
self.input_types = [inp.type for inp in inputs]
377377
self.output_types = [out.type for out in outputs]
378+
379+
self.lop_overrides = lop_overrides
380+
self.grad_overrides = grad_overrides
381+
self.rop_overrides = rop_overrides
382+
378383
if lop_overrides != "default":
379384
if grad_overrides != "default":
380385
raise ValueError(
@@ -732,19 +737,71 @@ def R_op(self, inputs, eval_points):
732737
]
733738
return ret_l
734739

740+
def __call__(self, *inputs, **kwargs):
741+
# The user interface doesn't expect the shared variable inputs of the
742+
# inner-graph, but, since `Op.make_node` does (and `Op.__call__`
743+
# dispatches to `Op.make_node`), we need to compensate here
744+
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
745+
746+
if len(inputs) == num_expected_inps:
747+
actual_inputs = inputs + tuple(self.shared_inputs)
748+
return super().__call__(*actual_inputs, **kwargs)
749+
elif len(inputs) == len(self.inner_inputs):
750+
return super().__call__(*inputs, **kwargs)
751+
else:
752+
raise ValueError(f"Expected at least {num_expected_inps} input(s)")
753+
735754
def make_node(self, *inputs):
755+
# The `inputs` received here should correspond to the inputs in the
756+
# `Apply` nodes we produce below
757+
if len(inputs) != len(self.inner_inputs):
758+
raise ValueError(f"Expected {len(self.inner_inputs)} input(s)")
759+
736760
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
737-
if len(inputs) != num_expected_inps:
738-
raise ValueError(
739-
f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}"
740-
)
741-
inputs = [
742-
inp_t.filter_variable(inp) for inp, inp_t in zip(inputs, self.input_types)
761+
non_shared_inputs = inputs[:num_expected_inps]
762+
763+
non_shared_inputs = [
764+
inp_t.filter_variable(inp)
765+
for inp, inp_t in zip(non_shared_inputs, self.input_types)
743766
]
767+
768+
shared_inputs = inputs[num_expected_inps:]
769+
local_shared_inputs = self.inner_inputs[num_expected_inps:]
770+
771+
inner_and_input_shareds = list(zip(local_shared_inputs, shared_inputs))
772+
773+
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
774+
# The shared variables are not equal to the original shared
775+
# variables, so we construct a new `Op` that uses the new shared
776+
# variables instead
777+
replace = {
778+
old_inp: new_inp for old_inp, new_inp in zip(self.inner_inputs, inputs)
779+
}
780+
replace.update(inner_and_input_shareds)
781+
782+
# If the new shared variables are inconsistent with the inner-graph,
783+
# such errors should arise in this step
784+
new_outputs = clone_replace(
785+
self.inner_outputs, replace=replace, share_inputs=True
786+
)
787+
788+
new_op = type(self)(
789+
inputs=non_shared_inputs,
790+
outputs=new_outputs,
791+
inline=self.is_inline,
792+
lop_overrides=self.lop_overrides,
793+
grad_overrides=self.grad_overrides,
794+
rop_overrides=self.rop_overrides,
795+
connection_pattern=self._connection_pattern,
796+
name=self.name,
797+
)
798+
else:
799+
new_op = self
800+
744801
apply_node = Apply(
745-
self,
746-
list(inputs) + self.shared_inputs,
747-
[type() for type in self.output_types],
802+
new_op,
803+
list(non_shared_inputs) + new_op.shared_inputs,
804+
[type() for type in new_op.output_types],
748805
)
749806
return apply_node
750807

tests/compile/test_builders.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import pytest
55

6+
import aesara.tensor as at
67
from aesara.compile import shared
78
from aesara.compile.builders import OpFromGraph
89
from aesara.compile.function import function
@@ -29,6 +30,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
2930
def test_valid_input(self):
3031
x, y, z = matrices("xyz")
3132

33+
with pytest.raises(ValueError, match="Expected at least.*"):
34+
OpFromGraph([x], [x])()
35+
36+
with pytest.raises(ValueError, match=r"Expected 1 input\(s\)"):
37+
OpFromGraph([x], [x]).make_node()
38+
3239
with pytest.raises(TypeError):
3340
OpFromGraph((x,), (x,))
3441

@@ -451,6 +458,39 @@ def test_compute_test_value(self):
451458
grad_f = grad(f, y)
452459
assert grad_f.tag.test_value is not None
453460

461+
def test_make_node_shared(self):
462+
"""Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`."""
463+
464+
x = at.scalar("x")
465+
y = shared(1.0, name="y")
466+
467+
test_ofg = OpFromGraph([x], [x + y])
468+
assert test_ofg.inputs == [x]
469+
assert test_ofg.shared_inputs == [y]
470+
471+
out = test_ofg(x)
472+
473+
y_clone = y.clone()
474+
assert y_clone != y
475+
y_clone.name = "y_clone"
476+
477+
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0]
478+
479+
assert out_new.owner.op.inputs == [x]
480+
assert out_new.owner.op.shared_inputs == [y_clone]
481+
482+
out_fn = function([x], out_new)
483+
484+
assert np.array_equal(out_fn(1.0), 2.0)
485+
486+
y_clone.set_value(2.0)
487+
488+
assert np.array_equal(out_fn(1.0), 3.0)
489+
490+
# This should also work, because the containers are the same:
491+
# y.set_value(1.0)
492+
# assert np.array_equal(out_fn(1.0), 2.0)
493+
454494

455495
def test_debugprint():
456496
x, y, z = matrices("xyz")

0 commit comments

Comments
 (0)