Skip to content

Commit 5f5be92

Browse files
committed
Add interactive rewrite ipython widget
1 parent 8a040b9 commit 5f5be92

File tree

4 files changed

+208
-3
lines changed

4 files changed

+208
-3
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ jobs:
107107
python-version: "3.13"
108108
include:
109109
- os: "ubuntu-latest"
110-
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
110+
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py"
111111
python-version: "3.12"
112112
numpy-version: ">=2.0"
113113
fast-compile: 0

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ versionfile_build = "pytensor/_version.py"
118118
tag_prefix = "rel-"
119119

120120
[tool.pytest.ini_options]
121-
addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py"
121+
addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/ipython.py"
122122
testpaths = ["pytensor/", "tests/"]
123123
xfail_strict = true
124124

pytensor/graph/features.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,12 @@ class FullHistory(Feature):
532532
533533
"""
534534

535-
def __init__(self):
535+
def __init__(self, callback=None):
536536
self.fw = []
537537
self.bw = []
538538
self.pointer = -1
539539
self.fg = None
540+
self.callback = callback
540541

541542
def on_attach(self, fgraph):
542543
if self.fg is not None:
@@ -547,6 +548,8 @@ def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
547548
self.bw.append(LambdaExtract(fgraph, node, i, r, reason))
548549
self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason))
549550
self.pointer += 1
551+
if self.callback:
552+
self.callback()
550553

551554
def goto(self, checkpoint):
552555
"""

