Skip to content

Commit fb4fe1d

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add direct copy fast path for portable copy op (#10487)
Summary: The PR adds a direct memcpy fast-path for portable copy and copy_ ops. This speeds up copy significantly in cases where no broadcasting is needed. This is most noticable when copying buffer mutations back, such as transformer KV cache when managing the cache as a mutable buffer. Prior to this change, an encoder/decoder model was taking ~25% of the total runtime copying KV cache back after permuting. With this change, the copy becomes significantly cheaper. I benchmarked a simple model on S23 and Pixel 5: ``` class TestModel(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("buffer", torch.zeros((2, 10, 1024, 1024))) def forward(self, x): self.buffer.add_(x) return self.buffer model = TestModel() inputs = (torch.randn(2, 10, 1024, 1024),) lowered = to_edge_transform_and_lower( torch.export.export(model, inputs), partitioner=[XnnpackPartitioner()], ).to_executorch() ``` S23, average of 50 runs, time in copy_: 4.1ms vs 22.3ms Pixel 5, average of 50 runs, time in copy_: 12.1ms vs 66.6ms This is approximately a ~5.5x speedup of the copy operator. Reviewed By: swolchok Differential Revision: D73656456 Pulled By: GregoryComer
1 parent 12079fe commit fb4fe1d

File tree

1 file changed

+106
-91
lines changed

1 file changed

+106
-91
lines changed

kernels/portable/cpu/op_copy.cpp

Lines changed: 106 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -6,94 +6,109 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <cstring>
10-
11-
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12-
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13-
#include <executorch/runtime/kernel/kernel_includes.h>
14-
15-
namespace torch {
16-
namespace executor {
17-
namespace native {
18-
19-
using Tensor = executorch::aten::Tensor;
20-
21-
// copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
22-
// out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
23-
// TODO: We actually shouldn't see this op with the proper functionalization,
24-
// and this op needs to be deleted
25-
Tensor& copy_out(
26-
KernelRuntimeContext& ctx,
27-
const Tensor& in,
28-
const Tensor& src,
29-
bool non_blocking,
30-
Tensor& out) {
31-
(void)ctx;
32-
// Right now we only support blocking data transfer
33-
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, out);
34-
35-
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
36-
37-
ET_KERNEL_CHECK(
38-
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, out);
39-
40-
ET_KERNEL_CHECK(
41-
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
42-
43-
ET_KERNEL_CHECK(
44-
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45-
46-
// @lint-ignore CLANGTIDY facebook-hte-CArray
47-
static constexpr const char op_name[] = "copy.out";
48-
49-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52-
ctx,
53-
in,
54-
utils::SupportedTensorDtypes::REALHBBF16,
55-
src,
56-
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::REALHBBF16);
59-
});
60-
61-
return out;
62-
}
63-
64-
Tensor& copy_(
65-
KernelRuntimeContext& ctx,
66-
Tensor& in,
67-
const Tensor& src,
68-
bool non_blocking) {
69-
(void)ctx;
70-
// Right now we only support blocking data transfer
71-
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, in);
72-
73-
ET_KERNEL_CHECK(
74-
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);
75-
76-
ET_KERNEL_CHECK(
77-
ctx, tensors_have_same_dim_order(in, src), InvalidArgument, in);
78-
79-
// @lint-ignore CLANGTIDY facebook-hte-CArray
80-
static constexpr const char op_name[] = "copy_";
81-
82-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
83-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85-
ctx,
86-
in,
87-
utils::SupportedTensorDtypes::REALHBBF16,
88-
src,
89-
utils::SupportedTensorDtypes::REALHBBF16,
90-
in,
91-
utils::SupportedTensorDtypes::REALHBBF16);
92-
});
93-
94-
return in;
95-
}
96-
97-
} // namespace native
98-
} // namespace executor
99-
} // namespace torch
9+
#include <cstring>
10+
11+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
13+
#include <executorch/runtime/kernel/kernel_includes.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using Tensor = executorch::aten::Tensor;
20+
21+
// copy.out(const Tensor& in, const Tensor& src, bool non_blocking, Tensor(a!)
22+
// out) -> Tensor(a!), see caffe2/aten/src/ATen/native/Copy.cpp
23+
// TODO: We actually shouldn't see this op with the proper functionalization,
24+
// and this op needs to be deleted
25+
Tensor& copy_out(
26+
KernelRuntimeContext& ctx,
27+
const Tensor& in,
28+
const Tensor& src,
29+
bool non_blocking,
30+
Tensor& out) {
31+
(void)ctx;
32+
// Right now we only support blocking data transfer
33+
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, out);
34+
35+
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
36+
37+
ET_KERNEL_CHECK(
38+
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, out);
39+
40+
ET_KERNEL_CHECK(
41+
ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out);
42+
43+
ET_KERNEL_CHECK(
44+
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
45+
46+
// @lint-ignore CLANGTIDY facebook-hte-CArray
47+
static constexpr const char op_name[] = "copy.out";
48+
49+
// Use direct copy fast path if broadcast is not needed and tensors are
50+
// non-empty
51+
if (internal::sizes_match_ignoring_leading_1s(out.sizes(), src.sizes()) &&
52+
src.numel() > 0) {
53+
std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes());
54+
} else {
55+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
56+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
57+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
58+
ctx,
59+
in,
60+
utils::SupportedTensorDtypes::REALHBBF16,
61+
src,
62+
utils::SupportedTensorDtypes::REALHBBF16,
63+
out,
64+
utils::SupportedTensorDtypes::REALHBBF16);
65+
});
66+
}
67+
68+
return out;
69+
}
70+
71+
Tensor& copy_(
72+
KernelRuntimeContext& ctx,
73+
Tensor& in,
74+
const Tensor& src,
75+
bool non_blocking) {
76+
(void)ctx;
77+
// Right now we only support blocking data transfer
78+
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, in);
79+
80+
ET_KERNEL_CHECK(
81+
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);
82+
83+
ET_KERNEL_CHECK(
84+
ctx, tensors_have_same_dim_order(in, src), InvalidArgument, in);
85+
86+
// @lint-ignore CLANGTIDY facebook-hte-CArray
87+
static constexpr const char op_name[] = "copy_";
88+
89+
// Use direct copy fast path if broadcast is not needed and tensors are
90+
// non-empty
91+
if (internal::sizes_match_ignoring_leading_1s(in.sizes(), src.sizes()) &&
92+
src.numel() > 0) {
93+
std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes());
94+
} else {
95+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
96+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
97+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
98+
ctx,
99+
in,
100+
utils::SupportedTensorDtypes::REALHBBF16,
101+
src,
102+
utils::SupportedTensorDtypes::REALHBBF16,
103+
in,
104+
utils::SupportedTensorDtypes::REALHBBF16);
105+
});
106+
}
107+
108+
return in;
109+
}
110+
111+
} // namespace native
112+
} // namespace executor
113+
} // namespace torch
114+

0 commit comments

Comments
 (0)