Skip to content

Commit 99c6237

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add direct copy fast path for portable copy op
Summary: The PR adds a direct memcpy fast-path for portable copy and copy_ ops. This speeds up copy significantly in cases where no broadcasting is needed. This is most noticable when copying buffer mutations back, such as transformer KV cache when managing the cache as a mutable buffer. Prior to this change, an encoder/decoder model was taking ~25% of the total runtime copying KV cache back after permuting. With this change, the copy becomes significantly cheaper. I benchmarked a simple model on S23 and Pixel 5: ``` class TestModel(torch.nn.Module): def __init__(self): super().__init__() self.register_buffer("buffer", torch.zeros((2, 10, 1024, 1024))) def forward(self, x): self.buffer.add_(x) return self.buffer model = TestModel() inputs = (torch.randn(2, 10, 1024, 1024),) lowered = to_edge_transform_and_lower( torch.export.export(model, inputs), partitioner=[XnnpackPartitioner()], ).to_executorch() ``` S23, average of 50 runs, time in copy_: 4.1ms vs 22.3ms Pixel 5, average of 50 runs, time in copy_: 12.1ms vs 66.6ms This is approximately a ~5.5x speedup of the copy operator. Differential Revision: D73656456
1 parent 7e034ca commit 99c6237

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

kernels/portable/cpu/op_copy.cpp

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,22 @@ Tensor& copy_out(
4646
// @lint-ignore CLANGTIDY facebook-hte-CArray
4747
static constexpr const char op_name[] = "copy.out";
4848

49-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52-
ctx,
53-
in,
54-
utils::SupportedTensorDtypes::REALHBBF16,
55-
src,
56-
utils::SupportedTensorDtypes::REALHBBF16,
57-
out,
58-
utils::SupportedTensorDtypes::REALHBBF16);
59-
});
49+
// Use direct copy fast path if broadcast is not needed and tensors are non-empty
50+
if (tensors_have_same_shape(out, src) && src.numel() > 0) {
51+
std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes());
52+
} else {
53+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
54+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
55+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
56+
ctx,
57+
in,
58+
utils::SupportedTensorDtypes::REALHBBF16,
59+
src,
60+
utils::SupportedTensorDtypes::REALHBBF16,
61+
out,
62+
utils::SupportedTensorDtypes::REALHBBF16);
63+
});
64+
}
6065

6166
return out;
6267
}
@@ -79,17 +84,22 @@ Tensor& copy_(
7984
// @lint-ignore CLANGTIDY facebook-hte-CArray
8085
static constexpr const char op_name[] = "copy_";
8186

82-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
83-
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84-
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85-
ctx,
86-
in,
87-
utils::SupportedTensorDtypes::REALHBBF16,
88-
src,
89-
utils::SupportedTensorDtypes::REALHBBF16,
90-
in,
91-
utils::SupportedTensorDtypes::REALHBBF16);
92-
});
87+
// Use direct copy fast path if broadcast is not needed and tensors are non-empty
88+
if (tensors_have_same_shape(in, src) && src.numel() > 0) {
89+
std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes());
90+
} else {
91+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
92+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
93+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
94+
ctx,
95+
in,
96+
utils::SupportedTensorDtypes::REALHBBF16,
97+
src,
98+
utils::SupportedTensorDtypes::REALHBBF16,
99+
in,
100+
utils::SupportedTensorDtypes::REALHBBF16);
101+
});
102+
}
93103

94104
return in;
95105
}

0 commit comments

Comments
 (0)