|
7 | 7 | */
|
8 | 8 |
|
9 | 9 | #include <executorch/kernels/portable/cpu/scalar_utils.h>
|
10 |
| -#include <executorch/kernels/portable/cpu/util/broadcast_util.h> |
| 10 | +#include <executorch/kernels/portable/cpu/util/elementwise_util.h> |
11 | 11 | #include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
|
12 | 12 | #include <executorch/kernels/portable/cpu/vec_ops.h>
|
13 | 13 | #include <executorch/runtime/kernel/kernel_includes.h>
|
@@ -53,62 +53,54 @@ Tensor& addmm_out(
|
53 | 53 |
|
54 | 54 | ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
|
55 | 55 |
|
56 |
| - ScalarType alpha_dtype = utils::get_scalar_dtype(alpha); |
57 |
| - ScalarType beta_dtype = utils::get_scalar_dtype(beta); |
58 |
| - ET_SWITCH_REAL_TYPES_AND( |
59 |
| - Half, in.scalar_type(), ctx, "addmm.out", CTYPE, [&]() { |
60 |
| - ET_SWITCH_SCALAR_OBJ_TYPES( |
61 |
| - alpha_dtype, ctx, "addmm.out", ALPHA_T, [&]() { |
62 |
| - ET_SWITCH_SCALAR_OBJ_TYPES( |
63 |
| - beta_dtype, ctx, "addmm.out", BETA_T, [&]() { |
64 |
| - size_t m = mat1.size(0); |
65 |
| - size_t n = mat1.size(1); |
66 |
| - size_t p = mat2.size(1); |
| 56 | + // @lint-ignore CLANGTIDY facebook-hte-CArray |
| 57 | + static constexpr const char op_name[] = "addmm.out"; |
67 | 58 |
|
68 |
| - if (out.sizes() == in.sizes()) { |
69 |
| - // vec_addmm assumes that no broadcasting is required. |
70 |
| - vec_addmm<CTYPE, CTYPE>( |
71 |
| - out.mutable_data_ptr<CTYPE>(), |
72 |
| - in.const_data_ptr<CTYPE>(), |
73 |
| - mat1.const_data_ptr<CTYPE>(), |
74 |
| - mat2.const_data_ptr<CTYPE>(), |
75 |
| - m, |
76 |
| - n, |
77 |
| - p, |
78 |
| - convert<CTYPE>(beta.to<BETA_T>()), |
79 |
| - convert<CTYPE>(alpha.to<ALPHA_T>())); |
80 |
| - } else { |
81 |
| - // If broadcasting is required, them compute the matmul |
82 |
| - // and addition separately, using |
83 |
| - // apply_binary_elementwise_fn to perform the addition |
84 |
| - // while applying broadcasting |
85 |
| - vec_matmul<CTYPE, CTYPE>( |
86 |
| - out.mutable_data_ptr<CTYPE>(), |
87 |
| - mat1.const_data_ptr<CTYPE>(), |
88 |
| - mat2.const_data_ptr<CTYPE>(), |
89 |
| - m, |
90 |
| - n, |
91 |
| - p); |
| 59 | + ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() { |
| 60 | + CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha); |
| 61 | + CTYPE beta_val = utils::scalar_to<CTYPE>(beta); |
| 62 | + size_t m = mat1.size(0); |
| 63 | + size_t n = mat1.size(1); |
| 64 | + size_t p = mat2.size(1); |
92 | 65 |
|
93 |
| - CTYPE alpha_val = convert<CTYPE>(alpha.to<ALPHA_T>()); |
94 |
| - CTYPE beta_val = convert<CTYPE>(beta.to<BETA_T>()); |
95 |
| - apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>( |
96 |
| - [alpha_val, beta_val]( |
97 |
| - const CTYPE val_a, const CTYPE val_b) { |
98 |
| - CTYPE a_casted = static_cast<CTYPE>(val_a); |
99 |
| - CTYPE b_casted = static_cast<CTYPE>(val_b); |
100 |
| - CTYPE value = |
101 |
| - a_casted * alpha_val + b_casted * beta_val; |
| 66 | + if (out.sizes() == in.sizes()) { |
| 67 | + // vec_addmm assumes that no broadcasting is required. |
| 68 | + vec_addmm<CTYPE, CTYPE>( |
| 69 | + out.mutable_data_ptr<CTYPE>(), |
| 70 | + in.const_data_ptr<CTYPE>(), |
| 71 | + mat1.const_data_ptr<CTYPE>(), |
| 72 | + mat2.const_data_ptr<CTYPE>(), |
| 73 | + m, |
| 74 | + n, |
| 75 | + p, |
| 76 | + beta_val, |
| 77 | + alpha_val); |
| 78 | + } else { |
| 79 | + // If broadcasting is required, them compute the matmul |
| 80 | + // and addition separately, using |
| 81 | + // apply_binary_elementwise_fn to perform the addition |
| 82 | + // while applying broadcasting |
| 83 | + vec_matmul<CTYPE, CTYPE>( |
| 84 | + out.mutable_data_ptr<CTYPE>(), |
| 85 | + mat1.const_data_ptr<CTYPE>(), |
| 86 | + mat2.const_data_ptr<CTYPE>(), |
| 87 | + m, |
| 88 | + n, |
| 89 | + p); |
102 | 90 |
|
103 |
| - return value; |
104 |
| - }, |
105 |
| - out, |
106 |
| - in, |
107 |
| - out); |
108 |
| - } |
109 |
| - }); |
110 |
| - }); |
111 |
| - }); |
| 91 | + utils::apply_bitensor_elementwise_fn<CTYPE, op_name>( |
| 92 | + [alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) { |
| 93 | + return val_a * alpha_val + val_b * beta_val; |
| 94 | + }, |
| 95 | + ctx, |
| 96 | + out, |
| 97 | + utils::SupportedTensorDtypes::REALHBF16, |
| 98 | + in, |
| 99 | + utils::SupportedTensorDtypes::REALHBF16, |
| 100 | + out, |
| 101 | + utils::SupportedTensorDtypes::REALHBF16); |
| 102 | + } |
| 103 | + }); |
112 | 104 |
|
113 | 105 | return out;
|
114 | 106 | }
|
|
0 commit comments