Skip to content

Commit e108f50

Browse files
committed
[Executorch] Refactor op_add to support op_sub broadcasting
Summary: Refactor op_add to conslidate commong broadcasting related improvements Test Plan: Previously added tests Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: d75b959 Pull Request resolved: #8255
1 parent dc3d7fe commit e108f50

File tree

4 files changed

+251
-136
lines changed

4 files changed

+251
-136
lines changed

kernels/optimized/cpu/op_add.cpp

Lines changed: 7 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -14,59 +14,11 @@
1414
#include <executorch/runtime/kernel/kernel_includes.h>
1515
#include <executorch/runtime/platform/assert.h>
1616

17+
#include <executorch/kernels/optimized/cpu/op_add_sub_impl.h>
18+
1719
namespace torch {
1820
namespace executor {
1921
namespace native {
20-
namespace {
21-
22-
template <
23-
bool can_cast,
24-
typename CTYPE_A,
25-
typename CTYPE_B,
26-
typename CTYPE_IN,
27-
typename CTYPE_OUT>
28-
struct AddInner;
29-
30-
template <
31-
typename CTYPE_A,
32-
typename CTYPE_B,
33-
typename CTYPE_IN,
34-
typename CTYPE_OUT>
35-
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
36-
static void
37-
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
38-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
39-
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
40-
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
41-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
42-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
43-
CTYPE_IN value = a_casted + alpha_val * b_casted;
44-
45-
return static_cast<CTYPE_OUT>(value);
46-
},
47-
a,
48-
b,
49-
out);
50-
}
51-
};
52-
53-
template <typename CTYPE_IN>
54-
struct ReportCanCastBug {
55-
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
56-
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
57-
}
58-
};
59-
60-
template <
61-
typename CTYPE_A,
62-
typename CTYPE_B,
63-
typename CTYPE_IN,
64-
typename CTYPE_OUT>
65-
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
66-
: public ReportCanCastBug<CTYPE_IN> {};
67-
68-
} // namespace
69-
7022
using Tensor = executorch::aten::Tensor;
7123
using ScalarType = executorch::aten::ScalarType;
7224

@@ -76,8 +28,6 @@ Tensor& opt_add_out(
7628
const Tensor& b,
7729
const Scalar& alpha,
7830
Tensor& out) {
79-
(void)ctx;
80-
8131
ScalarType a_type = a.scalar_type();
8232
ScalarType b_type = b.scalar_type();
8333
ScalarType out_type = out.scalar_type();
@@ -95,7 +45,9 @@ Tensor& opt_add_out(
9545
ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
9646
CTYPE alpha_val;
9747
ET_KERNEL_CHECK(
98-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
48+
ctx,
49+
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
50+
InvalidArgument, );
9951
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
10052
CTYPE b_casted = static_cast<CTYPE>(b_val);
10153

@@ -115,89 +67,8 @@ Tensor& opt_add_out(
11567
return opt_add_out(ctx, b, a, alpha, out);
11668
}
11769

118-
auto selected_optimized_path = select_optimized_path(a, b, out);
119-
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
120-
// Resize for dynamic shape
121-
auto error = resize_tensor(out, a.sizes());
122-
ET_KERNEL_CHECK_MSG(
123-
ctx,
124-
error == Error::Ok,
125-
InvalidArgument,
126-
out,
127-
"Failed to resize output tensor.");
128-
129-
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
130-
CTYPE alpha_val;
131-
ET_KERNEL_CHECK(
132-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
133-
134-
using Vec = executorch::vec::Vectorized<CTYPE>;
135-
executorch::vec::map2<CTYPE>(
136-
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
137-
out.mutable_data_ptr<CTYPE>(),
138-
a.const_data_ptr<CTYPE>(),
139-
b.const_data_ptr<CTYPE>(),
140-
out.numel());
141-
});
142-
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143-
if (selected_optimized_path ==
144-
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
145-
selected_optimized_path ==
146-
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
147-
selected_optimized_path ==
148-
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
149-
// This behavior is a bit confusing.
150-
// Reason we swap out args here is because handle_broadcast_elementwise
151-
// handles this selected_optimized_path option a bit differently.
152-
// This should really be resoled in handle_broadcast_elementwise.
153-
// However, the current blocker is that handle_broadcast_elementwise tries to
154-
// be agnostic of op. This should be fixed, likely by moving lambda creation
155-
// to handle_broadcast_elementwise and it be aware of which op is being executed.
156-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
157-
return y + alpha_val * x;
158-
};
159-
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>(
160-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
161-
} else {
162-
auto add_lambda = [](auto x, auto y, auto alpha_val) {
163-
return x + alpha_val * y;
164-
};
165-
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>(
166-
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
167-
}
168-
} else {
169-
ScalarType common_type =
170-
promoteTypes(a_type, b_type, /*half_to_float*/ true);
171-
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
172-
173-
ET_KERNEL_CHECK(
174-
ctx,
175-
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
176-
InvalidArgument,
177-
out);
178-
179-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
180-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
181-
using CTYPE_IN = typename torch::executor::
182-
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
183-
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
184-
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
185-
CTYPE_IN alpha_val;
186-
ET_KERNEL_CHECK(
187-
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
188-
189-
AddInner<
190-
can_cast<CTYPE_IN, CTYPE_OUT>::value,
191-
CTYPE_A,
192-
CTYPE_B,
193-
CTYPE_IN,
194-
CTYPE_OUT>::run(a, b, alpha_val, out);
195-
});
196-
});
197-
});
198-
}
199-
200-
return out;
70+
return torch::executor::kernels::impl::opt_add_sub_out_impl(
71+
ctx, a, b, alpha, out);
20172
}
20273

