@@ -438,6 +438,169 @@ def revert(self, fgraph, checkpoint):
438
438
self .history [fgraph ] = h
439
439
440
440
441
+ class FullHistory (Feature ):
442
+ """Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states
443
+
444
+ .. testcode::
445
+ import pytensor
446
+ import pytensor.tensor as pt
447
+ from pytensor.graph.fg import FunctionGraph
448
+ from pytensor.graph.features import FullHistory
449
+ from pytensor.graph.rewriting.utils import rewrite_graph
450
+
451
+ x = pt.scalar("x")
452
+ out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
453
+
454
+ fg = FunctionGraph(outputs=[out])
455
+ history = FullHistory()
456
+ fg.attach_feature(history)
457
+
458
+ rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
459
+
460
+ # Replay rewrites
461
+ history.start()
462
+ pytensor.dprint(fg)
463
+ with pytensor.config.change_flags(optimizer_verbose = True):
464
+ for i in range(3):
465
+ print(">> ", end="")
466
+ pytensor.dprint(history.next())
467
+
468
+ .. testoutput::
469
+ Log [id A] 4
470
+ └─ True_div [id B] 3
471
+ ├─ Exp [id C] 2
472
+ │ └─ x [id D]
473
+ └─ Sum{axes=None} [id E] 1
474
+ └─ Exp [id F] 0
475
+ └─ x [id D]
476
+ >> MergeOptimizer
477
+ Log [id A] 3
478
+ └─ True_div [id B] 2
479
+ ├─ Exp [id C] 0
480
+ │ └─ x [id D]
481
+ └─ Sum{axes=None} [id E] 1
482
+ └─ Exp [id C] 0
483
+ └─ ···
484
+ >> local_mul_canonizer
485
+ Log [id A] 1
486
+ └─ Softmax{axis=None} [id B] 0
487
+ └─ x [id C]
488
+ >> local_logsoftmax
489
+ LogSoftmax{axis=None} [id A] 0
490
+ └─ x [id B]
491
+
492
+
493
+ .. testcode::
494
+ # Or in reverse
495
+ with pytensor.config.change_flags(optimizer_verbose=True):
496
+ for i in range(3):
497
+ print(">> ", end="")
498
+ pytensor.dprint(history.prev())
499
+
500
+ .. testoutput::
501
+ >> local_logsoftmax
502
+ Log [id A] 1
503
+ └─ Softmax{axis=None} [id B] 0
504
+ └─ x [id C]
505
+ >> local_mul_canonizer
506
+ Log [id A] 3
507
+ └─ True_div [id B] 2
508
+ ├─ Exp [id C] 0
509
+ │ └─ x [id D]
510
+ └─ Sum{axes=None} [id E] 1
511
+ └─ Exp [id C] 0
512
+ └─ ···
513
+ >> MergeOptimizer
514
+ Log [id A] 4
515
+ └─ True_div [id B] 3
516
+ ├─ Exp [id C] 2
517
+ │ └─ x [id D]
518
+ └─ Sum{axes=None} [id E] 1
519
+ └─ Exp [id F] 0
520
+ └─ x [id D]
521
+
522
+
523
+ .. testcode::
524
+ # Or go to any step
525
+ pytensor.dprint(history.goto(2))
526
+
527
+ .. testoutput::
528
+ Log [id A] 1
529
+ └─ Softmax{axis=None} [id B] 0
530
+ └─ x [id C]
531
+
532
+
533
+ """
534
+
535
+ def __init__ (self ):
536
+ self .fw = []
537
+ self .bw = []
538
+ self .pointer = - 1
539
+ self .fg = None
540
+
541
+ def on_attach (self , fgraph ):
542
+ if self .fg is not None :
543
+ raise ValueError ("Full History already attached to another fgraph" )
544
+ self .fg = fgraph
545
+
546
+ def on_change_input (self , fgraph , node , i , r , new_r , reason = None ):
547
+ self .bw .append (LambdaExtract (fgraph , node , i , r , reason ))
548
+ self .fw .append (LambdaExtract (fgraph , node , i , new_r , reason ))
549
+ self .pointer += 1
550
+
551
+ def goto (self , checkpoint ):
552
+ """
553
+ Reverts the graph to whatever it was at the provided
554
+ checkpoint (undoes all replacements). A checkpoint at any
555
+ given time can be obtained using self.checkpoint().
556
+
557
+ """
558
+ history_len = len (self .bw )
559
+ pointer = self .pointer
560
+ assert 0 <= checkpoint <= history_len
561
+ verbose = config .optimizer_verbose
562
+
563
+ # Go backwards
564
+ while pointer > checkpoint - 1 :
565
+ reverse_fn = self .bw [pointer ]
566
+ if verbose :
567
+ print (reverse_fn .reason ) # noqa: T201
568
+ reverse_fn ()
569
+ pointer -= 1
570
+
571
+ # Go forward
572
+ while pointer < checkpoint - 1 :
573
+ pointer += 1
574
+ forward_fn = self .fw [pointer ]
575
+ if verbose :
576
+ print (forward_fn .reason ) # noqa: T201
577
+ forward_fn ()
578
+
579
+ # Remove history changes caused by the foward/backward!
580
+ self .bw = self .bw [:history_len ]
581
+ self .fw = self .fw [:history_len ]
582
+ self .pointer = pointer
583
+ return self .fg
584
+
585
+ def start (self ):
586
+ return self .goto (0 )
587
+
588
+ def end (self ):
589
+ return self .goto (len (self .bw ))
590
+
591
+ def prev (self ):
592
+ if self .pointer < 0 :
593
+ return self .fg
594
+ else :
595
+ return self .goto (self .pointer )
596
+
597
+ def next (self ):
598
+ if self .pointer >= len (self .bw ) - 1 :
599
+ return self .fg
600
+ else :
601
+ return self .goto (self .pointer + 2 )
602
+
603
+
441
604
class Validator (Feature ):
442
605
pickle_rm_attr = ["validate" , "consistent" ]
443
606
0 commit comments