Skip to content

Commit b27c59d

Browse files
committed
Simplify expand_empty
1 parent c70a886 commit b27c59d

File tree

2 files changed

+105
-110
lines changed

2 files changed

+105
-110
lines changed

pytensor/scan/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def expand_empty(tensor_var, size):
231231

232232
if size == 0:
233233
return tensor_var
234-
shapes = [tensor_var.shape[x] for x in range(tensor_var.ndim)]
235-
new_shape = [size + shapes[0]] + shapes[1:]
234+
shapes = tuple(tensor_var.shape)
235+
new_shape = (size + shapes[0], *shapes[1:])
236236
empty = AllocEmpty(tensor_var.dtype)(*new_shape)
237237

238238
ret = set_subtensor(empty[: shapes[0]], tensor_var)

tests/scan/test_printing.py

Lines changed: 103 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,24 @@ def test_debugprint_sitsot():
4444
│ │ │ │ │ │ └─ 1.0 [id O]
4545
│ │ │ │ │ └─ 0 [id P]
4646
│ │ │ │ └─ Subtensor{i} [id Q]
47-
│ │ │ │ ├─ Shape [id R]
48-
│ │ │ │ │ └─ Unbroadcast{0} [id J]
49-
│ │ │ │ │ └─ ···
50-
│ │ │ │ └─ 1 [id S]
47+
│ │ │ │ ├─ Shape [id I]
48+
│ │ │ │ │ └─ ···
49+
│ │ │ │ └─ 1 [id R]
5150
│ │ │ ├─ Unbroadcast{0} [id J]
5251
│ │ │ │ └─ ···
53-
│ │ │ └─ ScalarFromTensor [id T]
52+
│ │ │ └─ ScalarFromTensor [id S]
5453
│ │ │ └─ Subtensor{i} [id H]
5554
│ │ │ └─ ···
5655
│ │ └─ A [id M] (outer_in_non_seqs-0)
57-
│ └─ 1 [id U]
58-
└─ -1 [id V]
56+
│ └─ 1 [id T]
57+
└─ -1 [id U]
5958
6059
Inner graphs:
6160
6261
Scan{scan_fn, while_loop=False, inplace=none} [id C]
63-
← Mul [id W] (inner_out_sit_sot-0)
64-
├─ *0-<Vector(float64, shape=(?,))> [id X] -> [id E] (inner_in_sit_sot-0)
65-
└─ *1-<Vector(float64, shape=(?,))> [id Y] -> [id M] (inner_in_non_seqs-0)
62+
← Mul [id V] (inner_out_sit_sot-0)
63+
├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E] (inner_in_sit_sot-0)
64+
└─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M] (inner_in_non_seqs-0)
6665
"""
6766

6867
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -103,25 +102,24 @@ def test_debugprint_sitsot_no_extra_info():
103102
│ │ │ │ │ │ └─ 1.0 [id O]
104103
│ │ │ │ │ └─ 0 [id P]
105104
│ │ │ │ └─ Subtensor{i} [id Q]
106-
│ │ │ │ ├─ Shape [id R]
107-
│ │ │ │ │ └─ Unbroadcast{0} [id J]
108-
│ │ │ │ │ └─ ···
109-
│ │ │ │ └─ 1 [id S]
105+
│ │ │ │ ├─ Shape [id I]
106+
│ │ │ │ │ └─ ···
107+
│ │ │ │ └─ 1 [id R]
110108
│ │ │ ├─ Unbroadcast{0} [id J]
111109
│ │ │ │ └─ ···
112-
│ │ │ └─ ScalarFromTensor [id T]
110+
│ │ │ └─ ScalarFromTensor [id S]
113111
│ │ │ └─ Subtensor{i} [id H]
114112
│ │ │ └─ ···
115113
│ │ └─ A [id M]
116-
│ └─ 1 [id U]
117-
└─ -1 [id V]
114+
│ └─ 1 [id T]
115+
└─ -1 [id U]
118116
119117
Inner graphs:
120118
121119
Scan{scan_fn, while_loop=False, inplace=none} [id C]
122-
← Mul [id W]
123-
├─ *0-<Vector(float64, shape=(?,))> [id X] -> [id E]
124-
└─ *1-<Vector(float64, shape=(?,))> [id Y] -> [id M]
120+
← Mul [id V]
121+
├─ *0-<Vector(float64, shape=(?,))> [id W] -> [id E]
122+
└─ *1-<Vector(float64, shape=(?,))> [id X] -> [id M]
125123
"""
126124

