Skip to content

Commit 08d23ab

Browse files
committed
Parametrized dtype in tests for Eye Op in PyTorch
1 parent daa86c4 commit 08d23ab

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/link/pytorch/test_basic.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,16 @@ def test_pytorch_Join():
277277
)
278278

279279

280-
def test_eye():
280+
@pytest.mark.parametrize(
281+
"dtype",
282+
["int64", config.floatX],
283+
)
284+
def test_eye(dtype):
281285
N = scalar("N", dtype="int64")
282286
M = scalar("M", dtype="int64")
283287
k = scalar("k", dtype="int64")
284288

285-
out = eye(N, M, k, dtype="float32")
289+
out = eye(N, M, k, dtype=dtype)
286290

287291
fn = function([N, M, k], out, mode=pytorch_mode)
288292

0 commit comments

Comments
 (0)