Skip to content

Commit e0c26dd

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Reduce build size of op_addmm (#6018)
Summary: Pull Request resolved: #6018 200 K -> 30 K ghstack-source-id: 246985125 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63994874 fbshipit-source-id: 1d26b944be8b0cdc4e2343e9efa6d0e28a0a82e3
1 parent be86a2c commit e0c26dd

File tree

2 files changed

+48
-54
lines changed

2 files changed

+48
-54
lines changed

kernels/portable/cpu/op_addmm.cpp

Lines changed: 46 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
*/
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
10-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1111
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
1212
#include <executorch/kernels/portable/cpu/vec_ops.h>
1313
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -53,62 +53,54 @@ Tensor& addmm_out(
5353

5454
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
5555

56-
ScalarType alpha_dtype = utils::get_scalar_dtype(alpha);
57-
ScalarType beta_dtype = utils::get_scalar_dtype(beta);
58-
ET_SWITCH_REAL_TYPES_AND(
59-
Half, in.scalar_type(), ctx, "addmm.out", CTYPE, [&]() {
60-
ET_SWITCH_SCALAR_OBJ_TYPES(
61-
alpha_dtype, ctx, "addmm.out", ALPHA_T, [&]() {
62-
ET_SWITCH_SCALAR_OBJ_TYPES(
63-
beta_dtype, ctx, "addmm.out", BETA_T, [&]() {
64-
size_t m = mat1.size(0);
65-
size_t n = mat1.size(1);
66-
size_t p = mat2.size(1);
56+
// @lint-ignore CLANGTIDY facebook-hte-CArray
57+
static constexpr const char op_name[] = "addmm.out";
6758

68-
if (out.sizes() == in.sizes()) {
69-
// vec_addmm assumes that no broadcasting is required.
70-
vec_addmm<CTYPE, CTYPE>(
71-
out.mutable_data_ptr<CTYPE>(),
72-
in.const_data_ptr<CTYPE>(),
73-
mat1.const_data_ptr<CTYPE>(),
74-
mat2.const_data_ptr<CTYPE>(),
75-
m,
76-
n,
77-
p,
78-
convert<CTYPE>(beta.to<BETA_T>()),
79-
convert<CTYPE>(alpha.to<ALPHA_T>()));
80-
} else {
81-
// If broadcasting is required, them compute the matmul
82-
// and addition separately, using
83-
// apply_binary_elementwise_fn to perform the addition
84-
// while applying broadcasting
85-
vec_matmul<CTYPE, CTYPE>(
86-
out.mutable_data_ptr<CTYPE>(),
87-
mat1.const_data_ptr<CTYPE>(),
88-
mat2.const_data_ptr<CTYPE>(),
89-
m,
90-
n,
91-
p);
59+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&]() {
60+
CTYPE alpha_val = utils::scalar_to<CTYPE>(alpha);
61+
CTYPE beta_val = utils::scalar_to<CTYPE>(beta);
62+
size_t m = mat1.size(0);
63+
size_t n = mat1.size(1);
64+
size_t p = mat2.size(1);
9265

93-
CTYPE alpha_val = convert<CTYPE>(alpha.to<ALPHA_T>());
94-
CTYPE beta_val = convert<CTYPE>(beta.to<BETA_T>());
95-
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
96-
[alpha_val, beta_val](
97-
const CTYPE val_a, const CTYPE val_b) {
98-
CTYPE a_casted = static_cast<CTYPE>(val_a);
99-
CTYPE b_casted = static_cast<CTYPE>(val_b);
100-
CTYPE value =
101-
a_casted * alpha_val + b_casted * beta_val;
66+
if (out.sizes() == in.sizes()) {
67+
// vec_addmm assumes that no broadcasting is required.
68+
vec_addmm<CTYPE, CTYPE>(
69+
out.mutable_data_ptr<CTYPE>(),
70+
in.const_data_ptr<CTYPE>(),
71+
mat1.const_data_ptr<CTYPE>(),
72+
mat2.const_data_ptr<CTYPE>(),
73+
m,
74+
n,
75+
p,
76+
beta_val,
77+
alpha_val);
78+
} else {
79+
// If broadcasting is required, them compute the matmul
80+
// and addition separately, using
81+
// apply_binary_elementwise_fn to perform the addition
82+
// while applying broadcasting
83+
vec_matmul<CTYPE, CTYPE>(
84+
out.mutable_data_ptr<CTYPE>(),
85+
mat1.const_data_ptr<CTYPE>(),
86+
mat2.const_data_ptr<CTYPE>(),
87+
m,
88+
n,
89+
p);
10290

103-
return value;
104-
},
105-
out,
106-
in,
107-
out);
108-
}
109-
});
110-
});
111-
});
91+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
92+
[alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
93+
return val_a * alpha_val + val_b * beta_val;
94+
},
95+
ctx,
96+
out,
97+
utils::SupportedTensorDtypes::REALHBF16,
98+
in,
99+
utils::SupportedTensorDtypes::REALHBF16,
100+
out,
101+
utils::SupportedTensorDtypes::REALHBF16);
102+
}
103+
});
112104

113105
return out;
114106
}

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,8 @@ ATEN_OPS = (
224224
name = "op_addmm",
225225
deps = [
226226
"//executorch/kernels/portable/cpu/util:broadcast_util",
227+
"//executorch/kernels/portable/cpu/util:dtype_util",
228+
"//executorch/kernels/portable/cpu/util:elementwise_util",
227229
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
228230
":scalar_utils",
229231
":vec_ops",

0 commit comments

Comments
 (0)