20374
Tensor& opt_add_scalar_out(
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/optimized/cpu/binary_ops.h>
10+
#include <executorch/kernels/optimized/vec/functional.h>
11+
#include <executorch/kernels/optimized/vec/vec.h>
12+
#include <executorch/kernels/portable/cpu/scalar_utils.h>
13+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14+
#include <executorch/runtime/kernel/kernel_includes.h>
15+
#include <executorch/runtime/platform/assert.h>
16+
17+
namespace torch {
18+
namespace executor {
19+
namespace kernels {
20+
namespace impl {
21+
22+
namespace {
23+
template <
24+
bool can_cast,
25+
typename CTYPE_A,
26+
typename CTYPE_B,
27+
typename CTYPE_IN,
28+
typename CTYPE_OUT>
29+
struct AddInner;
30+
31+
template <
32+
typename CTYPE_A,
33+
typename CTYPE_B,
34+
typename CTYPE_IN,
35+
typename CTYPE_OUT>
36+
struct AddInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
37+
static void
38+
run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) {
39+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
40+
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
41+
[alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) {
42+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
43+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
44+
CTYPE_IN value = a_casted + alpha_val * b_casted;
45+
46+
return static_cast<CTYPE_OUT>(value);
47+
},
48+
a,
49+
b,
50+
out);
51+
}
52+
};
53+
54+
template <typename CTYPE_IN>
55+
struct ReportCanCastBug {
56+
static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) {
57+
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
58+
}
59+
};
60+
61+
template <
62+
typename CTYPE_A,
63+
typename CTYPE_B,
64+
typename CTYPE_IN,
65+
typename CTYPE_OUT>
66+
struct AddInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
67+
: public ReportCanCastBug<CTYPE_IN> {};
68+
69+
} // namespace
70+
71+
using Tensor = executorch::aten::Tensor;
72+
using ScalarType = executorch::aten::ScalarType;
73+
74+
Tensor& opt_add_sub_out_impl(
75+
KernelRuntimeContext& ctx,
76+
const Tensor& a,
77+
const Tensor& b,
78+
const Scalar& alpha,
79+
Tensor& out,
80+
const bool is_sub) {
81+
(void)ctx;
82+
83+
ScalarType a_type = a.scalar_type();
84+
ScalarType b_type = b.scalar_type();
85+
ScalarType out_type = out.scalar_type();
86+
87+
auto selected_optimized_path = select_optimized_path(a, b, out);
88+
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
89+
// Resize for dynamic shape
90+
auto error = resize_tensor(out, a.sizes());
91+
ET_KERNEL_CHECK_MSG(
92+
ctx,
93+
error == Error::Ok,
94+
InvalidArgument,
95+
out,
96+
"Failed to resize output tensor.");
97+
98+
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
99+
CTYPE alpha_val;
100+
ET_KERNEL_CHECK(
101+
ctx,
102+
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
103+
InvalidArgument, );
104+
if (is_sub) {
105+
alpha_val = -alpha_val;
106+
}
107+
using Vec = executorch::vec::Vectorized<CTYPE>;
108+
executorch::vec::map2<CTYPE>(
109+
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
110+
out.mutable_data_ptr<CTYPE>(),
111+
a.const_data_ptr<CTYPE>(),
112+
b.const_data_ptr<CTYPE>(),
113+
out.numel());
114+
});
115+
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
116+
// Cannot apply the trick of -alpha here because alpha is Scalar without
117+
// support for - operator. At least not right now.
118+
if (is_sub) {
119+
if (selected_optimized_path ==
120+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
121+
selected_optimized_path ==
122+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
123+
selected_optimized_path ==
124+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
125+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
126+
return y - alpha_val * x;
127+
};
128+
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kSub>(
129+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
130+
} else {
131+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
132+
return x - alpha_val * y;
133+
};
134+
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kSub>(
135+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
136+
}
137+
} else {
138+
if (selected_optimized_path ==
139+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments ||
140+
selected_optimized_path ==
141+
ElementwiseOptimizedPath::kBroadcastLastDimReverseArguments ||
142+
selected_optimized_path ==
143+
ElementwiseOptimizedPath::kBroadcastNdByNdReverseArguments) {
144+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
145+
return y + alpha_val * x;
146+
};
147+
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>(
148+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
149+
} else {
150+
auto add_lambda = [](auto x, auto y, auto alpha_val) {
151+
return x + alpha_val * y;
152+
};
153+
return torch::executor::handle_broadcast_elementwise<BinaryOpType::kAdd>(
154+
ctx, add_lambda, a, b, out, selected_optimized_path, alpha);
155+
}
156+
}
157+
} else {
158+
ScalarType common_type =
159+
promoteTypes(a_type, b_type, /*half_to_float*/ true);
160+
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
161+
162+
ET_KERNEL_CHECK(
163+
ctx,
164+
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
165+
InvalidArgument,
166+
out);
167+
168+
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
169+
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
170+
using CTYPE_IN = typename torch::executor::
171+
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
172+
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
173+
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
174+
CTYPE_IN alpha_val;
175+
ET_KERNEL_CHECK(
176+
ctx,
177+
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
178+
InvalidArgument, );
179+
if (is_sub) {
180+
alpha_val = -alpha_val;
181+
}
182+
183+
AddInner<
184+
can_cast<CTYPE_IN, CTYPE_OUT>::value,
185+
CTYPE_A,
186+
CTYPE_B,
187+
CTYPE_IN,
188+
CTYPE_OUT>::run(a, b, alpha_val, out);
189+
});
190+
});
191+
});
192+
}
193+
194+
return out;
195+
}
196+
197+
} // namespace impl
198+
} // namespace kernels
199+
} // namespace executor
200+
} // namespace torch

0 commit comments

Comments
 (0)