Skip to content

Add Feature that can go back and forward in rewrite history #874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ jobs:
python-version: "3.13"
include:
- os: "ubuntu-latest"
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
python-version: "3.12"
numpy-version: ">=2.0"
fast-compile: 0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ versionfile_build = "pytensor/_version.py"
tag_prefix = "rel-"

[tool.pytest.ini_options]
addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py"
addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/ipython.py"
testpaths = ["pytensor/", "tests/"]
xfail_strict = true

Expand Down
166 changes: 166 additions & 0 deletions pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,172 @@
self.history[fgraph] = h


class FullHistory(Feature):
"""Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states

.. testcode::
import pytensor
import pytensor.tensor as pt
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.features import FullHistory
from pytensor.graph.rewriting.utils import rewrite_graph

x = pt.scalar("x")
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))

fg = FunctionGraph(outputs=[out])
history = FullHistory()
fg.attach_feature(history)

rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))

# Replay rewrites
history.start()
pytensor.dprint(fg)
with pytensor.config.change_flags(optimizer_verbose = True):
for i in range(3):
print(">> ", end="")
pytensor.dprint(history.next())

.. testoutput::
Log [id A] 4
└─ True_div [id B] 3
├─ Exp [id C] 2
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id F] 0
└─ x [id D]
>> MergeOptimizer
Log [id A] 3
└─ True_div [id B] 2
├─ Exp [id C] 0
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id C] 0
└─ ···
>> local_mul_canonizer
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
>> local_logsoftmax
LogSoftmax{axis=None} [id A] 0
└─ x [id B]


.. testcode::
# Or in reverse
with pytensor.config.change_flags(optimizer_verbose=True):
for i in range(3):
print(">> ", end="")
pytensor.dprint(history.prev())

.. testoutput::
>> local_logsoftmax
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]
>> local_mul_canonizer
Log [id A] 3
└─ True_div [id B] 2
├─ Exp [id C] 0
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id C] 0
└─ ···
>> MergeOptimizer
Log [id A] 4
└─ True_div [id B] 3
├─ Exp [id C] 2
│ └─ x [id D]
└─ Sum{axes=None} [id E] 1
└─ Exp [id F] 0
└─ x [id D]


.. testcode::
# Or go to any step
pytensor.dprint(history.goto(2))

.. testoutput::
Log [id A] 1
└─ Softmax{axis=None} [id B] 0
└─ x [id C]


"""

def __init__(self, callback=None):
self.fw = []
self.bw = []
self.pointer = -1
self.fg = None
self.callback = callback

def on_attach(self, fgraph):
if self.fg is not None:
raise ValueError("Full History already attached to another fgraph")

Check warning on line 544 in pytensor/graph/features.py

View check run for this annotation

Codecov / codecov/patch

pytensor/graph/features.py#L544

Added line #L544 was not covered by tests
self.fg = fgraph

def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
self.bw.append(LambdaExtract(fgraph, node, i, r, reason))
self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason))
self.pointer += 1
if self.callback:
self.callback()

Check warning on line 552 in pytensor/graph/features.py

View check run for this annotation

Codecov / codecov/patch

pytensor/graph/features.py#L552

Added line #L552 was not covered by tests

def goto(self, checkpoint):
"""
Reverts the graph to whatever it was at the provided
checkpoint (undoes all replacements). A checkpoint at any
given time can be obtained using self.checkpoint().

"""
history_len = len(self.bw)
pointer = self.pointer
assert 0 <= checkpoint <= history_len
Copy link
Preview

Copilot AI Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using assert for input validation can be bypassed in optimized runs; consider raising a ValueError for out-of-range checkpoint values instead.

Suggested change
assert 0 <= checkpoint <= history_len
if not (0 <= checkpoint <= history_len):
raise ValueError(f"Checkpoint value {checkpoint} is out of range. It must be between 0 and {history_len}.")

Copilot uses AI. Check for mistakes.

verbose = config.optimizer_verbose

# Go backwards
while pointer > checkpoint - 1:
reverse_fn = self.bw[pointer]
if verbose:
print(reverse_fn.reason) # noqa: T201
reverse_fn()
pointer -= 1

# Go forward
while pointer < checkpoint - 1:
pointer += 1
forward_fn = self.fw[pointer]
if verbose:
print(forward_fn.reason) # noqa: T201
forward_fn()

# Remove history changes caused by the foward/backward!
self.bw = self.bw[:history_len]
self.fw = self.fw[:history_len]
self.pointer = pointer
return self.fg

def start(self):
return self.goto(0)

def end(self):
return self.goto(len(self.bw))

