@@ -28,6 +28,19 @@ class OpGluOutTest : public OperatorTest {
28
28
return torch::executor::aten::glu_outf (context_, self, dim, out);
29
29
}
30
30
31
+ template <ScalarType DTYPE>
32
+ void expect_tensor_close (Tensor actual, Tensor expected) {
33
+ if (DTYPE == ScalarType::Half || DTYPE == ScalarType::BFloat16) {
34
+ EXPECT_TENSOR_CLOSE_WITH_TOL (
35
+ actual,
36
+ expected,
37
+ 1e-2 ,
38
+ executorch::runtime::testing::internal::kDefaultAtol );
39
+ } else {
40
+ EXPECT_TENSOR_CLOSE (actual, expected);
41
+ }
42
+ }
43
+
31
44
// Common testing for glu operator
32
45
template <ScalarType DTYPE, ScalarType OUT_DTYPE>
33
46
void test_glu_out () {
@@ -41,14 +54,14 @@ class OpGluOutTest : public OperatorTest {
41
54
Tensor in = tf.ones (sizes);
42
55
Tensor out = tf_out.zeros (out_sizes_1);
43
56
op_glu_out (in, 0 , out);
44
- EXPECT_TENSOR_CLOSE (
57
+ expect_tensor_close<DTYPE> (
45
58
out,
46
59
tf_out.make (
47
60
out_sizes_1, /* data=*/ {0.731059 , 0.731059 , 0.731059 , 0.731059 }));
48
61
const std::vector<int32_t > out_sizes_2 = {4 , 1 };
49
62
out = tf_out.zeros (out_sizes_2);
50
63
op_glu_out (in, 1 , out);
51
- EXPECT_TENSOR_CLOSE (
64
+ expect_tensor_close<DTYPE> (
52
65
out,
53
66
tf_out.make (
54
67
out_sizes_2, /* data=*/ {0.731059 , 0.731059 , 0.731059 , 0.731059 }));
0 commit comments