Skip to content

Commit 7027c4c

Browse files
author
Ian Schweer
committed
Update test
1 parent fd2f192 commit 7027c4c

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def cast(x):
6464
def pytorch_funcify_Softplus(op, node, **kwargs):
6565
return torch.nn.Softplus()
6666

67+
6768
@pytorch_funcify.register(ScalarLoop)
6869
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
6970
update = pytorch_funcify(op.fgraph)
@@ -80,7 +81,7 @@ def scalar_loop(steps, *start_and_constants):
8081
*carry, done = update(*carry, *constants)
8182
if torch.any(done):
8283
break
83-
return *carry, done
84+
return *carry, done
8485
else:
8586

8687
def scalar_loop(steps, *start_and_constants):

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,16 +428,25 @@ def test_ScalarLoop_while():
428428
def test_ScalarLoop_Elemwise():
429429
n_steps = int64("n_steps")
430430
x0 = float64("x0")
431+
x1 = float64("x1")
431432
x = x0 * 2
433+
x1_n = x1 * 3
432434
until = x >= 10
433435

434-
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
436+
scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until)
435437
op = Elemwise(scalarop)
436438

437439
n_steps = pt.scalar("n_steps", dtype="int32")
438440
x0 = pt.vector("x0", dtype="float32")
439-
state, done = op(n_steps, x0)
440-
441-
f = FunctionGraph([n_steps, x0], [state, done])
442-
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
443-
compare_pytorch_and_py(f, args)
441+
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
442+
*states, done = op(n_steps, x0, x1)
443+
444+
f = FunctionGraph([n_steps, x0, x1], [*states, done])
445+
args = [
446+
np.array(10).astype("int32"),
447+
np.arange(0, 5).astype("float32"),
448+
np.random.rand(7, 3, 1).astype("float32"),
449+
]
450+
compare_pytorch_and_py(
451+
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
452+
)

0 commit comments

Comments
 (0)