Skip to content

Commit 1083643

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Reduce build size of op_copy (#6019)
Summary: Pull Request resolved: #6019 133 K -> 15 K ghstack-source-id: 246985122 exported-using-ghexport Reviewed By: malfet, swolchok Differential Revision: D63994873 fbshipit-source-id: c2cb4381e68934eebd078dd3e8f72bbb9c6320e6
1 parent e0c26dd commit 1083643

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

kernels/portable/cpu/op_copy.cpp

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <cstring>
1010

1111
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

1415
namespace torch {
@@ -42,19 +43,19 @@ Tensor& copy_out(
4243
ET_KERNEL_CHECK(
4344
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4445

45-
ScalarType in_type = in.scalar_type();
46-
ScalarType src_type = src.scalar_type();
47-
48-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() {
49-
ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() {
50-
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
51-
[](const CTYPE val_in, const CTYPE_SRC val_src) {
52-
return convert<CTYPE, CTYPE_SRC>(val_src);
53-
},
54-
in,
55-
src,
56-
out);
57-
});
46+
// @lint-ignore CLANGTIDY facebook-hte-CArray
47+
static constexpr const char op_name[] = "copy.out";
48+
49+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
50+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
51+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
52+
ctx,
53+
in,
54+
utils::SupportedTensorDtypes::REALHBBF16,
55+
src,
56+
utils::SupportedTensorDtypes::REALHBBF16,
57+
out,
58+
utils::SupportedTensorDtypes::REALHBBF16);
5859
});
5960

6061
return out;
@@ -75,19 +76,19 @@ Tensor& copy_(
7576
ET_KERNEL_CHECK(
7677
ctx, tensors_have_same_dim_order(in, src), InvalidArgument, in);
7778

78-
ScalarType in_type = in.scalar_type();
79-
ScalarType src_type = src.scalar_type();
80-
81-
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy_", CTYPE, [&]() {
82-
ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() {
83-
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
84-
[](const CTYPE val_in, const CTYPE_SRC val_src) {
85-
return convert<CTYPE, CTYPE_SRC>(val_src);
86-
},
87-
in,
88-
src,
89-
in);
90-
});
79+
// @lint-ignore CLANGTIDY facebook-hte-CArray
80+
static constexpr const char op_name[] = "copy_";
81+
82+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
83+
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
84+
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
85+
ctx,
86+
in,
87+
utils::SupportedTensorDtypes::REALHBBF16,
88+
src,
89+
utils::SupportedTensorDtypes::REALHBBF16,
90+
in,
91+
utils::SupportedTensorDtypes::REALHBBF16);
9192
});
9293

9394
return in;

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,8 @@ ATEN_OPS = (
425425
name = "op_copy",
426426
deps = [
427427
"//executorch/kernels/portable/cpu/util:broadcast_util",
428+
"//executorch/kernels/portable/cpu/util:dtype_util",
429+
"//executorch/kernels/portable/cpu/util:elementwise_util",
428430
"//executorch/runtime/core/exec_aten/util:scalar_type_util",
429431
"//executorch/runtime/core/exec_aten/util:tensor_util",
430432
":scalar_utils",

0 commit comments

Comments
 (0)