Skip to content

Commit 42f744a

Browse files
pytorchbotswolchok
andauthored
Fix ATen mode op_logit_test
Was broken, now it's not. Differential Revision: [D68929577](https://our.internmc.facebook.com/intern/diff/D68929577/) ghstack-source-id: 263962875 Pull Request resolved: #8081 Co-authored-by: Scott Wolchok <[email protected]>
1 parent f56d66a commit 42f744a

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

kernels/test/op_logit_test.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,17 @@ class OpLogitOutTest : public OperatorTest {
5757

5858
op_logit_out(tf.make(sizes, /*data=*/{1, 2, 4, 8}), 0.1, out);
5959

60-
// Check that it matches (or close to) the expected output.
61-
EXPECT_TENSOR_CLOSE(
62-
out,
63-
tf_out.make(sizes, /*data=*/{2.197224, 2.197224, 2.197224, 2.197224}));
60+
auto expected =
61+
tf_out.make(sizes, /*data=*/{2.197224, 2.197224, 2.197224, 2.197224});
62+
if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
63+
EXPECT_TENSOR_CLOSE_WITH_TOL(
64+
out,
65+
expected,
66+
1e-2,
67+
executorch::runtime::testing::internal::kDefaultAtol);
68+
} else {
69+
EXPECT_TENSOR_CLOSE(out, expected);
70+
}
6471
}
6572

6673
// Unhandled output dtypes.

0 commit comments

Comments
 (0)