Skip to content

Commit 1d88092

Browse files
swolchokYIWENX14
authored andcommitted
Support Half/BFloat16 in leaky_relu (#7825)
Partial fix for #7748.
1 parent 82ca9cf commit 1d88092

File tree

2 files changed

+15
-9
lines changed

2 files changed

+15
-9
lines changed

kernels/portable/cpu/op_leaky_relu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ Tensor& leaky_relu_out(
4444

4545
ET_KERNEL_CHECK(ctx, in_type == out_type, InvalidArgument, out);
4646

47-
ET_SWITCH_FLOAT_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
47+
ET_SWITCH_FLOATHBF16_TYPES(in_type, ctx, "leaky_relu.out", CTYPE, [&]() {
4848
CTYPE negative_slope_casted;
4949
ET_SWITCH_SCALAR_OBJ_TYPES(
5050
sc_type, ctx, "leaky_relu.out", CTYPE_MIN, [&]() {

kernels/test/op_leaky_relu_test.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,21 @@ class OpLeakyReluTest : public OperatorTest {
2929
return torch::executor::aten::leaky_relu_outf(
3030
context_, in, negative_slope, out);
3131
}
32-
};
32+
template <ScalarType DTYPE>
33+
void test_leaky_relu_dtype() {
34+
TensorFactory<DTYPE> tf;
35+
Tensor in = tf.ones({2, 2});
36+
Tensor out = tf.zeros({2, 2});
3337

34-
TEST_F(OpLeakyReluTest, SanityCheck) {
35-
TensorFactory<ScalarType::Float> tf;
36-
Tensor in = tf.ones({2, 2});
37-
Tensor out = tf.zeros({2, 2});
38+
Tensor ret = op_leaky_relu_out(in, -0.01, out);
3839

39-
Tensor ret = op_leaky_relu_out(in, -0.01, out);
40+
EXPECT_TENSOR_EQ(out, ret);
41+
EXPECT_TENSOR_EQ(out, tf.ones({2, 2}));
42+
}
43+
};
4044

41-
EXPECT_TENSOR_EQ(out, ret);
42-
EXPECT_TENSOR_EQ(out, tf.ones({2, 2}));
45+
TEST_F(OpLeakyReluTest, SanityCheck) {
46+
#define TEST_ENTRY(ctype, dtype) test_leaky_relu_dtype<ScalarType::dtype>();
47+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
48+
#undef TEST_ENTRY
4349
}

0 commit comments

Comments
 (0)