def prev(self):
if self.pointer < 0:
return self.fg
else:
return self.goto(self.pointer)

def next(self):
if self.pointer >= len(self.bw) - 1:
return self.fg
else:
return self.goto(self.pointer + 2)


class Validator(Feature):
pickle_rm_attr = ["validate", "consistent"]

Expand Down
202 changes: 202 additions & 0 deletions pytensor/ipython.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import anywidget
import ipywidgets as widgets
import traitlets
from IPython.display import display

Check warning on line 4 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L1-L4

Added lines #L1 - L4 were not covered by tests

from pytensor.graph import FunctionGraph, Variable, rewrite_graph
from pytensor.graph.features import FullHistory

Check warning on line 7 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L6-L7

Added lines #L6 - L7 were not covered by tests


class CodeBlockWidget(anywidget.AnyWidget):

Check warning on line 10 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L10

Added line #L10 was not covered by tests
"""Widget that displays text content as a monospaced code block."""

content = traitlets.Unicode("").tag(sync=True)

Check warning on line 13 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L13

Added line #L13 was not covered by tests

_esm = """

Check warning on line 15 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L15

Added line #L15 was not covered by tests
function render({ model, el }) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these should be put in a widget.js and a widget.css file instead of hanging out as a giant string.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I would keep it together, until and if it actually grows larger. It's a single file this way instead of a folder. We can always refactor later

const pre = document.createElement("pre");
pre.style.backgroundColor = "#f5f5f5";
pre.style.padding = "10px";
pre.style.borderRadius = "4px";
pre.style.overflowX = "auto";
pre.style.maxHeight = "500px";

const code = document.createElement("code");
code.textContent = model.get("content");

pre.appendChild(code);
el.appendChild(pre);

model.on("change:content", () => {
code.textContent = model.get("content");
});
}
export default { render };
"""

_css = """

Check warning on line 37 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L37

Added line #L37 was not covered by tests
.jp-RenderedHTMLCommon pre {
font-family: monospace;
white-space: pre;
line-height: 1.4;
}
"""


class InteractiveRewrite:

Check warning on line 46 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L46

Added line #L46 was not covered by tests
"""
A class that wraps a graph history object with interactive widgets
to navigate through history and display the graph at each step.

Includes an option to display the reason for the last change.
"""

def __init__(self, fg, display_reason=True):

Check warning on line 54 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L54

Added line #L54 was not covered by tests
"""
Initialize with a history object that has a goto method
and tracks a FunctionGraph.

Parameters:
-----------
fg : FunctionGraph (or Variables)
The function graph to track
display_reason : bool, optional
Whether to display the reason for each rewrite
"""
self.history = FullHistory(callback=self._history_callback)

Check warning on line 66 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L66

Added line #L66 was not covered by tests
if not isinstance(fg, FunctionGraph):
outs = [fg] if isinstance(fg, Variable) else fg
fg = FunctionGraph(outputs=outs)
fg.attach_feature(self.history)

Check warning on line 70 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L68-L70

Added lines #L68 - L70 were not covered by tests

self.updating_from_callback = False # Flag to prevent recursion
self.code_widget = CodeBlockWidget(content="")
self.display_reason = display_reason

Check warning on line 74 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L72-L74

Added lines #L72 - L74 were not covered by tests

if self.display_reason:
self.reason_label = widgets.HTML(

Check warning on line 77 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L77

Added line #L77 was not covered by tests
value="", description="", style={"description_width": "initial"}
)
self.slider_label = widgets.Label(value="")
self.slider = widgets.IntSlider(

Check warning on line 81 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L80-L81

Added lines #L80 - L81 were not covered by tests
value=self.history.pointer,
min=0,
max=0,
step=1,
description="", # Empty description since we're using a separate label
continuous_update=True,
layout=widgets.Layout(width="300px"),
)
self.prev_button = widgets.Button(description="← Previous")
self.next_button = widgets.Button(description="Next →")
self.slider.observe(self._on_slider_change, names="value")
self.prev_button.on_click(self._on_prev_click)
self.next_button.on_click(self._on_next_click)

Check warning on line 94 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L90-L94

Added lines #L90 - L94 were not covered by tests

self.rewrite_button = widgets.Button(

Check warning on line 96 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L96

Added line #L96 was not covered by tests
description="Apply Rewrites",
button_style="primary", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Apply default rewrites to the current graph",
icon="cogs", # Optional: add an icon (requires font-awesome)
)
self.rewrite_button.on_click(self._on_rewrite_click)

Check warning on line 102 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L102

Added line #L102 was not covered by tests

