Skip to content

Commit 8a040b9

Browse files
committed
Add feature that keeps track of full rewrite history
1 parent 646a734 commit 8a040b9

File tree

2 files changed

+197
-2
lines changed

2 files changed

+197
-2
lines changed

pytensor/graph/features.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,169 @@ def revert(self, fgraph, checkpoint):
438438
self.history[fgraph] = h
439439

440440

441+
class FullHistory(Feature):
442+
"""Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states
443+
444+
.. testcode::
445+
import pytensor
446+
import pytensor.tensor as pt
447+
from pytensor.graph.fg import FunctionGraph
448+
from pytensor.graph.features import FullHistory
449+
from pytensor.graph.rewriting.utils import rewrite_graph
450+
451+
x = pt.scalar("x")
452+
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
453+
454+
fg = FunctionGraph(outputs=[out])
455+
history = FullHistory()
456+
fg.attach_feature(history)
457+
458+
rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
459+
460+
# Replay rewrites
461+
history.start()
462+
pytensor.dprint(fg)
463+
with pytensor.config.change_flags(optimizer_verbose = True):
464+
for i in range(3):
465+
print(">> ", end="")
466+
pytensor.dprint(history.next())
467+
468+
.. testoutput::
469+
Log [id A] 4
470+
└─ True_div [id B] 3
471+
├─ Exp [id C] 2
472+
│ └─ x [id D]
473+
└─ Sum{axes=None} [id E] 1
474+
└─ Exp [id F] 0
475+
└─ x [id D]
476+
>> MergeOptimizer
477+
Log [id A] 3
478+
└─ True_div [id B] 2
479+
├─ Exp [id C] 0
480+
│ └─ x [id D]
481+
└─ Sum{axes=None} [id E] 1
482+
└─ Exp [id C] 0
483+
└─ ···
484+
>> local_mul_canonizer
485+
Log [id A] 1
486+
└─ Softmax{axis=None} [id B] 0
487+
└─ x [id C]
488+
>> local_logsoftmax
489+
LogSoftmax{axis=None} [id A] 0
490+
└─ x [id B]
491+
492+
493+
.. testcode::
494+
# Or in reverse
495+
with pytensor.config.change_flags(optimizer_verbose=True):
496+
for i in range(3):
497+
print(">> ", end="")
498+
pytensor.dprint(history.prev())
499+
500+
.. testoutput::
501+
>> local_logsoftmax
502+
Log [id A] 1
503+
└─ Softmax{axis=None} [id B] 0
504+
└─ x [id C]
505+
>> local_mul_canonizer
506+
Log [id A] 3
507+
└─ True_div [id B] 2
508+
├─ Exp [id C] 0
509+
│ └─ x [id D]
510+
└─ Sum{axes=None} [id E] 1
511+
└─ Exp [id C] 0
512+
└─ ···
513+
>> MergeOptimizer
514+
Log [id A] 4
515+
└─ True_div [id B] 3
516+
├─ Exp [id C] 2
517+
│ └─ x [id D]
518+
└─ Sum{axes=None} [id E] 1
519+
└─ Exp [id F] 0
520+
└─ x [id D]
521+
522+
523+
.. testcode::
524+
# Or go to any step
525+
pytensor.dprint(history.goto(2))
526+
527+
.. testoutput::
528+
Log [id A] 1
529+
└─ Softmax{axis=None} [id B] 0
530+
└─ x [id C]
531+
532+
533+
"""
534+
535+
def __init__(self):
536+
self.fw = []
537+
self.bw = []
538+
self.pointer = -1
539+
self.fg = None
540+
541+
def on_attach(self, fgraph):
542+
if self.fg is not None:
543+
raise ValueError("Full History already attached to another fgraph")
544+
self.fg = fgraph
545+
546+
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
547+
self.bw.append(LambdaExtract(fgraph, node, i, r, reason))
548+
self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason))
549+
self.pointer += 1
550+
551+
def goto(self, checkpoint):
552+
"""
553+
Reverts the graph to whatever it was at the provided
554+
checkpoint (undoes all replacements). A checkpoint at any
555+
given time can be obtained using self.checkpoint().
556+
557+
"""
558+
history_len = len(self.bw)
559+
pointer = self.pointer
560+
assert 0 <= checkpoint <= history_len
561+
verbose = config.optimizer_verbose
562+
563+
# Go backwards
564+
while pointer > checkpoint - 1:
565+
reverse_fn = self.bw[pointer]
566+
if verbose:
567+
print(reverse_fn.reason) # noqa: T201
568+
reverse_fn()
569+
pointer -= 1
570+
571+
# Go forward
572+
while pointer < checkpoint - 1:
573+
pointer += 1
574+
forward_fn = self.fw[pointer]
575+
if verbose:
576+
print(forward_fn.reason) # noqa: T201
577+
forward_fn()
578+
579+
# Remove history changes caused by the foward/backward!
580+
self.bw = self.bw[:history_len]
581+
self.fw = self.fw[:history_len]
582+
self.pointer = pointer
583+
return self.fg
584+
585+
def start(self):
586+
return self.goto(0)
587+
588+
def end(self):
589+
return self.goto(len(self.bw))
590+
591+
def prev(self):
592+
if self.pointer < 0:
593+
return self.fg
594+
else:
595+
return self.goto(self.pointer)
596+
597+
def next(self):
598+
if self.pointer >= len(self.bw) - 1:
599+
return self.fg
600+
else:
601+
return self.goto(self.pointer + 2)
602+
603+
441604
class Validator(Feature):
442605
pickle_rm_attr = ["validate", "consistent"]
443606

tests/graph/test_features.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import pytest
22

3-
from pytensor.graph.basic import Apply, Variable
4-
from pytensor.graph.features import Feature, NodeFinder, ReplaceValidate
3+
import pytensor.tensor as pt
4+
from pytensor.graph import rewrite_graph
5+
from pytensor.graph.basic import Apply, Variable, equal_computations
6+
from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate
57
from pytensor.graph.fg import FunctionGraph
68
from pytensor.graph.op import Op
79
from pytensor.graph.type import Type
@@ -119,3 +121,33 @@ def validate(self, *args):
119121

120122
capres = capsys.readouterr()
121123
assert "rewriting: validate failed on node Op1.0" in capres.out
124+
125+
126+
def test_full_history():
127+
x = pt.scalar("x")
128+
out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
129+
fg = FunctionGraph(outputs=[out], clone=True, copy_inputs=False)
130+
history = FullHistory()
131+
fg.attach_feature(history)
132+
rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
133+
134+
history.start()
135+
assert equal_computations(fg.outputs, [out])
136+
137+
history.end()
138+
assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])
139+
140+
history.prev()
141+
assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))])
142+
143+
for i in range(10):
144+
history.prev()
145+
assert equal_computations(fg.outputs, [out])
146+
147+
history.goto(2)
148+
assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))])
149+
150+
for i in range(10):
151+
history.next()
152+
153+
assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])

0 commit comments

Comments
 (0)