@@ -44,25 +44,24 @@ def test_debugprint_sitsot():
44
44
│ │ │ │ │ │ └─ 1.0 [id O]
45
45
│ │ │ │ │ └─ 0 [id P]
46
46
│ │ │ │ └─ Subtensor{i} [id Q]
47
- │ │ │ │ ├─ Shape [id R]
48
- │ │ │ │ │ └─ Unbroadcast{0} [id J]
49
- │ │ │ │ │ └─ ···
50
- │ │ │ │ └─ 1 [id S]
47
+ │ │ │ │ ├─ Shape [id I]
48
+ │ │ │ │ │ └─ ···
49
+ │ │ │ │ └─ 1 [id R]
51
50
│ │ │ ├─ Unbroadcast{0} [id J]
52
51
│ │ │ │ └─ ···
53
- │ │ │ └─ ScalarFromTensor [id T ]
52
+ │ │ │ └─ ScalarFromTensor [id S ]
54
53
│ │ │ └─ Subtensor{i} [id H]
55
54
│ │ │ └─ ···
56
55
│ │ └─ A [id M] (outer_in_non_seqs-0)
57
- │ └─ 1 [id U ]
58
- └─ -1 [id V ]
56
+ │ └─ 1 [id T ]
57
+ └─ -1 [id U ]
59
58
60
59
Inner graphs:
61
60
62
61
Scan{scan_fn, while_loop=False, inplace=none} [id C]
63
- ← Mul [id W ] (inner_out_sit_sot-0)
64
- ├─ *0-<Vector(float64, shape=(?,))> [id X ] -> [id E] (inner_in_sit_sot-0)
65
- └─ *1-<Vector(float64, shape=(?,))> [id Y ] -> [id M] (inner_in_non_seqs-0)
62
+ ← Mul [id V ] (inner_out_sit_sot-0)
63
+ ├─ *0-<Vector(float64, shape=(?,))> [id W ] -> [id E] (inner_in_sit_sot-0)
64
+ └─ *1-<Vector(float64, shape=(?,))> [id X ] -> [id M] (inner_in_non_seqs-0)
66
65
"""
67
66
68
67
for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
@@ -103,25 +102,24 @@ def test_debugprint_sitsot_no_extra_info():
103
102
│ │ │ │ │ │ └─ 1.0 [id O]
104
103
│ │ │ │ │ └─ 0 [id P]
105
104
│ │ │ │ └─ Subtensor{i} [id Q]
106
- │ │ │ │ ├─ Shape [id R]
107
- │ │ │ │ │ └─ Unbroadcast{0} [id J]
108
- │ │ │ │ │ └─ ···
109
- │ │ │ │ └─ 1 [id S]
105
+ │ │ │ │ ├─ Shape [id I]
106
+ │ │ │ │ │ └─ ···
107
+ │ │ │ │ └─ 1 [id R]
110
108
│ │ │ ├─ Unbroadcast{0} [id J]
111
109
│ │ │ │ └─ ···
112
- │ │ │ └─ ScalarFromTensor [id T ]
110
+ │ │ │ └─ ScalarFromTensor [id S ]
113
111
│ │ │ └─ Subtensor{i} [id H]
114
112
│ │ │ └─ ···
115
113
│ │ └─ A [id M]
116
- │ └─ 1 [id U ]
117
- └─ -1 [id V ]
114
+ │ └─ 1 [id T ]
115
+ └─ -1 [id U ]
118
116
119
117
Inner graphs:
120
118
121
119
Scan{scan_fn, while_loop=False, inplace=none} [id C]
122
- ← Mul [id W ]
123
- ├─ *0-<Vector(float64, shape=(?,))> [id X ] -> [id E]
124
- └─ *1-<Vector(float64, shape=(?,))> [id Y ] -> [id M]
120
+ ← Mul [id V ]
121
+ ├─ *0-<Vector(float64, shape=(?,))> [id W ] -> [id E]
122
+ └─ *1-<Vector(float64, shape=(?,))> [id X ] -> [id M]
125
123
"""
126
124
127
125
for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
@@ -288,25 +286,24 @@ def compute_A_k(A, k):
288
286
│ │ │ │ │ │ │ └─ 1.0 [id BQ]
289
287
│ │ │ │ │ │ └─ 0 [id BR]
290
288
│ │ │ │ │ └─ Subtensor{i} [id BS]
291
- │ │ │ │ │ ├─ Shape [id BT]
292
- │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
293
- │ │ │ │ │ │ └─ ···
294
- │ │ │ │ │ └─ 1 [id BU]
289
+ │ │ │ │ │ ├─ Shape [id BK]
290
+ │ │ │ │ │ │ └─ ···
291
+ │ │ │ │ │ └─ 1 [id BT]
295
292
│ │ │ │ ├─ Unbroadcast{0} [id BL]
296
293
│ │ │ │ │ └─ ···
297
- │ │ │ │ └─ ScalarFromTensor [id BV ]
294
+ │ │ │ │ └─ ScalarFromTensor [id BU ]
298
295
│ │ │ │ └─ Subtensor{i} [id BJ]
299
296
│ │ │ │ └─ ···
300
297
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
301
- │ │ └─ 1 [id BW ]
302
- │ └─ -1 [id BX ]
303
- └─ ExpandDims{axis=0} [id BY ]
304
- └─ *1-<Scalar(int64, shape=())> [id BZ ] -> [id U] (inner_in_seqs-1)
298
+ │ │ └─ 1 [id BV ]
299
+ │ └─ -1 [id BW ]
300
+ └─ ExpandDims{axis=0} [id BX ]
301
+ └─ *1-<Scalar(int64, shape=())> [id BY ] -> [id U] (inner_in_seqs-1)
305
302
306
303
Scan{scan_fn, while_loop=False, inplace=none} [id BE]
307
- ← Mul [id CA ] (inner_out_sit_sot-0)
308
- ├─ *0-<Vector(float64, shape=(?,))> [id CB ] -> [id BG] (inner_in_sit_sot-0)
309
- └─ *1-<Vector(float64, shape=(?,))> [id CC ] -> [id BO] (inner_in_non_seqs-0)
304
+ ← Mul [id BZ ] (inner_out_sit_sot-0)
305
+ ├─ *0-<Vector(float64, shape=(?,))> [id CA ] -> [id BG] (inner_in_sit_sot-0)
306
+ └─ *1-<Vector(float64, shape=(?,))> [id CB ] -> [id BO] (inner_in_non_seqs-0)
310
307
"""
311
308
312
309
for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
@@ -386,27 +383,26 @@ def compute_A_k(A, k):
386
383
│ │ │ │ │ │ │ └─ 1.0 [id BR]
387
384
│ │ │ │ │ │ └─ 0 [id BS]
388
385
│ │ │ │ │ └─ Subtensor{i} [id BT]
389
- │ │ │ │ │ ├─ Shape [id BU]
390
- │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
391
- │ │ │ │ │ │ └─ ···
392
- │ │ │ │ │ └─ 1 [id BV]
386
+ │ │ │ │ │ ├─ Shape [id BM]
387
+ │ │ │ │ │ │ └─ ···
388
+ │ │ │ │ │ └─ 1 [id BU]
393
389
│ │ │ │ ├─ Unbroadcast{0} [id BN]
394
390
│ │ │ │ │ └─ ···
395
- │ │ │ │ └─ ScalarFromTensor [id BW ]
391
+ │ │ │ │ └─ ScalarFromTensor [id BV ]
396
392
│ │ │ │ └─ Subtensor{i} [id BL]
397
393
│ │ │ │ └─ ···
398
394
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
399
- │ │ └─ 1 [id BX ]
400
- │ └─ -1 [id BY ]
401
- └─ ExpandDims{axis=0} [id BZ ]
395
+ │ │ └─ 1 [id BW ]
396
+ │ └─ -1 [id BX ]
397
+ └─ ExpandDims{axis=0} [id BY ]
402
398
└─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
403
399
404
400
Scan{scan_fn, while_loop=False, inplace=none} [id BH]
405
- → *0-<Vector(float64, shape=(?,))> [id CA ] -> [id BI] (inner_in_sit_sot-0)
406
- → *1-<Vector(float64, shape=(?,))> [id CB ] -> [id BA] (inner_in_non_seqs-0)
407
- ← Mul [id CC ] (inner_out_sit_sot-0)
408
- ├─ *0-<Vector(float64, shape=(?,))> [id CA ] (inner_in_sit_sot-0)
409
- └─ *1-<Vector(float64, shape=(?,))> [id CB ] (inner_in_non_seqs-0)
401
+ → *0-<Vector(float64, shape=(?,))> [id BZ ] -> [id BI] (inner_in_sit_sot-0)
402
+ → *1-<Vector(float64, shape=(?,))> [id CA ] -> [id BA] (inner_in_non_seqs-0)
403
+ ← Mul [id CB ] (inner_out_sit_sot-0)
404
+ ├─ *0-<Vector(float64, shape=(?,))> [id BZ ] (inner_in_sit_sot-0)
405
+ └─ *1-<Vector(float64, shape=(?,))> [id CA ] (inner_in_non_seqs-0)
410
406
"""
411
407
412
408
for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
@@ -528,98 +524,97 @@ def test_debugprint_mitmot():
528
524
│ │ │ │ │ │ │ │ └─ 1.0 [id R]
529
525
│ │ │ │ │ │ │ └─ 0 [id S]
530
526
│ │ │ │ │ │ └─ Subtensor{i} [id T]
531
- │ │ │ │ │ │ ├─ Shape [id U]
532
- │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
533
- │ │ │ │ │ │ │ └─ ···
534
- │ │ │ │ │ │ └─ 1 [id V]
527
+ │ │ │ │ │ │ ├─ Shape [id L]
528
+ │ │ │ │ │ │ │ └─ ···
529
+ │ │ │ │ │ │ └─ 1 [id U]
535
530
│ │ │ │ │ ├─ Unbroadcast{0} [id M]
536
531
│ │ │ │ │ │ └─ ···
537
- │ │ │ │ │ └─ ScalarFromTensor [id W ]
532
+ │ │ │ │ │ └─ ScalarFromTensor [id V ]
538
533
│ │ │ │ │ └─ Subtensor{i} [id K]
539
534
│ │ │ │ │ └─ ···
540
535
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
541
- │ │ │ └─ 0 [id X ]
542
- │ │ └─ 1 [id Y ]
543
- │ ├─ Subtensor{:stop} [id Z ] (outer_in_seqs-0)
544
- │ │ ├─ Subtensor{::step} [id BA ]
545
- │ │ │ ├─ Subtensor{:stop} [id BB ]
536
+ │ │ │ └─ 0 [id W ]
537
+ │ │ └─ 1 [id X ]
538
+ │ ├─ Subtensor{:stop} [id Y ] (outer_in_seqs-0)
539
+ │ │ ├─ Subtensor{::step} [id Z ]
540
+ │ │ │ ├─ Subtensor{:stop} [id BA ]
546
541
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
547
542
│ │ │ │ │ └─ ···
548
- │ │ │ │ └─ -1 [id BC ]
549
- │ │ │ └─ -1 [id BD ]
550
- │ │ └─ ScalarFromTensor [id BE ]
543
+ │ │ │ │ └─ -1 [id BB ]
544
+ │ │ │ └─ -1 [id BC ]
545
+ │ │ └─ ScalarFromTensor [id BD ]
551
546
│ │ └─ Sub [id C]
552
547
│ │ └─ ···
553
- │ ├─ Subtensor{:stop} [id BF ] (outer_in_seqs-1)
554
- │ │ ├─ Subtensor{:stop} [id BG ]
555
- │ │ │ ├─ Subtensor{::step} [id BH ]
548
+ │ ├─ Subtensor{:stop} [id BE ] (outer_in_seqs-1)
549
+ │ │ ├─ Subtensor{:stop} [id BF ]
550
+ │ │ │ ├─ Subtensor{::step} [id BG ]
556
551
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
557
552
│ │ │ │ │ └─ ···
558
- │ │ │ │ └─ -1 [id BI ]
559
- │ │ │ └─ -1 [id BJ ]
560
- │ │ └─ ScalarFromTensor [id BK ]
553
+ │ │ │ │ └─ -1 [id BH ]
554
+ │ │ │ └─ -1 [id BI ]
555
+ │ │ └─ ScalarFromTensor [id BJ ]
561
556
│ │ └─ Sub [id C]
562
557
│ │ └─ ···
563
- │ ├─ Subtensor{::step} [id BL ] (outer_in_mit_mot-0)
564
- │ │ ├─ IncSubtensor{start:} [id BM ]
565
- │ │ │ ├─ Second [id BN ]
558
+ │ ├─ Subtensor{::step} [id BK ] (outer_in_mit_mot-0)
559
+ │ │ ├─ IncSubtensor{start:} [id BL ]
560
+ │ │ │ ├─ Second [id BM ]
566
561
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
567
562
│ │ │ │ │ └─ ···
568
- │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO ]
569
- │ │ │ │ └─ 0.0 [id BP ]
570
- │ │ │ ├─ IncSubtensor{i} [id BQ ]
571
- │ │ │ │ ├─ Second [id BR ]
572
- │ │ │ │ │ ├─ Subtensor{start:} [id BS ]
563
+ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN ]
564
+ │ │ │ │ └─ 0.0 [id BO ]
565
+ │ │ │ ├─ IncSubtensor{i} [id BP ]
566
+ │ │ │ │ ├─ Second [id BQ ]
567
+ │ │ │ │ │ ├─ Subtensor{start:} [id BR ]
573
568
│ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
574
569
│ │ │ │ │ │ │ └─ ···
575
- │ │ │ │ │ │ └─ 1 [id BT ]
576
- │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU ]
577
- │ │ │ │ │ └─ 0.0 [id BV ]
578
- │ │ │ │ ├─ Second [id BW ]
579
- │ │ │ │ │ ├─ Subtensor{i} [id BX ]
580
- │ │ │ │ │ │ ├─ Subtensor{start:} [id BS ]
570
+ │ │ │ │ │ │ └─ 1 [id BS ]
571
+ │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT ]
572
+ │ │ │ │ │ └─ 0.0 [id BU ]
573
+ │ │ │ │ ├─ Second [id BV ]
574
+ │ │ │ │ │ ├─ Subtensor{i} [id BW ]
575
+ │ │ │ │ │ │ ├─ Subtensor{start:} [id BR ]
581
576
│ │ │ │ │ │ │ └─ ···
582
- │ │ │ │ │ │ └─ -1 [id BY ]
583
- │ │ │ │ │ └─ ExpandDims{axis=0} [id BZ ]
584
- │ │ │ │ │ └─ Second [id CA ]
585
- │ │ │ │ │ ├─ Sum{axes=None} [id CB ]
586
- │ │ │ │ │ │ └─ Subtensor{i} [id BX ]
577
+ │ │ │ │ │ │ └─ -1 [id BX ]
578
+ │ │ │ │ │ └─ ExpandDims{axis=0} [id BY ]
579
+ │ │ │ │ │ └─ Second [id BZ ]
580
+ │ │ │ │ │ ├─ Sum{axes=None} [id CA ]
581
+ │ │ │ │ │ │ └─ Subtensor{i} [id BW ]
587
582
│ │ │ │ │ │ └─ ···
588
- │ │ │ │ │ └─ 1.0 [id CC ]
589
- │ │ │ │ └─ -1 [id BY ]
590
- │ │ │ └─ 1 [id BT ]
591
- │ │ └─ -1 [id CD ]
592
- │ ├─ Alloc [id CE ] (outer_in_sit_sot-0)
593
- │ │ ├─ 0.0 [id CF ]
594
- │ │ ├─ Add [id CG ]
583
+ │ │ │ │ │ └─ 1.0 [id CB ]
584
+ │ │ │ │ └─ -1 [id BX ]
585
+ │ │ │ └─ 1 [id BS ]
586
+ │ │ └─ -1 [id CC ]
587
+ │ ├─ Alloc [id CD ] (outer_in_sit_sot-0)
588
+ │ │ ├─ 0.0 [id CE ]
589
+ │ │ ├─ Add [id CF ]
595
590
│ │ │ ├─ Sub [id C]
596
591
│ │ │ │ └─ ···
597
- │ │ │ └─ 1 [id CH ]
598
- │ │ └─ Subtensor{i} [id CI ]
599
- │ │ ├─ Shape [id CJ ]
592
+ │ │ │ └─ 1 [id CG ]
593
+ │ │ └─ Subtensor{i} [id CH ]
594
+ │ │ ├─ Shape [id CI ]
600
595
│ │ │ └─ A [id P]
601
- │ │ └─ 0 [id CK ]
596
+ │ │ └─ 0 [id CJ ]
602
597
│ └─ A [id P] (outer_in_non_seqs-0)
603
- └─ -1 [id CL ]
598
+ └─ -1 [id CK ]
604
599
605
600
Inner graphs:
606
601
607
602
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
608
- ← Add [id CM ] (inner_out_mit_mot-0-0)
609
- ├─ Mul [id CN ]
610
- │ ├─ *2-<Vector(float64, shape=(?,))> [id CO ] -> [id BL ] (inner_in_mit_mot-0-0)
611
- │ └─ *5-<Vector(float64, shape=(?,))> [id CP ] -> [id P] (inner_in_non_seqs-0)
612
- └─ *3-<Vector(float64, shape=(?,))> [id CQ ] -> [id BL ] (inner_in_mit_mot-0-1)
613
- ← Add [id CR ] (inner_out_sit_sot-0)
614
- ├─ Mul [id CS ]
615
- │ ├─ *2-<Vector(float64, shape=(?,))> [id CO ] -> [id BL ] (inner_in_mit_mot-0-0)
616
- │ └─ *0-<Vector(float64, shape=(?,))> [id CT ] -> [id Z ] (inner_in_seqs-0)
617
- └─ *4-<Vector(float64, shape=(?,))> [id CU ] -> [id CE ] (inner_in_sit_sot-0)
603
+ ← Add [id CL ] (inner_out_mit_mot-0-0)
604
+ ├─ Mul [id CM ]
605
+ │ ├─ *2-<Vector(float64, shape=(?,))> [id CN ] -> [id BK ] (inner_in_mit_mot-0-0)
606
+ │ └─ *5-<Vector(float64, shape=(?,))> [id CO ] -> [id P] (inner_in_non_seqs-0)
607
+ └─ *3-<Vector(float64, shape=(?,))> [id CP ] -> [id BK ] (inner_in_mit_mot-0-1)
608
+ ← Add [id CQ ] (inner_out_sit_sot-0)
609
+ ├─ Mul [id CR ]
610
+ │ ├─ *2-<Vector(float64, shape=(?,))> [id CN ] -> [id BK ] (inner_in_mit_mot-0-0)
611
+ │ └─ *0-<Vector(float64, shape=(?,))> [id CS ] -> [id Y ] (inner_in_seqs-0)
612
+ └─ *4-<Vector(float64, shape=(?,))> [id CT ] -> [id CD ] (inner_in_sit_sot-0)
618
613
619
614
Scan{scan_fn, while_loop=False, inplace=none} [id F]
620
- ← Mul [id CV ] (inner_out_sit_sot-0)
621
- ├─ *0-<Vector(float64, shape=(?,))> [id CT ] -> [id H] (inner_in_sit_sot-0)
622
- └─ *1-<Vector(float64, shape=(?,))> [id CW ] -> [id P] (inner_in_non_seqs-0)
615
+ ← Mul [id CU ] (inner_out_sit_sot-0)
616
+ ├─ *0-<Vector(float64, shape=(?,))> [id CS ] -> [id H] (inner_in_sit_sot-0)
617
+ └─ *1-<Vector(float64, shape=(?,))> [id CV ] -> [id P] (inner_in_non_seqs-0)
623
618
"""
624
619
625
620
for truth , out in zip (expected_output .split ("\n " ), lines , strict = True ):
0 commit comments