Skip to content

Commit e9a7d7c

Browse files
authored
Speedup truncated_graph_inputs (#394)
* add pytest-mock dependency * rename to node to variable
1 parent 673c1ac commit e9a7d7c

File tree

5 files changed

+57
-38
lines changed

5 files changed

+57
-38
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
- name: Install dependencies
140140
shell: bash -l {0}
141141
run: |
142-
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
142+
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
143143
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but
144144
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
145145
# PyTensor next, pip installs a lower version of numpy via the PyPI.

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies:
3131
- pytest-cov
3232
- pytest-xdist
3333
- pytest-benchmark
34+
- pytest-mock
3435
# For building docs
3536
- sphinx>=5.1.0,<6
3637
- sphinx_rtd_theme

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ tests = [
8686
"pytest-cov>=2.6.1",
8787
"coverage>=5.1",
8888
"pytest-benchmark",
89+
"pytest-mock",
8990
]
9091
rtd = [
9192
"sphinx>=5.1.0,<6",

pytensor/graph/basic.py

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,14 +1003,14 @@ def applys_between(
10031003
def truncated_graph_inputs(
10041004
outputs: Sequence[Variable],
10051005
ancestors_to_include: Optional[Collection[Variable]] = None,
1006-
) -> List[Variable]:
1006+
) -> list[Variable]:
10071007
"""Get the truncate graph inputs.
10081008
10091009
Unlike :func:`graph_inputs` this function will return
1010-
the closest nodes to outputs that do not depend on
1010+
the closest variables to outputs that do not depend on
10111011
``ancestors_to_include``. So given all the returned
1012-
variables provided there is no missing node to
1013-
compute the output and all nodes are independent
1012+
variables provided there is no missing variable to
1013+
compute the output and all variables are independent
10141014
from each other.
10151015
10161016
Parameters
@@ -1027,7 +1027,7 @@ def truncated_graph_inputs(
10271027
10281028
Examples
10291029
--------
1030-
The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o``
1030+
The returned variables marked in (parenthesis), ancestors variables are ``c``, output variables are ``o``
10311031
10321032
* No ancestors to include
10331033
@@ -1047,7 +1047,7 @@ def truncated_graph_inputs(
10471047
10481048
(c) - (c) - o
10491049
1050-
* Additional nodes are present
1050+
* Additional variables are present
10511051
10521052
.. code-block::
10531053
@@ -1076,58 +1076,60 @@ def truncated_graph_inputs(
10761076
10771077
"""
10781078
# simple case, no additional ancestors to include
1079-
truncated_inputs = list()
1080-
# blockers have known independent nodes and ancestors to include
1079+
truncated_inputs: list[Variable] = list()
1080+
# blockers have known independent variables and ancestors to include
10811081
candidates = list(outputs)
10821082
if not ancestors_to_include: # None or empty
10831083
# just filter out unique variables
1084-
for node in candidates:
1085-
if node not in truncated_inputs:
1086-
truncated_inputs.append(node)
1084+
for variable in candidates:
1085+
if variable not in truncated_inputs:
1086+
truncated_inputs.append(variable)
10871087
# no more actions are needed
10881088
return truncated_inputs
10891089

1090-
blockers: Set[Variable] = set(ancestors_to_include)
1091-
# enforce O(1) check for node in ancestors to include
1090+
blockers: set[Variable] = set(ancestors_to_include)
1091+
# variables that go here are under check already, do not repeat the loop for them
1092+
seen: set[Variable] = set()
1093+
# enforce O(1) check for variable in ancestors to include
10921094
ancestors_to_include = blockers.copy()
10931095

10941096
while candidates:
10951097
# on any new candidate
1096-
node = candidates.pop()
1097-
1098-
# There was a repeated reference to this node, we have already investigated it
1099-
if node in truncated_inputs:
1098+
variable = candidates.pop()
1099+
# we've looked into this variable already
1100+
if variable in seen:
11001101
continue
1101-
1102-
# check if the node is independent, never go above blockers;
1103-
# blockers are independent nodes and ancestors to include
1104-
if node in ancestors_to_include:
1105-
# The case where node is in ancestors to include so we check if it depends on others
1102+
# check if the variable is independent, never go above blockers;
1103+
# blockers are independent variables and ancestors to include
1104+
elif variable in ancestors_to_include:
1105+
# The case where variable is in ancestors to include so we check if it depends on others
11061106
# it should be removed from the blockers to check against the rest
1107-
dependent = variable_depends_on(node, ancestors_to_include - {node})
1107+
dependent = variable_depends_on(variable, ancestors_to_include - {variable})
11081108
# ancestors to include that are present in the graph (not disconnected)
11091109
# should be added to truncated_inputs
1110-
truncated_inputs.append(node)
1110+
truncated_inputs.append(variable)
11111111
if dependent:
11121112
# if the ancestors to include is still dependent we need to go above, the search is not yet finished
1113-
# owner can never be None for a dependent node
1114-
candidates.extend(node.owner.inputs)
1113+
# owner can never be None for a dependent variable
1114+
candidates.extend(n for n in variable.owner.inputs if n not in seen)
11151115
else:
1116-
# A regular node to check
1117-
dependent = variable_depends_on(node, blockers)
1118-
# all regular nodes fall to blockers
1116+
# A regular variable to check
1117+
dependent = variable_depends_on(variable, blockers)
1118+
# all regular variables fall to blockers
11191119
# 1. it is dependent - further search irrelevant
1120-
# 2. it is independent - the search node is inside the closure
1121-
blockers.add(node)
1122-
# if we've found an independent node and it is not in blockers so far
1123-
# it is a new independent node not present in ancestors to include
1120+
# 2. it is independent - the search variable is inside the closure
1121+
blockers.add(variable)
1122+
# if we've found an independent variable and it is not in blockers so far
1123+
# it is a new independent variable not present in ancestors to include
11241124
if dependent:
1125-
# populate search if it's not an independent node
1126-
# owner can never be None for a dependent node
1127-
candidates.extend(node.owner.inputs)
1125+
# populate search if it's not an independent variable
1126+
# owner can never be None for a dependent variable
1127+
candidates.extend(n for n in variable.owner.inputs if n not in seen)
11281128
else:
11291129
# otherwise, do not search beyond
1130-
truncated_inputs.append(node)
1130+
truncated_inputs.append(variable)
1131+
# add variable to seen, no point in checking it once more
1132+
seen.add(variable)
11311133
return truncated_inputs
11321134

11331135

tests/graph/test_basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,18 @@ def test_repeated_nested_input(self):
795795
o2.name = "o2"
796796

797797
assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x]
798+
799+
def test_single_pass_per_node(self, mocker):
800+
import pytensor.graph.basic
801+
802+
inspect = mocker.spy(pytensor.graph.basic, "variable_depends_on")
803+
x = at.dmatrix("x")
804+
m = x.shape[0][None, None]
805+
806+
f = x / m
807+
w = x / m - f
808+
truncated_graph_inputs([w], [x])
809+
# make sure there were exactly the same calls as unique variables seen by the function
810+
assert len(inspect.call_args_list) == len(
811+
{a for ((a, b), kw) in inspect.call_args_list}
812+
)

0 commit comments

Comments
 (0)