self.nav_button_box = widgets.HBox([self.prev_button, self.next_button])
self.slider_box = widgets.HBox([self.slider_label, self.slider])
self.control_box = widgets.HBox([self.slider_box, self.rewrite_button])

Check warning on line 106 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L104-L106

Added lines #L104 - L106 were not covered by tests

# Update the display with the initial state
self._update_display()

Check warning on line 109 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L109

Added line #L109 was not covered by tests

def _on_slider_change(self, change):

Check warning on line 111 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L111

Added line #L111 was not covered by tests
"""Handle slider value changes"""
if change["name"] == "value" and not self.updating_from_callback:
self.updating_from_callback = True
index = change["new"]
self.history.goto(index)
self._update_display()
self.updating_from_callback = False

Check warning on line 118 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L114-L118

Added lines #L114 - L118 were not covered by tests

def _on_prev_click(self, b):

Check warning on line 120 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L120

Added line #L120 was not covered by tests
"""Go to previous history item"""
if self.slider.value > 0:
self.slider.value -= 1

Check warning on line 123 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L123

Added line #L123 was not covered by tests

def _on_next_click(self, b):

Check warning on line 125 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L125

Added line #L125 was not covered by tests
"""Go to next history item"""
if self.slider.value < self.slider.max:
self.slider.value += 1

Check warning on line 128 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L128

Added line #L128 was not covered by tests

def _on_rewrite_click(self, b):

Check warning on line 130 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L130

Added line #L130 was not covered by tests
"""Handle rewrite button click"""
self.slider.value = self.slider.max
self.rewrite()

Check warning on line 133 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L132-L133

Added lines #L132 - L133 were not covered by tests

def display(self):

Check warning on line 135 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L135

Added line #L135 was not covered by tests
"""Display the full widget interface"""
display(

Check warning on line 137 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L137

Added line #L137 was not covered by tests
widgets.VBox(
[
self.control_box,
self.nav_button_box,
*((self.reason_label,) if self.display_reason else ()),
self.code_widget,
]
)
)

def _ipython_display_(self):
self.display()

Check warning on line 149 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L148-L149

Added lines #L148 - L149 were not covered by tests

def _history_callback(self):

Check warning on line 151 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L151

Added line #L151 was not covered by tests
"""Callback for history updates that prevents recursion"""
if not self.updating_from_callback:
self.updating_from_callback = True
self._update_display()
self.updating_from_callback = False

Check warning on line 156 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L154-L156

Added lines #L154 - L156 were not covered by tests

def _update_display(self):

Check warning on line 158 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L158

Added line #L158 was not covered by tests
"""Update the code widget with the current graph and reason"""
# Update the reason label if checkbox is checked
if self.display_reason:
if self.history.pointer == -1:
reason = ""

Check warning on line 163 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L163

Added line #L163 was not covered by tests
else:
reason = self.history.fw[self.history.pointer].reason
reason = getattr(reason, "name", str(reason))

Check warning on line 166 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L165-L166

Added lines #L165 - L166 were not covered by tests

self.reason_label.value = f"""

Check warning on line 168 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L168

Added line #L168 was not covered by tests
<div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'>
<b>Rewrite:</b> {reason}
</div>
"""

# Update the graph display
self.code_widget.content = self.history.fg.dprint(file="str")

Check warning on line 175 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L175

Added line #L175 was not covered by tests

# Update slider range if history length has changed
history_len = len(self.history.fw) + 1

Check warning on line 178 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L178

Added line #L178 was not covered by tests
if history_len != self.slider.max + 1:
self.slider.max = history_len - 1

Check warning on line 180 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L180

Added line #L180 was not covered by tests

# Update slider value without triggering the observer
if not self.updating_from_callback:
with self.slider.hold_trait_notifications():
self.slider.value = self.history.pointer + 1

Check warning on line 185 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L184-L185

Added lines #L184 - L185 were not covered by tests

# Update the slider label to show current position and total (1-based)
self.slider_label.value = (

Check warning on line 188 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L188

Added line #L188 was not covered by tests
f"History: {self.history.pointer + 1}/{history_len - 1}"
)

def rewrite(self, *args, include=("fast_run",), exclude=("inplace",), **kwargs):

Check warning on line 192 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L192

Added line #L192 was not covered by tests
"""Apply rewrites to the current graph"""
rewrite_graph(

Check warning on line 194 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L194

Added line #L194 was not covered by tests
self.history.fg,
*args,
include=include,
exclude=exclude,
**kwargs,
clone=False,
)
self._update_display()

Check warning on line 202 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L202

Added line #L202 was not covered by tests
Loading