Skip to content

Commit 858e9fd

Browse files
authored
[executorch] Propagate mul optimizations from D61504544/D61560825/D61560826 to add/sub/div
Differential Revision: D61577411 Pull Request resolved: #4816
1 parent 8e3361e commit 858e9fd

File tree

9 files changed

+465
-101
lines changed

9 files changed

+465
-101
lines changed

kernels/optimized/cpu/binary_ops.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace internal {
16+
// NOTE: we bake ArrayRef iterators being pointers into the return
17+
// type here because we assume that iterators are portable across
18+
// ArrayRef copies.
19+
inline const Tensor::SizesType* arrayref_begin_ignoring_leading_1s(
20+
ArrayRef<Tensor::SizesType> arr) {
21+
return std::find_if(
22+
arr.begin(), arr.end(), [](Tensor::SizesType x) { return x != 1; });
23+
}
24+
25+
inline bool sizes_match_ignoring_leading_1s(
26+
ArrayRef<Tensor::SizesType> lhs,
27+
ArrayRef<Tensor::SizesType> rhs) {
28+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs);
29+
auto lhs_end = lhs.end();
30+
31+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs);
32+
auto rhs_end = rhs.end();
33+
34+
return ((lhs_end - lhs_begin) == (rhs_end - rhs_begin)) &&
35+
std::equal(lhs_begin, lhs_end, rhs_begin);
36+
}
37+
} // namespace internal
38+
39+
enum class ElementwiseOptimizedPath {
40+
kNone,
41+
kTreatAs1d,
42+
kBroadcast2dBy1d,
43+
kBroadcast2dBy1dReverseArguments,
44+
};
45+
46+
namespace internal {
47+
inline ElementwiseOptimizedPath select_broadcast_2d_by_1d_optimized_path(
48+
const Tensor& lhs,
49+
const Tensor& rhs) {
50+
auto lhs_begin = arrayref_begin_ignoring_leading_1s(lhs.sizes());
51+
auto lhs_end = lhs.sizes().end();
52+
53+
auto rhs_begin = arrayref_begin_ignoring_leading_1s(rhs.sizes());
54+
auto rhs_end = rhs.sizes().end();
55+
56+
const auto lhs_size = lhs_end - lhs_begin;
57+
const auto rhs_size = rhs_end - rhs_begin;
58+
if (lhs_size == 2 && rhs_size == 1 && lhs_begin[1] == rhs_begin[0]) {
59+
return ElementwiseOptimizedPath::kBroadcast2dBy1d;
60+
}
61+
62+
if (lhs_size == 1 && rhs_size == 2 && rhs_begin[1] == lhs_begin[0]) {
63+
return ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments;
64+
}
65+
66+
return ElementwiseOptimizedPath::kNone;
67+
}
68+
} // namespace internal
69+
70+
ElementwiseOptimizedPath inline select_optimized_path(
71+
const Tensor& a,
72+
const Tensor& b,
73+
const Tensor& out) {
74+
ScalarType a_type = a.scalar_type();
75+
ScalarType b_type = b.scalar_type();
76+
ScalarType out_type = out.scalar_type();
77+
78+
if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half) {
79+
return ElementwiseOptimizedPath::kNone;
80+
}
81+
if (a.sizes().equals(b.sizes()) ||
82+
(a.numel() == b.numel() &&
83+
(a.numel() == out.numel() ||
84+
internal::sizes_match_ignoring_leading_1s(a.sizes(), b.sizes())))) {
85+
return ElementwiseOptimizedPath::kTreatAs1d;
86+
}
87+
return internal::select_broadcast_2d_by_1d_optimized_path(a, b);
88+
}
89+
90+
} // namespace executor
91+
} // namespace torch

