Skip to content

Commit de7c069

Browse files
author
Ian Schweer
committed
Do iteration instead of vmap for elemwise
1 parent 2dce8e4 commit de7c069

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from itertools import chain
2+
13
import torch
24

35
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
6+
from pytensor.scalar import ScalarLoop
47
from pytensor.tensor.elemwise import DimShuffle, Elemwise
58
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
69
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@@ -17,6 +20,34 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
1720
def elemwise_fn(*inputs):
1821
Elemwise._check_runtime_broadcast(node, inputs)
1922
return base_fn(*inputs)
23+
24+
elif isinstance(scalar_op, ScalarLoop):
25+
# note: scalarloop + elemwise is too common
26+
# to not work, but @1031, vmap won't allow it.
27+
# Instead, we will just successively unbind
28+
def elemwise_fn(*inputs):
29+
Elemwise._check_runtime_broadcast(node, inputs)
30+
shaped_inputs = torch.broadcast_tensors(*inputs)
31+
expected_size = shaped_inputs[0].numel()
32+
final_inputs = [s.clone() for s in shaped_inputs]
33+
for _ in range(shaped_inputs[0].dim() - 1):
34+
for i, _ in enumerate(shaped_inputs):
35+
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
36+
final_inputs[i] = list(layer)
37+
38+
# make sure we still have the same number of things
39+
assert len(final_inputs) == len(shaped_inputs)
40+
41+
# make sure each group of things are the expected size
42+
assert all(len(x) == expected_size for x in final_inputs)
43+
44+
# make sure they are all single elements
45+
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
46+
res = [base_fn(*args) for args in zip(*final_inputs)]
47+
states = torch.stack(tuple(out[0] for out in res))
48+
done = torch.stack(tuple(out[1] for out in res))
49+
return states, done
50+
2051
else:
2152

2253
def elemwise_fn(*inputs):
@@ -26,6 +57,7 @@ def elemwise_fn(*inputs):
2657
for _ in range(broadcast_inputs[0].dim()):
2758
ufunc = torch.vmap(ufunc)
2859
return ufunc(*broadcast_inputs)
60+
return base_fn(*inputs)
2961

3062
return elemwise_fn
3163

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
import pytest
66

7+
import pytensor.tensor as pt
78
import pytensor.tensor.basic as ptb
89
from pytensor.compile.builders import OpFromGraph
910
from pytensor.compile.function import function
@@ -431,10 +432,19 @@ def test_ScalarLoop_Elemwise():
431432
x = x0 * 2
432433
until = x >= 10
433434

434-
op = ScalarLoop(init=[x0], update=[x], until=until)
435-
fn = function([n_steps, x0], Elemwise(op)(n_steps, x0), mode=pytorch_mode)
435+
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
436+
op = Elemwise(scalarop)
437+
438+
n_steps = pt.scalar("n_steps", dtype="int32")
439+
x0 = pt.vector("x0", dtype="float32")
440+
state, done = op(n_steps, x0)
441+
442+
fn = function([n_steps, x0], [state, done], mode=pytorch_mode)
443+
py_fn = function([n_steps, x0], [state, done])
436444

437-
states, dones = fn(10, np.array(range(5)))
445+
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
446+
torch_states, torch_dones = fn(*args)
447+
py_states, py_dones = py_fn(*args)
438448

439-
np.testing.assert_allclose(states, [0, 4, 8, 12, 16])
440-
np.testing.assert_allclose(dones, [False, False, False, True, True])
449+
np.testing.assert_allclose(torch_states, py_states)
450+
np.testing.assert_allclose(torch_dones, py_dones)

0 commit comments

Comments
 (0)