Skip to content

Commit 893cd96

Browse files
author
Ian Schweer
committed
Update test
1 parent 7dd9edb commit 893cd96

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

pytensor/link/pytorch/dispatch/scalar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def cast(x):
5252

5353
return cast
5454

55+
5556
@pytorch_funcify.register(ScalarLoop)
5657
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
5758
update = pytorch_funcify(op.fgraph)
@@ -68,7 +69,7 @@ def scalar_loop(steps, *start_and_constants):
6869
*carry, done = update(*carry, *constants)
6970
if torch.any(done):
7071
break
71-
return *carry, done
72+
return *carry, done
7273
else:
7374

7475
def scalar_loop(steps, *start_and_constants):

tests/link/pytorch/test_basic.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -415,16 +415,25 @@ def test_ScalarLoop_while():
415415
def test_ScalarLoop_Elemwise():
416416
n_steps = int64("n_steps")
417417
x0 = float64("x0")
418+
x1 = float64("x1")
418419
x = x0 * 2
420+
x1_n = x1 * 3
419421
until = x >= 10
420422

421-
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
423+
scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until)
422424
op = Elemwise(scalarop)
423425

424426
n_steps = pt.scalar("n_steps", dtype="int32")
425427
x0 = pt.vector("x0", dtype="float32")
426-
state, done = op(n_steps, x0)
427-
428-
f = FunctionGraph([n_steps, x0], [state, done])
429-
args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")]
430-
compare_pytorch_and_py(f, args)
428+
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
429+
*states, done = op(n_steps, x0, x1)
430+
431+
f = FunctionGraph([n_steps, x0, x1], [*states, done])
432+
args = [
433+
np.array(10).astype("int32"),
434+
np.arange(0, 5).astype("float32"),
435+
np.random.rand(7, 3, 1).astype("float32"),
436+
]
437+
compare_pytorch_and_py(
438+
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
439+
)

0 commit comments

Comments
 (0)