kernels/optimized/cpu/op_add.cpp

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/optimized/cpu/binary_ops.h>
910
#include <executorch/kernels/optimized/vec/functional.h>
1011
#include <executorch/kernels/optimized/vec/vec.h>
1112
#include <executorch/kernels/portable/cpu/scalar_utils.h>
@@ -81,8 +82,41 @@ Tensor& opt_add_out(
8182
ScalarType b_type = b.scalar_type();
8283
ScalarType out_type = out.scalar_type();
8384

84-
if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes()) &&
85-
a_type != ScalarType::Half) {
85+
if (b.numel() == 1) {
86+
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
87+
auto error = resize_tensor(out, a.sizes());
88+
ET_KERNEL_CHECK_MSG(
89+
ctx,
90+
error == Error::Ok,
91+
InvalidArgument,
92+
out,
93+
"Failed to resize output tensor.");
94+
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.out", CTYPE, [&]() {
95+
ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
96+
CTYPE alpha_val;
97+
ET_KERNEL_CHECK(
98+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
99+
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
100+
CTYPE b_casted = static_cast<CTYPE>(b_val);
101+
102+
using Vec = executorch::vec::Vectorized<CTYPE>;
103+
executorch::vec::map<CTYPE>(
104+
[alpha_val, b_casted](Vec x) {
105+
return x + Vec(alpha_val * b_casted);
106+
},
107+
out.mutable_data_ptr<CTYPE>(),
108+
a.const_data_ptr<CTYPE>(),
109+
out.numel());
110+
});
111+
});
112+
return out;
113+
}
114+
} else if (a.numel() == 1) {
115+
return opt_add_out(ctx, b, a, alpha, out);
116+
}
117+
118+
auto selected_optimized_path = select_optimized_path(a, b, out);
119+
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
86120
// Resize for dynamic shape
87121
auto error = resize_tensor(out, a.sizes());
88122
ET_KERNEL_CHECK_MSG(
@@ -105,6 +139,42 @@ Tensor& opt_add_out(
105139
b.const_data_ptr<CTYPE>(),
106140
out.numel());
107141
});
142+
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
143+
const Tensor* lhs;
144+
const Tensor* rhs;
145+
if (selected_optimized_path ==
146+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
147+
lhs = &b;
148+
rhs = &a;
149+
} else {
150+
// Catch failure to update logic when adding new broadcasting possibility.
151+
ET_DCHECK(
152+
selected_optimized_path ==
153+
ElementwiseOptimizedPath::kBroadcast2dBy1d);
154+
lhs = &a;
155+
rhs = &b;
156+
}
157+
auto error = resize_tensor(out, lhs->sizes());
158+
ET_KERNEL_CHECK_MSG(
159+
ctx,
160+
error == Error::Ok,
161+
InvalidArgument,
162+
out,
163+
"Failed to resize output tensor.");
164+
ET_SWITCH_REALB_TYPES(out_type, ctx, "add.out", CTYPE, [&]() {
165+
CTYPE alpha_val;
166+
ET_KERNEL_CHECK(
167+
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
168+
169+
using Vec = executorch::vec::Vectorized<CTYPE>;
170+
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
171+
[alpha_val](Vec x, Vec y) { return x + Vec(alpha_val) * y; },
172+
out.mutable_data_ptr<CTYPE>(),
173+
lhs->const_data_ptr<CTYPE>(),
174+
rhs->const_data_ptr<CTYPE>(),
175+
lhs->sizes()[lhs->dim() - 2],
176+
lhs->sizes()[lhs->dim() - 1]);
177+
});
108178
} else {
109179
ScalarType common_type =
110180
promoteTypes(a_type, b_type, /*half_to_float*/ true);

kernels/optimized/cpu/op_div.cpp

Lines changed: 112 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/kernels/optimized/cpu/binary_ops.h>
910
#include <executorch/kernels/optimized/vec/functional.h>
1011
#include <executorch/kernels/optimized/vec/vec.h>
1112
#include <executorch/kernels/portable/cpu/scalar_utils.h>
@@ -48,7 +49,57 @@ Tensor& opt_div_out(
4849
ScalarType b_type = b.scalar_type();
4950
ScalarType out_type = out.scalar_type();
5051

51-
if (a_type == b_type && a_type == out_type && a.sizes().equals(b.sizes())) {
52+
if (a.numel() == 1 || b.numel() == 1) {
53+
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
54+
const Tensor* tensor;
55+
const Tensor* scalar;
56+
ScalarType tensor_type;
57+
ScalarType scalar_type;
58+
if (a.numel() == 1) {
59+
tensor = &b;
60+
tensor_type = b_type;
61+
scalar = &a;
62+
scalar_type = a_type;
63+
} else {
64+
tensor = &a;
65+
tensor_type = a_type;
66+
scalar = &b;
67+
scalar_type = b_type;
68+
}
69+
auto error = resize_tensor(out, tensor->sizes());
70+
ET_KERNEL_CHECK_MSG(
71+
ctx,
72+
error == Error::Ok,
73+
InvalidArgument,
74+
out,
75+
"Failed to resize output tensor.");
76+
ET_SWITCH_REALB_TYPES(tensor_type, ctx, "div.out", CTYPE, [&]() {
77+
ET_SWITCH_REALB_TYPES(scalar_type, ctx, "div.out", CTYPE_SCALAR, [&]() {
78+
CTYPE_SCALAR scalar_val = *scalar->const_data_ptr<CTYPE_SCALAR>();
79+
CTYPE scalar_casted = static_cast<CTYPE>(scalar_val);
80+
81+
using Vec = executorch::vec::Vectorized<CTYPE>;
82+
if (a.numel() == 1) {
83+
executorch::vec::map<CTYPE>(
84+
[scalar_casted](Vec x) { return Vec(scalar_casted) / x; },
85+
out.mutable_data_ptr<CTYPE>(),
86+
tensor->const_data_ptr<CTYPE>(),
87+
out.numel());
88+
} else {
89+
executorch::vec::map<CTYPE>(
90+
[scalar_casted](Vec x) { return x / Vec(scalar_casted); },
91+
out.mutable_data_ptr<CTYPE>(),
92+
tensor->const_data_ptr<CTYPE>(),
93+
out.numel());
94+
}
95+
});
96+
});
97+
return out;
98+
}
99+
}
100+
101+
auto selected_optimized_path = select_optimized_path(a, b, out);
102+
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
52103
// Resize for dynamic shape
53104
auto error = resize_tensor(out, a.sizes());
54105
ET_KERNEL_CHECK_MSG(
@@ -67,6 +118,49 @@ Tensor& opt_div_out(
67118
b.const_data_ptr<CTYPE>(),
68119
out.numel());
69120
});
121+
} else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) {
122+
const Tensor* lhs;
123+
const Tensor* rhs;
124+
if (selected_optimized_path ==
125+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
126+
lhs = &b;
127+
rhs = &a;
128+
} else {
129+
// Catch failure to update logic when subing new broadcasting possibility.
130+
ET_DCHECK(
131+
selected_optimized_path ==
132+
ElementwiseOptimizedPath::kBroadcast2dBy1d);
133+
lhs = &a;
134+
rhs = &b;
135+
}
136+
auto error = resize_tensor(out, lhs->sizes());
137+
ET_KERNEL_CHECK_MSG(
138+
ctx,
139+
error == Error::Ok,
140+
InvalidArgument,
141+
out,
142+
"Failed to resize output tensor.");
143+
ET_SWITCH_REALB_TYPES(out_type, ctx, "sub.out", CTYPE, [&]() {
144+
using Vec = executorch::vec::Vectorized<CTYPE>;
145+
if (selected_optimized_path ==
146+
ElementwiseOptimizedPath::kBroadcast2dBy1dReverseArguments) {
147+
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
148+
[](Vec x, Vec y) { return y / x; },
149+
out.mutable_data_ptr<CTYPE>(),
150+
lhs->const_data_ptr<CTYPE>(),
151+
rhs->const_data_ptr<CTYPE>(),
152+
lhs->sizes()[lhs->dim() - 2],
153+
lhs->sizes()[lhs->dim() - 1]);
154+
} else {
155+
executorch::vec::broadcasting_map_2d_by_1d<CTYPE>(
156+
[](Vec x, Vec y) { return x / y; },
157+
out.mutable_data_ptr<CTYPE>(),
158+
lhs->const_data_ptr<CTYPE>(),
159+
rhs->const_data_ptr<CTYPE>(),
160+
lhs->sizes()[lhs->dim() - 2],
161+
lhs->sizes()[lhs->dim() - 1]);
162+
}
163+
});
70164
} else {
71165
ScalarType common_type = get_compute_type(a_type, b_type);
72166
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);
@@ -77,25 +171,23 @@ Tensor& opt_div_out(
77171
InvalidArgument,
78172
out);
79173

