Skip to content

Commit c63846a

Browse files
Ch0ronomatoIan Schweer
authored andcommitted
Add unit test to verify iteration
1 parent 3117383 commit c63846a

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from collections.abc import Callable, Iterable
22
from functools import partial
3+
from itertools import repeat, starmap
4+
from unittest.mock import MagicMock, call, patch
35

46
import numpy as np
57
import pytest
@@ -439,12 +441,29 @@ def test_ScalarLoop_Elemwise():
439441
x0 = pt.vector("x0", dtype="float32")
440442
state, done = op(n_steps, x0)
441443

442-
fn = function([n_steps, x0], [state, done], mode=pytorch_mode)
443-
py_fn = function([n_steps, x0], [state, done])
444-
444+
f = FunctionGraph([n_steps, x0], [state, done])
445445
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
446-
torch_states, torch_dones = fn(*args)
447-
py_states, py_dones = py_fn(*args)
448-
449-
np.testing.assert_allclose(torch_states, py_states)
450-
np.testing.assert_allclose(torch_dones, py_dones)
446+
compare_pytorch_and_py(f, args)
447+
448+
449+
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")
450+
451+
452+
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
453+
@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise")
454+
def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes):
455+
args = [torch.ones(*s) for s in input_shapes[:-1]] + [
456+
torch.zeros(*input_shapes[-1])
457+
]
458+
mock_inner_func = MagicMock()
459+
ret_value = torch.rand(2, 2).unbind(0)
460+
mock_inner_func.f.return_value = ret_value
461+
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
462+
result = elemwise_fn(*args)
463+
for actual, expected in zip(ret_value, result):
464+
assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected)))
465+
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))
466+
467+
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
468+
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
469+
mock_inner_func.f.assert_has_calls(expected_calls)

0 commit comments

Comments
 (0)