Skip to content

Commit 60bb4ef

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add direct copy fast path for portable copy op (#10487)
Summary: The PR adds a direct memcpy fast-path for portable copy and copy_ ops. This speeds up copy significantly in cases where no broadcasting is needed. This is most noticable when copying buffer mutations back, such as transformer KV cache when managing the cache as a mutable buffer. Prior to this change, an encoder/decoder model was taking ~25% of the total runtime copying KV cache back after permuting. With this change, the copy becomes significantly cheaper. I benchmarked a simple model on S23 and Pixel 5: ``` class TestModel(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("buffer", torch.zeros((2, 10, 1024, 1024))) def forward(self, x): self.buffer.add_(x) return self.buffer model = TestModel() inputs = (torch.randn(2, 10, 1024, 1024),) lowered = to_edge_transform_and_lower( torch.export.export(model, inputs), partitioner=[XnnpackPartitioner()], ).to_executorch() ``` S23, average of 50 runs, time in copy_: 4.1ms vs 22.3ms Pixel 5, average of 50 runs, time in copy_: 12.1ms vs 66.6ms This is approximately a ~5.5x speedup of the copy operator. Reviewed By: swolchok Differential Revision: D73656456 Pulled By: GregoryComer
1 parent 12079fe commit 60bb4ef

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

kernels/portable/cpu/op_copy.cpp

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,24 @@ Tensor& copy_out(
4646
// @lint-ignore CLANGTIDY facebook-hte-CArray
4747
static constexpr const char op_name[] = "copy.out";
4848

49-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52-
ctx,
53-
in,
54-
utils::SupportedTensorDtypes::REALHBBF16,
55-
src,
56-
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::REALHBBF16);
59-
});
49+
// Use direct copy fast path if broadcast is not needed and tensors are
50+
// non-empty
51+
if (internal::sizes_match_ignoring_leading_1s(out.sizes(), src.sizes()) &&
52+
src.numel() > 0) {
53+
std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes());
54+
} else {
55+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
56+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
57+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
58+
ctx,
59+
in,
60+
utils::SupportedTensorDtypes::REALHBBF16,
61+
src,
62+
utils::SupportedTensorDtypes::REALHBBF16,
63+
out,
64+
utils::SupportedTensorDtypes::REALHBBF16);
65+
});
66+
}
6067

6168
return out;
6269
}
@@ -79,17 +86,24 @@ Tensor& copy_(
7986
// @lint-ignore CLANGTIDY facebook-hte-CArray
8087
static constexpr const char op_name[] = "copy_";
8188

82-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
83-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85-
ctx,
86-
in,
87-
utils::SupportedTensorDtypes::REALHBBF16,
88-
src,
89-
utils::SupportedTensorDtypes::REALHBBF16,
90-
in,
91-
utils::SupportedTensorDtypes::REALHBBF16);
92-
});
89+
// Use direct copy fast path if broadcast is not needed and tensors are
90+
// non-empty
91+
if (internal::sizes_match_ignoring_leading_1s(in.sizes(), src.sizes()) &&
92+
src.numel() > 0) {
93+
std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes());
94+
} else {
95+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
96+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
97+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
98+
ctx,
99+
in,
100+
utils::SupportedTensorDtypes::REALHBBF16,
101+
src,
102+
utils::SupportedTensorDtypes::REALHBBF16,
103+
in,
104+
utils::SupportedTensorDtypes::REALHBBF16);
105+
});
106+
}
93107

94108
return in;
95109
}

0 commit comments

Comments
 (0)