80-
ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.out", CTYPE_A, [&]() {
81-
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out", CTYPE_B, [&]() {
82-
ET_SWITCH_REAL_TYPES_AND(
83-
Bool, common_type, ctx, "div.out", CTYPE_IN, [&]() {
84-
ET_SWITCH_REAL_TYPES_AND(
85-
Bool, out_type, ctx, "div.out", CTYPE_OUT, [&]() {
86-
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
87-
[](const CTYPE_A val_a, const CTYPE_B val_b) {
88-
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
89-
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
90-
CTYPE_IN value = a_casted / b_casted;
91-
92-
return static_cast<CTYPE_OUT>(value);
93-
},
94-
a,
95-
b,
96-
out);
97-
});
98-
});
174+
ET_SWITCH_REALB_TYPES(a_type, ctx, "div.out", CTYPE_A, [&]() {
175+
ET_SWITCH_REALB_TYPES(b_type, ctx, "div.out", CTYPE_B, [&]() {
176+
ET_SWITCH_REALB_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() {
177+
ET_SWITCH_REALB_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() {
178+
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
179+
[](const CTYPE_A val_a, const CTYPE_B val_b) {
180+
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
181+
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
182+
CTYPE_IN value = a_casted / b_casted;
183+
184+
return static_cast<CTYPE_OUT>(value);
185+
},
186+
a,
187+
b,
188+
out);
189+
});
190+
});
99191
});
100192
});
101193
}

0 commit comments

Comments
 (0)