127125
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -288,25 +286,24 @@ def compute_A_k(A, k):
288286
│ │ │ │ │ │ │ └─ 1.0 [id BQ]
289287
│ │ │ │ │ │ └─ 0 [id BR]
290288
│ │ │ │ │ └─ Subtensor{i} [id BS]
291-
│ │ │ │ │ ├─ Shape [id BT]
292-
│ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
293-
│ │ │ │ │ │ └─ ···
294-
│ │ │ │ │ └─ 1 [id BU]
289+
│ │ │ │ │ ├─ Shape [id BK]
290+
│ │ │ │ │ │ └─ ···
291+
│ │ │ │ │ └─ 1 [id BT]
295292
│ │ │ │ ├─ Unbroadcast{0} [id BL]
296293
│ │ │ │ │ └─ ···
297-
│ │ │ │ └─ ScalarFromTensor [id BV]
294+
│ │ │ │ └─ ScalarFromTensor [id BU]
298295
│ │ │ │ └─ Subtensor{i} [id BJ]
299296
│ │ │ │ └─ ···
300297
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
301-
│ │ └─ 1 [id BW]
302-
│ └─ -1 [id BX]
303-
└─ ExpandDims{axis=0} [id BY]
304-
└─ *1-<Scalar(int64, shape=())> [id BZ] -> [id U] (inner_in_seqs-1)
298+
│ │ └─ 1 [id BV]
299+
│ └─ -1 [id BW]
300+
└─ ExpandDims{axis=0} [id BX]
301+
└─ *1-<Scalar(int64, shape=())> [id BY] -> [id U] (inner_in_seqs-1)
305302
306303
Scan{scan_fn, while_loop=False, inplace=none} [id BE]
307-
← Mul [id CA] (inner_out_sit_sot-0)
308-
├─ *0-<Vector(float64, shape=(?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
309-
└─ *1-<Vector(float64, shape=(?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)
304+
← Mul [id BZ] (inner_out_sit_sot-0)
305+
├─ *0-<Vector(float64, shape=(?,))> [id CA] -> [id BG] (inner_in_sit_sot-0)
306+
└─ *1-<Vector(float64, shape=(?,))> [id CB] -> [id BO] (inner_in_non_seqs-0)
310307
"""
311308

312309
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -386,27 +383,26 @@ def compute_A_k(A, k):
386383
│ │ │ │ │ │ │ └─ 1.0 [id BR]
387384
│ │ │ │ │ │ └─ 0 [id BS]
388385
│ │ │ │ │ └─ Subtensor{i} [id BT]
389-
│ │ │ │ │ ├─ Shape [id BU]
390-
│ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
391-
│ │ │ │ │ │ └─ ···
392-
│ │ │ │ │ └─ 1 [id BV]
386+
│ │ │ │ │ ├─ Shape [id BM]
387+
│ │ │ │ │ │ └─ ···
388+
│ │ │ │ │ └─ 1 [id BU]
393389
│ │ │ │ ├─ Unbroadcast{0} [id BN]
394390
│ │ │ │ │ └─ ···
395-
│ │ │ │ └─ ScalarFromTensor [id BW]
391+
│ │ │ │ └─ ScalarFromTensor [id BV]
396392
│ │ │ │ └─ Subtensor{i} [id BL]
397393
│ │ │ │ └─ ···
398394
│ │ │ └─ *2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
399-
│ │ └─ 1 [id BX]
400-
│ └─ -1 [id BY]
401-
└─ ExpandDims{axis=0} [id BZ]
395+
│ │ └─ 1 [id BW]
396+
│ └─ -1 [id BX]
397+
└─ ExpandDims{axis=0} [id BY]
402398
└─ *1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
403399
404400
Scan{scan_fn, while_loop=False, inplace=none} [id BH]
405-
→ *0-<Vector(float64, shape=(?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
406-
→ *1-<Vector(float64, shape=(?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
407-
← Mul [id CC] (inner_out_sit_sot-0)
408-
├─ *0-<Vector(float64, shape=(?,))> [id CA] (inner_in_sit_sot-0)
409-
└─ *1-<Vector(float64, shape=(?,))> [id CB] (inner_in_non_seqs-0)
401+
→ *0-<Vector(float64, shape=(?,))> [id BZ] -> [id BI] (inner_in_sit_sot-0)
402+
→ *1-<Vector(float64, shape=(?,))> [id CA] -> [id BA] (inner_in_non_seqs-0)
403+
← Mul [id CB] (inner_out_sit_sot-0)
404+
├─ *0-<Vector(float64, shape=(?,))> [id BZ] (inner_in_sit_sot-0)
405+
└─ *1-<Vector(float64, shape=(?,))> [id CA] (inner_in_non_seqs-0)
410406
"""
411407

412408
for truth, out in zip(expected_output.split("\n"), lines, strict=True):
@@ -528,98 +524,97 @@ def test_debugprint_mitmot():
528524
│ │ │ │ │ │ │ │ └─ 1.0 [id R]
529525
│ │ │ │ │ │ │ └─ 0 [id S]
530526
│ │ │ │ │ │ └─ Subtensor{i} [id T]
531-
│ │ │ │ │ │ ├─ Shape [id U]
532-
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
533-
│ │ │ │ │ │ │ └─ ···
534-
│ │ │ │ │ │ └─ 1 [id V]
527+
│ │ │ │ │ │ ├─ Shape [id L]
528+
│ │ │ │ │ │ │ └─ ···
529+
│ │ │ │ │ │ └─ 1 [id U]
535530
│ │ │ │ │ ├─ Unbroadcast{0} [id M]
536531
│ │ │ │ │ │ └─ ···
537-
│ │ │ │ │ └─ ScalarFromTensor [id W]
532+
│ │ │ │ │ └─ ScalarFromTensor [id V]
538533
│ │ │ │ │ └─ Subtensor{i} [id K]
539534
│ │ │ │ │ └─ ···
540535
│ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
541-
│ │ │ └─ 0 [id X]
542-
│ │ └─ 1 [id Y]
543-
│ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
544-
│ │ ├─ Subtensor{::step} [id BA]
545-
│ │ │ ├─ Subtensor{:stop} [id BB]
536+
│ │ │ └─ 0 [id W]
537+
│ │ └─ 1 [id X]
538+
│ ├─ Subtensor{:stop} [id Y] (outer_in_seqs-0)
539+
│ │ ├─ Subtensor{::step} [id Z]
540+
│ │ │ ├─ Subtensor{:stop} [id BA]
546541
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
547542
│ │ │ │ │ └─ ···
548-
│ │ │ │ └─ -1 [id BC]
549-
│ │ │ └─ -1 [id BD]
550-
│ │ └─ ScalarFromTensor [id BE]
543+
│ │ │ │ └─ -1 [id BB]
544+
│ │ │ └─ -1 [id BC]
545+
│ │ └─ ScalarFromTensor [id BD]
551546
│ │ └─ Sub [id C]
552547
│ │ └─ ···
553-
│ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
554-
│ │ ├─ Subtensor{:stop} [id BG]
555-
│ │ │ ├─ Subtensor{::step} [id BH]
548+
│ ├─ Subtensor{:stop} [id BE] (outer_in_seqs-1)
549+
│ │ ├─ Subtensor{:stop} [id BF]
550+
│ │ │ ├─ Subtensor{::step} [id BG]
556551
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
557552
│ │ │ │ │ └─ ···
558-
│ │ │ │ └─ -1 [id BI]
559-
│ │ │ └─ -1 [id BJ]
560-
│ │ └─ ScalarFromTensor [id BK]
553+
│ │ │ │ └─ -1 [id BH]
554+
│ │ │ └─ -1 [id BI]
555+
│ │ └─ ScalarFromTensor [id BJ]
561556
│ │ └─ Sub [id C]
562557
│ │ └─ ···
563-
│ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
564-
│ │ ├─ IncSubtensor{start:} [id BM]
565-
│ │ │ ├─ Second [id BN]
558+
│ ├─ Subtensor{::step} [id BK] (outer_in_mit_mot-0)
559+
│ │ ├─ IncSubtensor{start:} [id BL]
560+
│ │ │ ├─ Second [id BM]
566561
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
567562
│ │ │ │ │ └─ ···
568-
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
569-
│ │ │ │ └─ 0.0 [id BP]
570-
│ │ │ ├─ IncSubtensor{i} [id BQ]
571-
│ │ │ │ ├─ Second [id BR]
572-
│ │ │ │ │ ├─ Subtensor{start:} [id BS]
563+
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BN]
564+
│ │ │ │ └─ 0.0 [id BO]
565+
│ │ │ ├─ IncSubtensor{i} [id BP]
566+
│ │ │ │ ├─ Second [id BQ]
567+
│ │ │ │ │ ├─ Subtensor{start:} [id BR]
573568
│ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
574569
│ │ │ │ │ │ │ └─ ···
575-
│ │ │ │ │ │ └─ 1 [id BT]
576-
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
577-
│ │ │ │ │ └─ 0.0 [id BV]
578-
│ │ │ │ ├─ Second [id BW]
579-
│ │ │ │ │ ├─ Subtensor{i} [id BX]
580-
│ │ │ │ │ │ ├─ Subtensor{start:} [id BS]
570+
│ │ │ │ │ │ └─ 1 [id BS]
571+
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BT]
572+
│ │ │ │ │ └─ 0.0 [id BU]
573+
│ │ │ │ ├─ Second [id BV]
574+
│ │ │ │ │ ├─ Subtensor{i} [id BW]
575+
│ │ │ │ │ │ ├─ Subtensor{start:} [id BR]
581576
│ │ │ │ │ │ │ └─ ···
582-
│ │ │ │ │ │ └─ -1 [id BY]
583-
│ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
584-
│ │ │ │ │ └─ Second [id CA]
585-
│ │ │ │ │ ├─ Sum{axes=None} [id CB]
586-
│ │ │ │ │ │ └─ Subtensor{i} [id BX]
577+
│ │ │ │ │ │ └─ -1 [id BX]
578+
│ │ │ │ │ └─ ExpandDims{axis=0} [id BY]
579+
│ │ │ │ │ └─ Second [id BZ]
580+
│ │ │ │ │ ├─ Sum{axes=None} [id CA]
581+
│ │ │ │ │ │ └─ Subtensor{i} [id BW]
587582
│ │ │ │ │ │ └─ ···
588-
│ │ │ │ │ └─ 1.0 [id CC]
589-
│ │ │ │ └─ -1 [id BY]
590-
│ │ │ └─ 1 [id BT]
591-
│ │ └─ -1 [id CD]
592-
│ ├─ Alloc [id CE] (outer_in_sit_sot-0)
593-
│ │ ├─ 0.0 [id CF]
594-
│ │ ├─ Add [id CG]
583+
│ │ │ │ │ └─ 1.0 [id CB]
584+
│ │ │ │ └─ -1 [id BX]
585+
│ │ │ └─ 1 [id BS]
586+
│ │ └─ -1 [id CC]
587+
│ ├─ Alloc [id CD] (outer_in_sit_sot-0)
588+
│ │ ├─ 0.0 [id CE]
589+
│ │ ├─ Add [id CF]
595590
│ │ │ ├─ Sub [id C]
596591
│ │ │ │ └─ ···
597-
│ │ │ └─ 1 [id CH]
598-
│ │ └─ Subtensor{i} [id CI]
599-
│ │ ├─ Shape [id CJ]
592+
│ │ │ └─ 1 [id CG]
593+
│ │ └─ Subtensor{i} [id CH]
594+
│ │ ├─ Shape [id CI]
600595
│ │ │ └─ A [id P]
601-
│ │ └─ 0 [id CK]
596+
│ │ └─ 0 [id CJ]
602597
│ └─ A [id P] (outer_in_non_seqs-0)
603-
└─ -1 [id CL]
598+
└─ -1 [id CK]
604599
605600
Inner graphs:
606601
607602
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
608-
← Add [id CM] (inner_out_mit_mot-0-0)
609-
├─ Mul [id CN]
610-
│ ├─ *2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
611-
│ └─ *5-<Vector(float64, shape=(?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
612-
└─ *3-<Vector(float64, shape=(?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
613-
← Add [id CR] (inner_out_sit_sot-0)
614-
├─ Mul [id CS]
615-
│ ├─ *2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
616-
│ └─ *0-<Vector(float64, shape=(?,))> [id CT] -> [id Z] (inner_in_seqs-0)
617-
└─ *4-<Vector(float64, shape=(?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
603+
← Add [id CL] (inner_out_mit_mot-0-0)
604+
├─ Mul [id CM]
605+
│ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0)
606+
│ └─ *5-<Vector(float64, shape=(?,))> [id CO] -> [id P] (inner_in_non_seqs-0)
607+
└─ *3-<Vector(float64, shape=(?,))> [id CP] -> [id BK] (inner_in_mit_mot-0-1)
608+
← Add [id CQ] (inner_out_sit_sot-0)
609+
├─ Mul [id CR]
610+
│ ├─ *2-<Vector(float64, shape=(?,))> [id CN] -> [id BK] (inner_in_mit_mot-0-0)
611+
│ └─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id Y] (inner_in_seqs-0)
612+
└─ *4-<Vector(float64, shape=(?,))> [id CT] -> [id CD] (inner_in_sit_sot-0)
618613
619614
Scan{scan_fn, while_loop=False, inplace=none} [id F]
620-
← Mul [id CV] (inner_out_sit_sot-0)
621-
├─ *0-<Vector(float64, shape=(?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
622-
└─ *1-<Vector(float64, shape=(?,))> [id CW] -> [id P] (inner_in_non_seqs-0)
615+
← Mul [id CU] (inner_out_sit_sot-0)
616+
├─ *0-<Vector(float64, shape=(?,))> [id CS] -> [id H] (inner_in_sit_sot-0)
617+
└─ *1-<Vector(float64, shape=(?,))> [id CV] -> [id P] (inner_in_non_seqs-0)
623618
"""
624619

625620
for truth, out in zip(expected_output.split("\n"), lines, strict=True):

0 commit comments

Comments
 (0)