Skip to content

Commit 81638c3

Browse files
committed
add seen set
1 parent b2b7e28 commit 81638c3

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

pytensor/graph/basic.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,20 +1088,20 @@ def truncated_graph_inputs(
10881088
return truncated_inputs
10891089

10901090
blockers: Set[Variable] = set(ancestors_to_include)
1091+
# variables that go here are under check already, do not repeat the loop for them
1092+
seen: Set[Variable] = set()
10911093
# enforce O(1) check for node in ancestors to include
10921094
ancestors_to_include = blockers.copy()
10931095

10941096
while candidates:
10951097
# on any new candidate
10961098
node = candidates.pop()
1097-
1098-
# There was a repeated reference to this node, we have already investigated it
1099-
if node in truncated_inputs:
1099+
# we've looked into this node already
1100+
if node in seen:
11001101
continue
1101-
11021102
# check if the node is independent, never go above blockers;
11031103
# blockers are independent nodes and ancestors to include
1104-
if node in ancestors_to_include:
1104+
elif node in ancestors_to_include:
11051105
# The case where node is in ancestors to include so we check if it depends on others
11061106
# it should be removed from the blockers to check against the rest
11071107
dependent = variable_depends_on(node, ancestors_to_include - {node})
@@ -1128,6 +1128,8 @@ def truncated_graph_inputs(
11281128
else:
11291129
# otherwise, do not search beyond
11301130
truncated_inputs.append(node)
1131+
# add node to seen, no point in checking it once more
1132+
seen.add(node)
11311133
return truncated_inputs
11321134

11331135

tests/graph/test_basic.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,22 @@ def test_repeated_nested_input(self):
795795
o2.name = "o2"
796796

797797
assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x]
798+
799+
def test_single_pass_per_node(self, mocker):
800+
import pytensor.graph.basic
801+
802+
inspect = mocker.spy(pytensor.graph.basic, "variable_depends_on")
803+
variables = [at.scalar(f"v{i}") for i in range(3)]
804+
for i in range(20):
805+
variables.append(
806+
at.add(
807+
variables[i],
808+
variables[(i**3) % len(variables)],
809+
variables[(i**2) % len(variables)],
810+
)
811+
)
812+
truncated_graph_inputs(variables[-5:], variables[15:-5:5])
813+
# make sure there were exactly the same calls as unique variables seen by the function
814+
assert len(inspect.call_args_list) == len(
815+
{a for ((a, b), kw) in inspect.call_args_list}
816+
)

0 commit comments

Comments
 (0)