Skip to content

Commit c1a4dcc

Browse files
author
Ian Schweer
committed
Refactor to ravel method
1 parent c63846a commit c1a4dcc

File tree

2 files changed

+25
-84
lines changed

2 files changed

+25
-84
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from itertools import chain
2-
31
import torch
42

53
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
@@ -175,50 +173,32 @@ def elemwise_scalar_loop(base_fn, op, node, **kwargs):
175173
"""
176174
ScalarLoop + Elemwise is too common
177175
to not work, but @1031, vmap won't allow it.
178-
Instead, we can do the following strategy
179-
1. `.unbind(dim)` will return a list of tensors
180-
representing `dim` but "unwrapped". e.x.
181-
```
182-
t = torch.ones(3, 4, 2)
183-
len(t.unbind(0)) == 3
184-
t[0].shape == torch.Size[4, 2]
185-
2. If we successfully apply, the length of the list will grow
186-
by the next dimension in the tensor if we flatten the previous
187-
dimension result
188-
```
189-
inputs = [torch.ones(3, 4, 2)]
190-
level_1 = chain.from_iterable(t.unbind(0) for t in inputs)
191-
level_2 = chain.from_iterable(t.unbind(0) for t in level_1)
192-
len(level_2) == 3 * 4
193-
```
194-
3. Eventually we'll reach single dimension tensors. At that point
195-
we can iterate over each input in an element by element manner
196-
and call some function
197-
198-
For scalar loop, we need to broadcast the tensors so all
199-
the necessary values are repeated, and we "evenly" iterate through everything
176+
Instead, we can ravel all the inputs, broadcasted
177+
according to torch
200178
"""
201179

180+
n_outputs = len(node.outputs)
181+
202182
def elemwise_fn(*inputs):
203-
Elemwise._check_runtime_broadcast(node, inputs)
204-
shaped_inputs = torch.broadcast_tensors(*inputs)
205-
expected_size = shaped_inputs[0].numel()
206-
final_inputs = [s.clone() for s in shaped_inputs]
207-
for _ in range(shaped_inputs[0].dim() - 1):
208-
for i, _ in enumerate(shaped_inputs):
209-
layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]])
210-
final_inputs[i] = list(layer)
211-
212-
# make sure we still have the same number of things
213-
assert len(final_inputs) == len(shaped_inputs)
214-
215-
# make sure each group of things are the expected size
216-
assert all(len(x) == expected_size for x in final_inputs)
217-
218-
# make sure they are all single elements
219-
assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor)
220-
res = [base_fn(*args) for args in zip(*final_inputs)]
221-
222-
return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))]
183+
bcasted_inputs = torch.broadcast_tensors(*inputs)
184+
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]
185+
186+
out_shape = bcasted_inputs[0].size()
187+
out_size = out_shape.numel()
188+
raveled_outputs = [torch.zeros(out_size) for out in node.outputs]
189+
190+
for i in range(out_size):
191+
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))
192+
if n_outputs == 1:
193+
raveled_outputs[0][i] = core_outs
194+
else:
195+
for o in range(n_outputs):
196+
raveled_outputs[o][i] = core_outs[o]
197+
198+
outputs = tuple(out.view(out_shape) for out in raveled_outputs)
199+
if n_outputs == 1:
200+
return outputs[0]
201+
else:
202+
return outputs
223203

224204
return elemwise_fn

tests/link/pytorch/test_basic.py

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from collections.abc import Callable, Iterable
22
from functools import partial
3-
from itertools import repeat, starmap
4-
from unittest.mock import MagicMock, call, patch
53

64
import numpy as np
75
import pytest
@@ -408,25 +406,11 @@ def test_ScalarLoop_while():
408406
for res, expected in zip(
409407
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
410408
[[10, True], [10, True], [6, False]],
409+
strict=True,
411410
):
412411
np.testing.assert_allclose(res[0], np.array(expected[0]))
413412
np.testing.assert_allclose(res[1], np.array(expected[1]))
414413

415-
def test_pytorch_OpFromGraph():
416-
x, y, z = matrices("xyz")
417-
ofg_1 = OpFromGraph([x, y], [x + y])
418-
ofg_2 = OpFromGraph([x, y], [x * y, x - y])
419-
420-
o1, o2 = ofg_2(y, z)
421-
out = ofg_1(x, o1) + o2
422-
423-
xv = np.ones((2, 2), dtype=config.floatX)
424-
yv = np.ones((2, 2), dtype=config.floatX) * 3
425-
zv = np.ones((2, 2), dtype=config.floatX) * 5
426-
427-
f = FunctionGraph([x, y, z], [out])
428-
compare_pytorch_and_py(f, [xv, yv, zv])
429-
430414

431415
def test_ScalarLoop_Elemwise():
432416
n_steps = int64("n_steps")
@@ -444,26 +428,3 @@ def test_ScalarLoop_Elemwise():
444428
f = FunctionGraph([n_steps, x0], [state, done])
445429
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
446430
compare_pytorch_and_py(f, args)
447-
448-
449-
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")
450-
451-
452-
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
453-
@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise")
454-
def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes):
455-
args = [torch.ones(*s) for s in input_shapes[:-1]] + [
456-
torch.zeros(*input_shapes[-1])
457-
]
458-
mock_inner_func = MagicMock()
459-
ret_value = torch.rand(2, 2).unbind(0)
460-
mock_inner_func.f.return_value = ret_value
461-
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
462-
result = elemwise_fn(*args)
463-
for actual, expected in zip(ret_value, result):
464-
assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected)))
465-
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))
466-
467-
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
468-
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
469-
mock_inner_func.f.assert_has_calls(expected_calls)

0 commit comments

Comments
 (0)