Skip to content

Commit 3117383

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Clean up and add description
1 parent de7c069 commit 3117383

File tree

1 file changed

+55
-25
lines changed

1 file changed

+55
-25
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 55 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
@pytorch_funcify.register(Elemwise)
1313
def pytorch_funcify_Elemwise(op, node, **kwargs):
1414
scalar_op = op.scalar_op
15+
1516
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1617

1718
if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]):
@@ -22,31 +23,7 @@ def elemwise_fn(*inputs):
2223
return base_fn(*inputs)
2324

2425
elif isinstance(scalar_op, ScalarLoop):
25-
# note: scalarloop + elemwise is too common
26-
# to not work, but @1031, vmap won't allow it.
27-
# Instead, we will just successively unbind
28-
def elemwise_fn(*inputs):
29-
Elemwise._check_runtime_broadcast(node, inputs)
30-
shaped_inputs = torch.broadcast_tensors(*inputs)
31-
expected_size = shaped_inputs[0].numel()
32-
final_inputs = [s.clone() for s in shaped_inputs]
33-
for _ in range(shaped_inputs[0].dim() - 1):
34-
for i, _ in enumerate(shaped_inputs):
35-
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
36-
final_inputs[i] = list(layer)
37-
38-
# make sure we still have the same number of things
39-
assert len(final_inputs) == len(shaped_inputs)
40-
41-
# make sure each group of things are the expected size
42-
assert all(len(x) == expected_size for x in final_inputs)
43-
44-
# make sure they are all single elements
45-
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
46-
res = [base_fn(*args) for args in zip(*final_inputs)]
47-
states = torch.stack(tuple(out[0] for out in res))
48-
done = torch.stack(tuple(out[1] for out in res))
49-
return states, done
26+
return elemwise_scalar_loop(base_fn, op, node, **kwargs)
5027

5128
else:
5229

@@ -192,3 +169,56 @@ def softmax_grad(dy, sm):
192169
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
193170

194171
return softmax_grad
172+
173+
174+
def elemwise_scalar_loop(base_fn, op, node, **kwargs):
175+
"""
176+
ScalarLoop + Elemwise is too common
177+
to not work, but @1031, vmap won't allow it.
178+
Instead, we can do the following strategy
179+
1. `.unbind(dim)` will return a list of tensors
180+
representing `dim` but "unwrapped". e.x.
181+
```
182+
t = torch.ones(3, 4, 2)
183+
len(t.unbind(0)) == 3
184+
t[0].shape == torch.Size[4, 2]
185+
2. If we successfully apply, the length of the list will grow
186+
by the next dimension in the tensor if we flatten the previous
187+
dimension result
188+
```
189+
inputs = [torch.ones(3, 4, 2)]
190+
level_1 = chain.from_iterable(t.unbind(0) for t in inputs)
191+
level_2 = chain.from_iterable(t.unbind(0) for t in level_1)
192+
len(level_2) == 3 * 4
193+
```
194+
3. Eventually we'll reach single dimension tensors. At that point
195+
we can iterate over each input in an element by element manner
196+
and call some function
197+
198+
For scalar loop, we need to broadcast the tensors so all
199+
the necessary values are repeated, and we "evenly" iterate through everything
200+
"""
201+
202+
def elemwise_fn(*inputs):
203+
Elemwise._check_runtime_broadcast(node, inputs)
204+
shaped_inputs = torch.broadcast_tensors(*inputs)
205+
expected_size = shaped_inputs[0].numel()
206+
final_inputs = [s.clone() for s in shaped_inputs]
207+
for _ in range(shaped_inputs[0].dim() - 1):
208+
for i, _ in enumerate(shaped_inputs):
209+
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
210+
final_inputs[i] = list(layer)
211+
212+
# make sure we still have the same number of things
213+
assert len(final_inputs) == len(shaped_inputs)
214+
215+
# make sure each group of things are the expected size
216+
assert all(len(x) == expected_size for x in final_inputs)
217+
218+
# make sure they are all single elements
219+
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
220+
res = [base_fn(*args) for args in zip(*final_inputs)]
221+
222+
return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))]
223+
224+
return elemwise_fn

0 commit comments

Comments
 (0)