File tree Expand file tree Collapse file tree 2 files changed +26
-5
lines changed Expand file tree Collapse file tree 2 files changed +26
-5
lines changed Original file line number Diff line number Diff line change @@ -1088,20 +1088,20 @@ def truncated_graph_inputs(
1088
1088
return truncated_inputs
1089
1089
1090
1090
blockers : Set [Variable ] = set (ancestors_to_include )
1091
+ # variables that go here are under check already, do not repeat the loop for them
1092
+ seen : Set [Variable ] = set ()
1091
1093
# enforce O(1) check for node in ancestors to include
1092
1094
ancestors_to_include = blockers .copy ()
1093
1095
1094
1096
while candidates :
1095
1097
# on any new candidate
1096
1098
node = candidates .pop ()
1097
-
1098
- # There was a repeated reference to this node, we have already investigated it
1099
- if node in truncated_inputs :
1099
+ # we've looked into this node already
1100
+ if node in seen :
1100
1101
continue
1101
-
1102
1102
# check if the node is independent, never go above blockers;
1103
1103
# blockers are independent nodes and ancestors to include
1104
- if node in ancestors_to_include :
1104
+ elif node in ancestors_to_include :
1105
1105
# The case where node is in ancestors to include so we check if it depends on others
1106
1106
# it should be removed from the blockers to check against the rest
1107
1107
dependent = variable_depends_on (node , ancestors_to_include - {node })
@@ -1128,6 +1128,8 @@ def truncated_graph_inputs(
1128
1128
else :
1129
1129
# otherwise, do not search beyond
1130
1130
truncated_inputs .append (node )
1131
+ # add node to seen, no point in checking it once more
1132
+ seen .add (node )
1131
1133
return truncated_inputs
1132
1134
1133
1135
Original file line number Diff line number Diff line change @@ -795,3 +795,22 @@ def test_repeated_nested_input(self):
795
795
o2 .name = "o2"
796
796
797
797
assert truncated_graph_inputs ([o2 ], [trunc_inp ]) == [trunc_inp , x ]
798
+
799
+ def test_single_pass_per_node (self , mocker ):
800
+ import pytensor .graph .basic
801
+
802
+ inspect = mocker .spy (pytensor .graph .basic , "variable_depends_on" )
803
+ variables = [at .scalar (f"v{ i } " ) for i in range (3 )]
804
+ for i in range (20 ):
805
+ variables .append (
806
+ at .add (
807
+ variables [i ],
808
+ variables [(i ** 3 ) % len (variables )],
809
+ variables [(i ** 2 ) % len (variables )],
810
+ )
811
+ )
812
+ truncated_graph_inputs (variables [- 5 :], variables [15 :- 5 :5 ])
813
+ # make sure there were exactly the same calls as unique variables seen by the function
814
+ assert len (inspect .call_args_list ) == len (
815
+ {a for ((a , b ), kw ) in inspect .call_args_list }
816
+ )
You can’t perform that action at this time.
0 commit comments