Skip to content

Commit fe0a7ec

Browse files
committed
fix: make elemwise test check against dtype in the graph
1 parent 584e506 commit fe0a7ec

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

tests/tensor/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -508,15 +508,17 @@ def test_good(self):
508508
if not isinstance(expecteds, list | tuple):
509509
expecteds = (expecteds,)
510510

511-
for i, (variable, expected) in enumerate(zip(variables, expecteds)):
511+
for i, (variable, expected, out_symbol) in enumerate(
512+
zip(variables, expecteds, node.outputs)
513+
):
512514
condition = (
513-
variable.dtype != expected.dtype
515+
variable.dtype != out_symbol.type.dtype
514516
or variable.shape != expected.shape
515517
or not np.allclose(variable, expected, atol=eps, rtol=eps)
516518
)
517519
assert not condition, (
518520
f"Test {self.op}::{testname}: Output {i} gave the wrong"
519-
f" value. With inputs {inputs}, expected {expected} (dtype {expected.dtype}),"
521+
f" value. With inputs {inputs}, expected {expected} (dtype {out_symbol.type.dtype}),"
520522
f" got {variable} (dtype {variable.dtype}). eps={eps:f}"
521523
f" np.allclose returns {np.allclose(variable, expected, atol=eps, rtol=eps)} {np.allclose(variable, expected)}"
522524
)

0 commit comments

Comments
 (0)