|
1 | 1 | from collections.abc import Callable, Iterable
|
2 | 2 | from functools import partial
|
| 3 | +from itertools import repeat, starmap |
| 4 | +from unittest.mock import MagicMock, call, patch |
3 | 5 |
|
4 | 6 | import numpy as np
|
5 | 7 | import pytest
|
@@ -439,12 +441,29 @@ def test_ScalarLoop_Elemwise():
|
439 | 441 | x0 = pt.vector("x0", dtype="float32")
|
440 | 442 | state, done = op(n_steps, x0)
|
441 | 443 |
|
442 |
| - fn = function([n_steps, x0], [state, done], mode=pytorch_mode) |
443 |
| - py_fn = function([n_steps, x0], [state, done]) |
444 |
| - |
| 444 | + f = FunctionGraph([n_steps, x0], [state, done]) |
445 | 445 | args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
|
446 |
| - torch_states, torch_dones = fn(*args) |
447 |
| - py_states, py_dones = py_fn(*args) |
448 |
| - |
449 |
| - np.testing.assert_allclose(torch_states, py_states) |
450 |
| - np.testing.assert_allclose(torch_dones, py_dones) |
| 446 | + compare_pytorch_and_py(f, args) |
| 447 | + |
| 448 | + |
| 449 | +torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") |
| 450 | + |
| 451 | + |
| 452 | +@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) |
| 453 | +@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise") |
| 454 | +def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes): |
| 455 | + args = [torch.ones(*s) for s in input_shapes[:-1]] + [ |
| 456 | + torch.zeros(*input_shapes[-1]) |
| 457 | + ] |
| 458 | + mock_inner_func = MagicMock() |
| 459 | + ret_value = torch.rand(2, 2).unbind(0) |
| 460 | + mock_inner_func.f.return_value = ret_value |
| 461 | + elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) |
| 462 | + result = elemwise_fn(*args) |
| 463 | + for actual, expected in zip(ret_value, result): |
| 464 | + assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected))) |
| 465 | + np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) |
| 466 | + |
| 467 | + expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) |
| 468 | + expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) |
| 469 | + mock_inner_func.f.assert_has_calls(expected_calls) |
0 commit comments