Skip to content

Commit 43cda48

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Modernize matmul kernels
Reviewed By: manuelcandales Differential Revision: D48272462 fbshipit-source-id: 1095f8c3d2a67f6acac34bd59251095ec8a00835
1 parent 09e8e19 commit 43cda48

File tree

7 files changed

+274
-377
lines changed

7 files changed

+274
-377
lines changed

kernels/portable/cpu/op_addmm.cpp

Lines changed: 70 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -8,195 +8,93 @@
88

99
#include <executorch/kernels/portable/cpu/scalar_utils.h>
1010
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
11+
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
1112
#include <executorch/kernels/portable/cpu/vec_ops.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

14-
/**
15-
* torch.addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) → Tensor
16-
* Performs a matrix multiplication of the matrices mat1 and mat2. The matrix
17-
* input is added to the final result.
18-
*
19-
* If mat1 is a (n \times m)(n×m) tensor, mat2 is a (m \times p)(m×p) tensor,
20-
* then input must be broadcastable with a (n \times p)(n×p) tensor and out will
21-
* be a (n \times p)(n×p) tensor.
22-
*
23-
* alpha and beta are scaling factors on matrix-vector product between mat1 and
24-
* mat2 and the added matrix input respectively.
25-
*
26-
* out= β input+α (mat1 @ mat2)
27-
* If beta is 0, then input will be ignored, and nan and inf in it will not be
28-
* propagated.
29-
*
30-
* For inputs of type FloatTensor or DoubleTensor, arguments beta and alpha must
31-
* be real numbers, otherwise they should be integers.
32-
*/
3315
namespace torch {
3416
namespace executor {
3517
namespace native {
3618

3719
using Tensor = exec_aten::Tensor;
3820
using Scalar = exec_aten::Scalar;
3921

40-
namespace {
41-
42-
/**
43-
* Asserts that the parameters are valid.
44-
* mat1 (m x n), mat2 (n x p), out (m, p), self (m x p)
45-
* z[i][j] = sum(x[i][k] * y[k][j]), for k in range(n)
46-
*/
47-
void check_addmm_out_args(
48-
const Tensor& self,
49-
const Tensor& mat1,
50-
const Tensor& mat2,
51-
const Scalar& beta,
52-
const Scalar& alpha,
53-
Tensor& out) {
54-
// Ensure self can be broadcasted to out
55-
ET_CHECK_MSG(
56-
tensor_is_broadcastable_to(self, out),
57-
"input tensor can not be broadcasted to out");
58-
// Ensure dimension is 2 for all tensors.
59-
// Does not test self here because it will be broadcasted to out.size() after
60-
// this function, so we just need to ensure out.dim() meets the requirement.
61-
ET_CHECK_MSG(mat1.dim() == 2, "mat1.dim() %zd != 2", mat1.dim());
62-
ET_CHECK_MSG(mat2.dim() == 2, "mat2.dim() %zd != 2", mat2.dim());
63-
ET_CHECK_MSG(out.dim() == 2, "out.dim() %zd != 2", out.dim());
64-
// Ensure 4 tensors are having the same dtype
65-
ET_CHECK_SAME_DTYPE3(self, mat1, mat2);
66-
ET_CHECK_SAME_DTYPE2(self, out);
67-
// Ensure beta and alpha are having the same type. Maybe support mixing types
68-
// in the future
69-
ET_CHECK_SCALAR_SAME_TYPE(beta, alpha);
70-
// Ensure the out size is compatible with input tensors
71-
ET_CHECK_MSG(
72-
mat2.size(1) == out.size(1),
73-
"mat2.size(1) %zd != out.size(1) %zd",
74-
mat2.size(1),
75-
out.size(1));
76-
ET_CHECK_MSG(
77-
mat1.size(0) == out.size(0),
78-
"mat1.size(0) %zd != out.size(0) %zd",
79-
mat1.size(0),
80-
out.size(0));
81-
// Ensure mat1 is able to multiply with mat2
82-
ET_CHECK_MSG(
83-
mat1.size(1) == mat2.size(0),
84-
"mat1.size(1) %zd != mat2.size(0) %zd",
85-
mat1.size(1),
86-
mat2.size(0));
87-
}
88-
89-
// for simplicity, assuming all tensors are of the same type and all scalars are
90-
// the same type. `self` can be broadasted to mat1@mat2. T is the tensor dtype
91-
// and we are handling scalar types inside.
92-
template <typename T>
93-
Tensor& addmm_out_kernel(
94-
const Tensor& self,
95-
const Tensor& mat1,
96-
const Tensor& mat2,
97-
const Scalar& beta,
98-
const Scalar& alpha,
99-
Tensor& out) {
100-
const T* self_data = self.const_data_ptr<T>();
101-
const T* mat1_data = mat1.const_data_ptr<T>();
102-
const T* mat2_data = mat2.const_data_ptr<T>();
103-
T* out_data = out.mutable_data_ptr<T>();
104-
105-
size_t m = mat1.size(0);
106-
size_t n = mat1.size(1);
107-
size_t p = mat2.size(1);
108-
109-
if (beta.isBoolean()) {
110-
vec_addmm<T, bool>(
111-
out_data,
112-
self_data,
113-
mat1_data,
114-
mat2_data,
115-
m,
116-
n,
117-
p,
118-
beta.to<bool>(),
119-
alpha.to<bool>());
120-
} else if (beta.isIntegral(/*includeBool=*/false)) {
121-
vec_addmm<T, int64_t>(
122-
out_data,
123-
self_data,
124-
mat1_data,
125-
mat2_data,
126-
m,
127-
n,
128-
p,
129-
beta.to<int64_t>(),
130-
alpha.to<int64_t>());
131-
} else if (beta.isFloatingPoint()) {
132-
vec_addmm<T, double>(
133-
out_data,
134-
self_data,
135-
mat1_data,
136-
mat2_data,
137-
m,
138-
n,
139-
p,
140-
beta.to<double>(),
141-
alpha.to<double>());
142-
} else {
143-
ET_CHECK_MSG(false, "Unhandled scalar type");
144-
}
145-
return out;
146-
}
147-
148-
void resize_out_tensor(const Tensor& mat1, const Tensor& mat2, Tensor& out) {
149-
Tensor::SizesType expected_output_size[2];
150-
expected_output_size[0] = mat1.size(0);
151-
expected_output_size[1] = mat2.size(1);
152-
153-
ArrayRef<Tensor::SizesType> output_size{
154-
expected_output_size, static_cast<size_t>(out.dim())};
155-
156-
torch::executor::Error err = resize_tensor(out, output_size);
157-
ET_CHECK_MSG(
158-
err == torch::executor::Error::Ok,
159-
"Failed to resize out Tensor in addmm_out");
160-
}
161-
} // namespace
162-
163-
/**
164-
* addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar
165-
* alpha=1, Tensor(a!) out) -> Tensor(a!)
166-
*/
16722
Tensor& addmm_out(
16823
RuntimeContext& ctx,
169-
const Tensor& self,
24+
const Tensor& in,
17025
const Tensor& mat1,
17126
const Tensor& mat2,
17227
const Scalar& beta,
17328
const Scalar& alpha,
17429
Tensor& out) {
175-
resize_out_tensor(mat1, mat2, out);
176-
check_addmm_out_args(self, mat1, mat2, beta, alpha, out);
177-
178-
// The tensor self needs to be broadcasted iff its is size differnet from the
179-
// target one (out.size())
180-
bool broadcasted = !out.sizes().equals(self.sizes());
181-
const Tensor& broadcasted_tensor =
182-
broadcasted ? broadcast_tensor(self, out) : self;
183-
auto scalar_type = broadcasted_tensor.scalar_type();
184-
185-
#define ADDMM_TENSOR(ctype, dtype) \
186-
case ScalarType::dtype: \
187-
addmm_out_kernel<ctype>(broadcasted_tensor, mat1, mat2, beta, alpha, out); \
188-
break;
189-
190-
switch (scalar_type) {
191-
ET_FORALL_REAL_TYPES(ADDMM_TENSOR)
192-
default:
193-
ET_CHECK_MSG(false, "Unhandled dtype %hhd", scalar_type);
194-
}
195-
#undef ADDMM_TENSOR
196-
197-
if (broadcasted) {
198-
free_broadcast_tensor(broadcasted_tensor);
199-
}
30+
ET_KERNEL_CHECK(
31+
ctx,
32+
check_addmm_args(in, mat1, mat2, beta, alpha, out),
33+
InvalidArgument,
34+
out);
35+
36+
size_t output_ndim = 0;
37+
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
38+
get_mm_out_target_size(mat1, mat2, output_sizes, &output_ndim);
39+
ET_KERNEL_CHECK(
40+
ctx,
41+
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
42+
InvalidArgument,
43+
out);
44+
45+
ET_KERNEL_CHECK(
46+
ctx, tensor_is_broadcastable_to(in, out), InvalidArgument, out);
47+
48+
ScalarType alpha_dtype = utils::get_scalar_dtype(alpha);
49+
ScalarType beta_dtype = utils::get_scalar_dtype(beta);
50+
ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "addmm", CTYPE, [&]() {
51+
ET_SWITCH_SCALAR_OBJ_TYPES(alpha_dtype, ctx, "addmm", ALPHA_T, [&]() {
52+
ET_SWITCH_SCALAR_OBJ_TYPES(beta_dtype, ctx, "addmm", BETA_T, [&]() {
53+
size_t m = mat1.size(0);
54+
size_t n = mat1.size(1);
55+
size_t p = mat2.size(1);
56+
57+
if (out.sizes() == in.sizes()) {
58+
// vec_addmm assumes that no broadcasting is required.
59+
vec_addmm<CTYPE, CTYPE>(
60+
out.mutable_data_ptr<CTYPE>(),
61+
in.const_data_ptr<CTYPE>(),
62+
mat1.const_data_ptr<CTYPE>(),
63+
mat2.const_data_ptr<CTYPE>(),
64+
m,
65+
n,
66+
p,
67+
convert<CTYPE>(beta.to<BETA_T>()),
68+
convert<CTYPE>(alpha.to<ALPHA_T>()));
69+
} else {
70+
// If broadcasting is required, them compute the matmul and addition
71+
// separately, using apply_binary_elementwise_fn to perform the
72+
// addition while applying broadcasting
73+
vec_matmul<CTYPE, CTYPE>(
74+
out.mutable_data_ptr<CTYPE>(),
75+
mat1.const_data_ptr<CTYPE>(),
76+
mat2.const_data_ptr<CTYPE>(),
77+
m,
78+
n,
79+
p);
80+
81+
CTYPE alpha_val = convert<CTYPE>(alpha.to<ALPHA_T>());
82+
CTYPE beta_val = convert<CTYPE>(beta.to<BETA_T>());
83+
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
84+
[alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
85+
CTYPE a_casted = static_cast<CTYPE>(val_a);
86+
CTYPE b_casted = static_cast<CTYPE>(val_b);
87+
CTYPE value = a_casted + alpha_val * b_casted * beta_val;
88+
89+
return value;
90+
},
91+
out,
92+
in,
93+
out);
94+
}
95+
});
96+
});
97+
});
20098

20199
return out;
202100
}

0 commit comments

Comments
 (0)