Skip to content

Commit 2a9ffd3

Browse files
author
Ian Schweer
committed
Add single carry test
1 parent 7027c4c commit 2a9ffd3

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,30 @@ def test_ScalarLoop_while():
425425
np.testing.assert_allclose(res[1], np.array(expected[1]))
426426

427427

428-
def test_ScalarLoop_Elemwise():
428+
def test_ScalarLoop_Elemwise_single_carries():
429+
n_steps = int64("n_steps")
430+
x0 = float64("x0")
431+
x = x0 * 2
432+
until = x >= 10
433+
434+
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
435+
op = Elemwise(scalarop)
436+
437+
n_steps = pt.scalar("n_steps", dtype="int32")
438+
x0 = pt.vector("x0", dtype="float32")
439+
state, done = op(n_steps, x0)
440+
441+
f = FunctionGraph([n_steps, x0], [state, done])
442+
args = [
443+
np.array(10).astype("int32"),
444+
np.arange(0, 5).astype("float32"),
445+
]
446+
compare_pytorch_and_py(
447+
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
448+
)
449+
450+
451+
def test_ScalarLoop_Elemwise_multi_carries():
429452
n_steps = int64("n_steps")
430453
x0 = float64("x0")
431454
x1 = float64("x1")

0 commit comments

Comments
 (0)