1
+ from itertools import chain
2
+
1
3
import torch
2
4
3
5
from pytensor .link .pytorch .dispatch .basic import pytorch_funcify
6
+ from pytensor .scalar import ScalarLoop
4
7
from pytensor .tensor .elemwise import DimShuffle , Elemwise
5
8
from pytensor .tensor .math import All , Any , Max , Min , Prod , Sum
6
9
from pytensor .tensor .special import LogSoftmax , Softmax , SoftmaxGrad
@@ -17,6 +20,34 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
17
20
def elemwise_fn (* inputs ):
18
21
Elemwise ._check_runtime_broadcast (node , inputs )
19
22
return base_fn (* inputs )
23
+
24
+ elif isinstance (scalar_op , ScalarLoop ):
25
+ # note: scalarloop + elemwise is too common
26
+ # to not work, but @1031, vmap won't allow it.
27
+ # Instead, we will just successively unbind
28
+ def elemwise_fn (* inputs ):
29
+ Elemwise ._check_runtime_broadcast (node , inputs )
30
+ shaped_inputs = torch .broadcast_tensors (* inputs )
31
+ expected_size = shaped_inputs [0 ].numel ()
32
+ final_inputs = [s .clone () for s in shaped_inputs ]
33
+ for _ in range (shaped_inputs [0 ].dim () - 1 ):
34
+ for i , _ in enumerate (shaped_inputs ):
35
+ layer = chain .from_iterable ([s .unbind (0 ) for s in final_inputs [i ]])
36
+ final_inputs [i ] = list (layer )
37
+
38
+ # make sure we still have the same number of things
39
+ assert len (final_inputs ) == len (shaped_inputs )
40
+
41
+ # make sure each group of things are the expected size
42
+ assert all (len (x ) == expected_size for x in final_inputs )
43
+
44
+ # make sure they are all single elements
45
+ assert all (len (x .shape ) == 0 for tensor in final_inputs for x in tensor )
46
+ res = [base_fn (* args ) for args in zip (* final_inputs )]
47
+ states = torch .stack (tuple (out [0 ] for out in res ))
48
+ done = torch .stack (tuple (out [1 ] for out in res ))
49
+ return states , done
50
+
20
51
else :
21
52
22
53
def elemwise_fn (* inputs ):
@@ -26,6 +57,7 @@ def elemwise_fn(*inputs):
26
57
for _ in range (broadcast_inputs [0 ].dim ()):
27
58
ufunc = torch .vmap (ufunc )
28
59
return ufunc (* broadcast_inputs )
60
+ return base_fn (* inputs )
29
61
30
62
return elemwise_fn
31
63
0 commit comments