pytensor/ipython.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import anywidget
2+
import ipywidgets as widgets
3+
import traitlets
4+
from IPython.display import display
5+
6+
from pytensor.graph import FunctionGraph, Variable, rewrite_graph
7+
from pytensor.graph.features import FullHistory
8+
9+
10+
class CodeBlockWidget(anywidget.AnyWidget):
11+
"""Widget that displays text content as a monospaced code block."""
12+
13+
content = traitlets.Unicode("").tag(sync=True)
14+
15+
_esm = """
16+
function render({ model, el }) {
17+
const pre = document.createElement("pre");
18+
pre.style.backgroundColor = "#f5f5f5";
19+
pre.style.padding = "10px";
20+
pre.style.borderRadius = "4px";
21+
pre.style.overflowX = "auto";
22+
pre.style.maxHeight = "500px";
23+
24+
const code = document.createElement("code");
25+
code.textContent = model.get("content");
26+
27+
pre.appendChild(code);
28+
el.appendChild(pre);
29+
30+
model.on("change:content", () => {
31+
code.textContent = model.get("content");
32+
});
33+
}
34+
export default { render };
35+
"""
36+
37+
_css = """
38+
.jp-RenderedHTMLCommon pre {
39+
font-family: monospace;
40+
white-space: pre;
41+
line-height: 1.4;
42+
}
43+
"""
44+
45+
46+
class InteractiveRewrite:
47+
"""
48+
A class that wraps a graph history object with interactive widgets
49+
to navigate through history and display the graph at each step.
50+
51+
Includes an option to display the reason for the last change.
52+
"""
53+
54+
def __init__(self, fg, display_reason=True):
55+
"""
56+
Initialize with a history object that has a goto method
57+
and tracks a FunctionGraph.
58+
59+
Parameters:
60+
-----------
61+
fg : FunctionGraph (or Variables)
62+
The function graph to track
63+
display_reason : bool, optional
64+
Whether to display the reason for each rewrite
65+
"""
66+
self.history = FullHistory(callback=self._history_callback)
67+
if not isinstance(fg, FunctionGraph):
68+
outs = [fg] if isinstance(fg, Variable) else fg
69+
fg = FunctionGraph(outputs=outs)
70+
fg.attach_feature(self.history)
71+
72+
self.updating_from_callback = False # Flag to prevent recursion
73+
self.code_widget = CodeBlockWidget(content="")
74+
self.display_reason = display_reason
75+
76+
if self.display_reason:
77+
self.reason_label = widgets.HTML(
78+
value="", description="", style={"description_width": "initial"}
79+
)
80+
self.slider_label = widgets.Label(value="")
81+
self.slider = widgets.IntSlider(
82+
value=self.history.pointer,
83+
min=0,
84+
max=0,
85+
step=1,
86+
description="", # Empty description since we're using a separate label
87+
continuous_update=True,
88+
layout=widgets.Layout(width="300px"),
89+
)
90+
self.prev_button = widgets.Button(description="← Previous")
91+
self.next_button = widgets.Button(description="Next →")
92+
self.slider.observe(self._on_slider_change, names="value")
93+
self.prev_button.on_click(self._on_prev_click)
94+
self.next_button.on_click(self._on_next_click)
95+
96+
self.rewrite_button = widgets.Button(
97+
description="Apply Rewrites",
98+
button_style="primary", # 'success', 'info', 'warning', 'danger' or ''
99+
tooltip="Apply default rewrites to the current graph",
100+
icon="cogs", # Optional: add an icon (requires font-awesome)
101+
)
102+
self.rewrite_button.on_click(self._on_rewrite_click)
103+
104+
self.nav_button_box = widgets.HBox([self.prev_button, self.next_button])
105+
self.slider_box = widgets.HBox([self.slider_label, self.slider])
106+
self.control_box = widgets.HBox([self.slider_box, self.rewrite_button])
107+
108+
# Update the display with the initial state
109+
self._update_display()
110+
111+
def _on_slider_change(self, change):
112+
"""Handle slider value changes"""
113+
if change["name"] == "value" and not self.updating_from_callback:
114+
self.updating_from_callback = True
115+
index = change["new"]
116+
self.history.goto(index)
117+
self._update_display()
118+
self.updating_from_callback = False
119+
120+
def _on_prev_click(self, b):
121+
"""Go to previous history item"""
122+
if self.slider.value > 0:
123+
self.slider.value -= 1
124+
125+
def _on_next_click(self, b):
126+
"""Go to next history item"""
127+
if self.slider.value < self.slider.max:
128+
self.slider.value += 1
129+
130+
def _on_rewrite_click(self, b):
131+
"""Handle rewrite button click"""
132+
self.slider.value = self.slider.max
133+
self.rewrite()
134+
135+
def display(self):
136+
"""Display the full widget interface"""
137+
display(
138+
widgets.VBox(
139+
[
140+
self.control_box,
141+
self.nav_button_box,
142+
*((self.reason_label,) if self.display_reason else ()),
143+
self.code_widget,
144+
]
145+
)
146+
)
147+
148+
def _ipython_display_(self):
149+
self.display()
150+
151+
def _history_callback(self):
152+
"""Callback for history updates that prevents recursion"""
153+
if not self.updating_from_callback:
154+
self.updating_from_callback = True
155+
self._update_display()
156+
self.updating_from_callback = False
157+
158+
def _update_display(self):
159+
"""Update the code widget with the current graph and reason"""
160+
# Update the reason label if checkbox is checked
161+
if self.display_reason:
162+
if self.history.pointer == -1:
163+
reason = ""
164+
else:
165+
reason = self.history.fw[self.history.pointer].reason
166+
reason = getattr(reason, "name", str(reason))
167+
168+
self.reason_label.value = f"""
169+
<div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'>
170+
<b>Rewrite:</b> {reason}
171+
</div>
172+
"""
173+
174+
# Update the graph display
175+
self.code_widget.content = self.history.fg.dprint(file="str")
176+
177+
# Update slider range if history length has changed
178+
history_len = len(self.history.fw) + 1
179+
if history_len != self.slider.max + 1:
180+
self.slider.max = history_len - 1
181+
182+
# Update slider value without triggering the observer
183+
if not self.updating_from_callback:
184+
with self.slider.hold_trait_notifications():
185+
self.slider.value = self.history.pointer + 1
186+
187+
# Update the slider label to show current position and total (1-based)
188+
self.slider_label.value = (
189+
f"History: {self.history.pointer + 1}/{history_len - 1}"
190+
)
191+
192+
def rewrite(self, *args, include=("fast_run",), exclude=("inplace",), **kwargs):
193+
"""Apply rewrites to the current graph"""
194+
rewrite_graph(
195+
self.history.fg,
196+
*args,
197+
include=include,
198+
exclude=exclude,
199+
**kwargs,
200+
clone=False,
201+
)
202+
self._update_display()

0 commit comments

Comments
 (0)