Skip to content

Commit 68397af

Browse files
swolchokfacebook-github-bot
authored andcommitted
Optimized op_mm using CPUBlas gemm (#5242)
Summary: Pull Request resolved: #5242 No immediate need for this, but it is extremely simple to implement so why not support it? ghstack-source-id: 241919004 exported-using-ghexport Reviewed By: kimishpatel Differential Revision: D62151659 fbshipit-source-id: 7cb5850981ad0666a304e7917d407847037ffa2d
1 parent af80804 commit 68397af

File tree

4 files changed

+84
-1
lines changed

4 files changed

+84
-1
lines changed

kernels/optimized/cpu/op_mm.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/optimized/blas/CPUBlas.h>
10+
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
#include <array>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using Tensor = exec_aten::Tensor;
20+
21+
Tensor& opt_mm_out(
22+
RuntimeContext& ctx,
23+
const Tensor& in,
24+
const Tensor& mat2,
25+
Tensor& out) {
26+
ET_KERNEL_CHECK(ctx, check_mm_args(in, mat2, out), InvalidArgument, out);
27+
28+
size_t output_ndim = 0;
29+
std::array<exec_aten::SizesType, kTensorDimensionLimit> output_sizes;
30+
get_mm_out_target_size(in, mat2, output_sizes.data(), &output_ndim);
31+
ET_KERNEL_CHECK(
32+
ctx,
33+
resize_tensor(out, {output_sizes.data(), output_ndim}) == Error::Ok,
34+
InvalidArgument,
35+
out);
36+
37+
if (out.numel() == 0) {
38+
return out;
39+
}
40+
ET_SWITCH_REAL_TYPES_AND2(
41+
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
42+
size_t n = in.size(0);
43+
size_t k = in.size(1);
44+
size_t m = mat2.size(1);
45+
46+
// gemm expects column-major inputs and produces column-major
47+
// output. So, we take advantage of the identity (A @ B).t()
48+
// = B.t() @ A.t() here; row-major B is B.t() from gemm's
49+
// column-major perspective, etc.
50+
executorch::cpublas::gemm(
51+
executorch::cpublas::TransposeType::NoTranspose,
52+
executorch::cpublas::TransposeType::NoTranspose,
53+
m,
54+
n,
55+
k,
56+
static_cast<CTYPE>(1),
57+
mat2.const_data_ptr<CTYPE>(),
58+
m,
59+
in.const_data_ptr<CTYPE>(),
60+
k,
61+
static_cast<CTYPE>(0),
62+
out.mutable_data_ptr<CTYPE>(),
63+
m);
64+
});
65+
66+
return out;
67+
}
68+
69+
} // namespace native
70+
} // namespace executor
71+
} // namespace torch

kernels/optimized/cpu/targets.bzl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,13 @@ _OPTIMIZED_ATEN_OPS = (
5252
],
5353
}),
5454
),
55+
op_target(
56+
name = "op_mm",
57+
deps = [
58+
"//executorch/kernels/optimized:libblas",
59+
"//executorch/kernels/portable/cpu/util:matmul_ops_util",
60+
],
61+
),
5562
op_target(
5663
name = "op_mul",
5764
deps = [

kernels/optimized/optimized.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@
5252
- arg_meta: null
5353
kernel_name: torch::executor::opt_le_tensor_out
5454

55+
- op: mm.out
56+
kernels:
57+
- arg_meta: null
58+
kernel_name: torch::executor::opt_mm_out
59+
5560
- op: mul.out
5661
kernels:
5762
- arg_meta: null

kernels/test/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def define_common_targets():
244244
_common_op_test("op_mean_test", ["aten", "portable"])
245245
_common_op_test("op_min_test", ["aten", "portable"])
246246
_common_op_test("op_minimum_test", ["aten", "portable"])
247-
_common_op_test("op_mm_test", ["aten", "portable"])
247+
_common_op_test("op_mm_test", ["aten", "portable", "optimized"])
248248
_common_op_test("op_mul_test", ["aten", "portable", "optimized"])
249249
_common_op_test("op_narrow_copy_test", ["aten", "portable"])
250250
_common_op_test("op_native_batch_norm_test", ["aten", "portable"])

0 commit comments

Comments
 (0)