Skip to content

Speedup truncated_graph_inputs #394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ jobs:
- name: Install dependencies
shell: bash -l {0}
run: |
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark sympy
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
# PyTensor next, pip installs a lower version of numpy via the PyPI.
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies:
- pytest-cov
- pytest-xdist
- pytest-benchmark
- pytest-mock
# For building docs
- sphinx>=5.1.0,<6
- sphinx_rtd_theme
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ tests = [
"pytest-cov>=2.6.1",
"coverage>=5.1",
"pytest-benchmark",
"pytest-mock",
]
rtd = [
"sphinx>=5.1.0,<6",
Expand Down
76 changes: 39 additions & 37 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,14 +1003,14 @@ def applys_between(
def truncated_graph_inputs(
outputs: Sequence[Variable],
ancestors_to_include: Optional[Collection[Variable]] = None,
) -> List[Variable]:
) -> list[Variable]:
"""Get the truncate graph inputs.

Unlike :func:`graph_inputs` this function will return
the closest nodes to outputs that do not depend on
the closest variables to outputs that do not depend on
``ancestors_to_include``. So given all the returned
variables provided there is no missing node to
compute the output and all nodes are independent
variables provided there is no missing variable to
compute the output and all variables are independent
from each other.

Parameters
Expand All @@ -1027,7 +1027,7 @@ def truncated_graph_inputs(

Examples
--------
The returned nodes marked in (parenthesis), ancestors nodes are ``c``, output nodes are ``o``
The returned variables marked in (parenthesis), ancestors variables are ``c``, output variables are ``o``

* No ancestors to include

Expand All @@ -1047,7 +1047,7 @@ def truncated_graph_inputs(

(c) - (c) - o

* Additional nodes are present
* Additional variables are present

.. code-block::

Expand Down Expand Up @@ -1076,58 +1076,60 @@ def truncated_graph_inputs(

"""
# simple case, no additional ancestors to include
truncated_inputs = list()
# blockers have known independent nodes and ancestors to include
truncated_inputs: list[Variable] = list()
# blockers have known independent variables and ancestors to include
candidates = list(outputs)
if not ancestors_to_include: # None or empty
# just filter out unique variables
for node in candidates:
if node not in truncated_inputs:
truncated_inputs.append(node)
for variable in candidates:
if variable not in truncated_inputs:
truncated_inputs.append(variable)
# no more actions are needed
return truncated_inputs

blockers: Set[Variable] = set(ancestors_to_include)
# enforce O(1) check for node in ancestors to include
blockers: set[Variable] = set(ancestors_to_include)
# variables that go here are under check already, do not repeat the loop for them
seen: set[Variable] = set()
# enforce O(1) check for variable in ancestors to include
ancestors_to_include = blockers.copy()

while candidates:
# on any new candidate
node = candidates.pop()

# There was a repeated reference to this node, we have already investigated it
if node in truncated_inputs:
variable = candidates.pop()
# we've looked into this variable already
if variable in seen:
continue

# check if the node is independent, never go above blockers;
# blockers are independent nodes and ancestors to include
if node in ancestors_to_include:
# The case where node is in ancestors to include so we check if it depends on others
# check if the variable is independent, never go above blockers;
# blockers are independent variables and ancestors to include
elif variable in ancestors_to_include:
# The case where variable is in ancestors to include so we check if it depends on others
# it should be removed from the blockers to check against the rest
dependent = variable_depends_on(node, ancestors_to_include - {node})
dependent = variable_depends_on(variable, ancestors_to_include - {variable})
# ancestors to include that are present in the graph (not disconnected)
# should be added to truncated_inputs
truncated_inputs.append(node)
truncated_inputs.append(variable)
if dependent:
# if the ancestors to include is still dependent we need to go above, the search is not yet finished
# owner can never be None for a dependent node
candidates.extend(node.owner.inputs)
# owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen)
else:
# A regular node to check
dependent = variable_depends_on(node, blockers)
# all regular nodes fall to blockers
# A regular variable to check
dependent = variable_depends_on(variable, blockers)
# all regular variables fall to blockers
# 1. it is dependent - further search irrelevant
# 2. it is independent - the search node is inside the closure
blockers.add(node)
# if we've found an independent node and it is not in blockers so far
# it is a new independent node not present in ancestors to include
# 2. it is independent - the search variable is inside the closure
blockers.add(variable)
# if we've found an independent variable and it is not in blockers so far
# it is a new independent variable not present in ancestors to include
if dependent:
# populate search if it's not an independent node
# owner can never be None for a dependent node
candidates.extend(node.owner.inputs)
# populate search if it's not an independent variable
# owner can never be None for a dependent variable
candidates.extend(n for n in variable.owner.inputs if n not in seen)
else:
# otherwise, do not search beyond
truncated_inputs.append(node)
truncated_inputs.append(variable)
# add variable to seen, no point in checking it once more
seen.add(variable)
return truncated_inputs


Expand Down
15 changes: 15 additions & 0 deletions tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,18 @@ def test_repeated_nested_input(self):
o2.name = "o2"

assert truncated_graph_inputs([o2], [trunc_inp]) == [trunc_inp, x]

def test_single_pass_per_node(self, mocker):
import pytensor.graph.basic

inspect = mocker.spy(pytensor.graph.basic, "variable_depends_on")
x = at.dmatrix("x")
m = x.shape[0][None, None]

f = x / m
w = x / m - f
truncated_graph_inputs([w], [x])
# make sure there were exactly the same calls as unique variables seen by the function
assert len(inspect.call_args_list) == len(
{a for ((a, b), kw) in inspect.call_args_list}
)