12
12
@pytorch_funcify .register (Elemwise )
13
13
def pytorch_funcify_Elemwise (op , node , ** kwargs ):
14
14
scalar_op = op .scalar_op
15
+
15
16
base_fn = pytorch_funcify (scalar_op , node = node , ** kwargs )
16
17
17
18
if hasattr (scalar_op , "nfunc_spec" ) and hasattr (torch , scalar_op .nfunc_spec [0 ]):
@@ -22,31 +23,7 @@ def elemwise_fn(*inputs):
22
23
return base_fn (* inputs )
23
24
24
25
elif isinstance (scalar_op , ScalarLoop ):
25
- # note: scalarloop + elemwise is too common
26
- # to not work, but @1031, vmap won't allow it.
27
- # Instead, we will just successively unbind
28
- def elemwise_fn (* inputs ):
29
- Elemwise ._check_runtime_broadcast (node , inputs )
30
- shaped_inputs = torch .broadcast_tensors (* inputs )
31
- expected_size = shaped_inputs [0 ].numel ()
32
- final_inputs = [s .clone () for s in shaped_inputs ]
33
- for _ in range (shaped_inputs [0 ].dim () - 1 ):
34
- for i , _ in enumerate (shaped_inputs ):
35
- layer = chain .from_iterable ([s .unbind (0 ) for s in final_inputs [i ]])
36
- final_inputs [i ] = list (layer )
37
-
38
- # make sure we still have the same number of things
39
- assert len (final_inputs ) == len (shaped_inputs )
40
-
41
- # make sure each group of things are the expected size
42
- assert all (len (x ) == expected_size for x in final_inputs )
43
-
44
- # make sure they are all single elements
45
- assert all (len (x .shape ) == 0 for tensor in final_inputs for x in tensor )
46
- res = [base_fn (* args ) for args in zip (* final_inputs )]
47
- states = torch .stack (tuple (out [0 ] for out in res ))
48
- done = torch .stack (tuple (out [1 ] for out in res ))
49
- return states , done
26
+ return elemwise_scalar_loop (base_fn , op , node , ** kwargs )
50
27
51
28
else :
52
29
@@ -192,3 +169,56 @@ def softmax_grad(dy, sm):
192
169
return dy_times_sm - torch .sum (dy_times_sm , dim = axis , keepdim = True ) * sm
193
170
194
171
return softmax_grad
172
+
173
+
174
+ def elemwise_scalar_loop (base_fn , op , node , ** kwargs ):
175
+ """
176
+ ScalarLoop + Elemwise is too common
177
+ to not work, but @1031, vmap won't allow it.
178
+ Instead, we can do the following strategy
179
+ 1. `.unbind(dim)` will return a list of tensors
180
+ representing `dim` but "unwrapped". e.x.
181
+ ```
182
+ t = torch.ones(3, 4, 2)
183
+ len(t.unbind(0)) == 3
184
+ t[0].shape == torch.Size[4, 2]
185
+ 2. If we successfully apply, the length of the list will grow
186
+ by the next dimension in the tensor if we flatten the previous
187
+ dimension result
188
+ ```
189
+ inputs = [torch.ones(3, 4, 2)]
190
+ level_1 = chain.from_iterable(t.unbind(0) for t in inputs)
191
+ level_2 = chain.from_iterable(t.unbind(0) for t in level_1)
192
+ len(level_2) == 3 * 4
193
+ ```
194
+ 3. Eventually we'll reach single dimension tensors. At that point
195
+ we can iterate over each input in an element by element manner
196
+ and call some function
197
+
198
+ For scalar loop, we need to broadcast the tensors so all
199
+ the necessary values are repeated, and we "evenly" iterate through everything
200
+ """
201
+
202
+ def elemwise_fn (* inputs ):
203
+ Elemwise ._check_runtime_broadcast (node , inputs )
204
+ shaped_inputs = torch .broadcast_tensors (* inputs )
205
+ expected_size = shaped_inputs [0 ].numel ()
206
+ final_inputs = [s .clone () for s in shaped_inputs ]
207
+ for _ in range (shaped_inputs [0 ].dim () - 1 ):
208
+ for i , _ in enumerate (shaped_inputs ):
209
+ layer = chain .from_iterable ([s .unbind (0 ) for s in final_inputs [i ]])
210
+ final_inputs [i ] = list (layer )
211
+
212
+ # make sure we still have the same number of things
213
+ assert len (final_inputs ) == len (shaped_inputs )
214
+
215
+ # make sure each group of things are the expected size
216
+ assert all (len (x ) == expected_size for x in final_inputs )
217
+
218
+ # make sure they are all single elements
219
+ assert all (len (x .shape ) == 0 for tensor in final_inputs for x in tensor )
220
+ res = [base_fn (* args ) for args in zip (* final_inputs )]
221
+
222
+ return [torch .stack (tuple (out [i ] for out in res )) for i in range (len (res [0 ]))]
223
+
224
+ return elemwise_fn
0 commit comments