Skip to content

Commit a14c5d1

Browse files
committed
[ExecuTorch] Add optimized op_linear
If we happen to be running without a delegate, directly implementing linear is much more efficient than permute_copy_out (materialize a transpose) followed by matmul. Differential Revision: [D62154007](https://our.internmc.facebook.com/intern/diff/D62154007/) [ghstack-poisoned]
1 parent a3c2b7a commit a14c5d1

File tree

9 files changed

+434
-0
lines changed

9 files changed

+434
-0
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@
215215

216216
- op: linalg_vector_norm.out
217217

218+
- op: linear.out
219+
218220
- op: log.out
219221

220222
- op: log10.out

kernels/optimized/cpu/op_linear.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/optimized/blas/CPUBlas.h>
10+
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
#include <array>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using Tensor = exec_aten::Tensor;
20+
21+
Tensor& opt_linear_out(
22+
RuntimeContext& ctx,
23+
const Tensor& in,
24+
const Tensor& mat2,
25+
const optional<Tensor>& bias,
26+
Tensor& out) {
27+
ET_KERNEL_CHECK_MSG(
28+
ctx,
29+
!bias.has_value(),
30+
InvalidArgument,
31+
out,
32+
"bias not supported yet in linear");
33+
ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out);
34+
35+
size_t output_ndim = 0;
36+
std::array<exec_aten::SizesType, kTensorDimensionLimit> output_sizes;
37+
get_linear_out_target_size(in, mat2, output_sizes.data(), &output_ndim);
38+
ET_KERNEL_CHECK(
39+
ctx,
40+
resize_tensor(out, {output_sizes.data(), output_ndim}) == Error::Ok,
41+
InvalidArgument,
42+
out);
43+
44+
// gemm on some platforms doesn't tolerate empty input.
45+
if (out.numel() == 0) {
46+
return out;
47+
}
48+
49+
int flattened_input_dim = 1;
50+
for (int ii = 0; ii < in.dim() - 1; ++ii) {
51+
flattened_input_dim *= in.sizes()[ii];
52+
}
53+
ET_SWITCH_REAL_TYPES_AND2(
54+
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
55+
size_t n = flattened_input_dim;
56+
size_t k = in.sizes()[in.dim() - 1];
57+
size_t m = mat2.size(0);
58+
59+
executorch::cpublas::gemm(
60+
executorch::cpublas::TransposeType::Transpose,
61+
executorch::cpublas::TransposeType::NoTranspose,
62+
m,
63+
n,
64+
k,
65+
static_cast<CTYPE>(1),
66+
mat2.const_data_ptr<CTYPE>(),
67+
k,
68+
in.const_data_ptr<CTYPE>(),
69+
k,
70+
static_cast<CTYPE>(0),
71+
out.mutable_data_ptr<CTYPE>(),
72+
m);
73+
});
74+
75+
return out;
76+
}
77+
78+
} // namespace native
79+
} // namespace executor
80+
} // namespace torch

kernels/optimized/cpu/targets.bzl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ _OPTIMIZED_ATEN_OPS = (
4040
"//executorch/kernels/portable/cpu:scalar_utils",
4141
],
4242
),
43+
op_target(
44+
name = "op_linear",
45+
deps = [
46+
"//executorch/kernels/optimized:libblas",
47+
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
48+
],
49+
),
4350
op_target(
4451
name = "op_log_softmax",
4552
deps = select({

kernels/optimized/optimized-oss.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
- arg_meta: null
4646
kernel_name: torch::executor::opt_le_tensor_out
4747

48+
- op: linear.out
49+
kernels:
50+
- arg_meta: null
51+
kernel_name: torch::executor::opt_linear_out
52+
4853
- op: mul.out
4954
kernels:
5055
- arg_meta: null

kernels/optimized/optimized.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
- arg_meta: null
5353
kernel_name: torch::executor::opt_le_tensor_out
5454

55+
- op: linear.out
56+
kernels:
57+
- arg_meta: null
58+
kernel_name: torch::executor::opt_linear_out
59+
5560
- op: mm.out
5661
kernels:
5762
- arg_meta: null

kernels/portable/cpu/util/matmul_ops_util.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ bool check_mm_args(const Tensor& in, const Tensor& mat2, Tensor& out) {
7171
return true;
7272
}
7373

74+
bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out) {
75+
ET_LOG_AND_RETURN_IF_FALSE(in.dim() == out.dim());
76+
ET_LOG_AND_RETURN_IF_FALSE(in.dim() >= 2);
77+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2));
78+
79+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out));
80+
81+
ET_LOG_AND_RETURN_IF_FALSE(
82+
tensors_have_same_size_at_dims(in, in.dim() - 1, mat2, 1));
83+
84+
return true;
85+
}
86+
7487
void get_mm_out_target_size(
7588
const Tensor& mat1,
7689
const Tensor& mat2,
@@ -81,5 +94,17 @@ void get_mm_out_target_size(
8194
out_sizes[1] = mat2.size(1);
8295
}
8396

97+
void get_linear_out_target_size(
98+
const Tensor& mat1,
99+
const Tensor& mat2,
100+
Tensor::SizesType* out_sizes,
101+
size_t* out_ndim) {
102+
*out_ndim = mat1.dim();
103+
for (int ii = 0; ii < mat1.dim() - 1; ++ii) {
104+
out_sizes[ii] = mat1.sizes()[ii];
105+
}
106+
out_sizes[mat1.dim() - 1] = mat2.size(0);
107+
}
108+
84109
} // namespace executor
85110
} // namespace torch

kernels/portable/cpu/util/matmul_ops_util.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,13 @@ void get_mm_out_target_size(
3737
Tensor::SizesType* out_sizes,
3838
size_t* out_ndim);
3939

40+
bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out);
41+
42+
void get_linear_out_target_size(
43+
const Tensor& mat1,
44+
const Tensor& mat2,
45+
Tensor::SizesType* out_sizes,
46+
size_t* out_ndim);
47+
4048
} // namespace executor
4149
} // namespace torch

0 commit comments

Comments
 (0)