Skip to content

Commit ac715f2

Browse files
committed
[Executorch] Refactor op_add to support op_sub broadcasting
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b0baeec Pull Request resolved: #8255
1 parent fae9914 commit ac715f2

File tree

3 files changed

+235
-147
lines changed

3 files changed

+235
-147
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 8 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -14,59 +14,11 @@
1414
#include <executorch/runtime/kernel/kernel_includes.h>
1515
#include <executorch/runtime/platform/assert.h>
1616

17+
#include <executorch/kernels/optimized/cpu/op_add_sub_impl.h>
18+
1719
namespace torch {
1820
namespace executor {
1921
namespace native {
20-
namespace {
21-
22-
template <
23-
bool can_cast,
24-
typename CTYPE_A,
25-
typename CTYPE_B,
26-
typename CTYPE_IN,
27-
typename CTYPE_OUT>
28-
struct AddInner;
29-
30-
template <
31-
typename CTYPE_A,
32-
typename CTYPE_B,
33-
typename CTYPE_IN,
34-
typename CTYPE_OUT>
35-
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36-
static void
37-
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40-
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43-
CTYPE_IN value = a_casted + alpha_val * b_casted;
44-
45-
return static_cast<CTYPE_OUT>(value);
46-
},
47-
a,
48-
b,
49-
out);
50-
}
51-
};
52-
53-
template <typename CTYPE_IN>
54-
struct ReportCanCastBug {
55-
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
57-
}
58-
};
59-
60-
template <
61-
typename CTYPE_A,
62-
typename CTYPE_B,
63-
typename CTYPE_IN,
64-
typename CTYPE_OUT>
65-
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66-
: public ReportCanCastBug<CTYPE_IN> {};
67-
68-
} // namespace
69-
7022
using Tensor = executorch::aten::Tensor;
7123
using ScalarType = executorch::aten::ScalarType;
7224

@@ -76,8 +28,6 @@ Tensor& opt_add_out(
7628
const Tensor& b,
7729
const Scalar& alpha,
7830
Tensor& out) {
79-
(void)ctx;
80-
8131
ScalarType a_type = a.scalar_type();
8232
ScalarType b_type = b.scalar_type();
8333
ScalarType out_type = out.scalar_type();
@@ -95,7 +45,9 @@ Tensor& opt_add_out(
9545
ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
9646
CTYPE alpha_val;
9747
ET_KERNEL_CHECK(
98-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
48+
ctx,
49+
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
50+
InvalidArgument, );
9951
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
10052
CTYPE b_casted = static_cast<CTYPE>(b_val);
10153

@@ -115,100 +67,9 @@ Tensor& opt_add_out(
11567
return opt_add_out(ctx, b, a, alpha, out);
11668
}
11769

118-
auto selected_optimized_path = select_optimized_path(a, b, out);
119-
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
120-
// Resize for dynamic shape
121-
auto error = resize_tensor(out, a.sizes());
122-
ET_KERNEL_CHECK_MSG(
123-
ctx,
124-
error == Error::Ok,
125-
InvalidArgument,
126-
out,
127-
"Failed to resize output tensor.");
128-
129-
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
130-
CTYPE alpha_val;
131-
ET_KERNEL_CHECK(
132-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
133-
134-
using Vec = executorch::vec::Vectorized<CTYPE>;
135-
executorch::vec::map2<CTYPE>(
136-
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
137-
out.mutable_data_ptr<CTYPE>(),
138-
a.const_data_ptr<CTYPE>(),
139-
b.const_data_ptr<CTYPE>(),
140-
out.numel());
141-
});
142-
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143-
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
144-
CTYPE alpha_val;
145-
ET_KERNEL_CHECK_MSG(
146-
ctx,
147-
utils::extract_scalar(alpha, &alpha_val),
148-
InvalidArgument,
149-
out,
150-
"Failed to extract scalar alpha.");
151-
using Vec = executorch::vec::Vectorized<CTYPE>;
152-
Vec alpha_val_vec(alpha_val);
153-
if (selected_optimized_path ==
154-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
155-
selected_optimized_path ==
156-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
157-
selected_optimized_path ==
158-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
159-
// Reason we swap out args here is because handle_broadcast_elementwise
160-
// handles this selected_optimized_path option a bit differently.
161-
// This should really be resolved in handle_broadcast_elementwise.
162-
// However, the current blocker is that handle_broadcast_elementwise
163-
// tries to be agnostic of op. This should be fixed, likely by moving
164-
// lambda creation to handle_broadcast_elementwise and it be aware of
165-
// which op is being executed.
166-
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
167-
return y + alpha_val_vec * x;
168-
};
169-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
170-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
171-
} else {
172-
auto add_lambda = [&alpha_val_vec](auto x, auto y) {
173-
return x + alpha_val_vec * y;
174-
};
175-
return torch::executor::handle_broadcast_elementwise<CTYPE>(
176-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
177-
}
178-
});
179-
} else {
180-
ScalarType common_type =
181-
promoteTypes(a_type, b_type, /*half_to_float*/ true);
182-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
183-
184-
ET_KERNEL_CHECK(
185-
ctx,
186-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
187-
InvalidArgument,
188-
out);
189-
190-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
191-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
192-
using CTYPE_IN = typename torch::executor::
193-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
194-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
195-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
196-
CTYPE_IN alpha_val;
197-
ET_KERNEL_CHECK(
198-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
199-
200-
AddInner<
201-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
202-
CTYPE_A,
203-
CTYPE_B,
204-
CTYPE_IN,
205-
CTYPE_OUT>::run(a, b, alpha_val, out);
206-
});
207-
});
208-
});
209-
}
210-
211-
return out;
70+
static constexpr const char op_name[] = "add.out";
71+
return torch::executor::kernels::impl::opt_add_sub_out_impl<false, op_name>(
72+
ctx, a, b, alpha, out);
21273
}
21374

21475
Tensor& opt_add_scalar_out(

0 commit comments

Comments
 (0)