Skip to content

Commit 89a0bdf

Browse files
authored
Make BroadcastIndexesRange efficient if there is no broadcasting (#9298)
It seems to be OK to just check for broadcasting when creating the iterator and then leave a highly-predictable branch inside the loop. Avoids code near-duplication to handle both broadcast and non-broadcast cases.
1 parent c4c2aaf commit 89a0bdf

File tree

4 files changed

+47
-100
lines changed

4 files changed

+47
-100
lines changed

kernels/optimized/cpu/op_where.cpp

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,42 +48,24 @@ Tensor& opt_where_out(
4848
cond.scalar_type() == ScalarType::Bool) {
4949
auto out_numel = out.numel();
5050
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
51-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
52-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
53-
const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes());
54-
const bool any_is_broadcasted =
55-
(a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted);
5651
const CTYPE_COMPUTE* const data_a = a.const_data_ptr<CTYPE_COMPUTE>();
5752
const CTYPE_COMPUTE* const data_b = b.const_data_ptr<CTYPE_COMPUTE>();
5853
const bool* const data_cond = cond.const_data_ptr<bool>();
5954
CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>();
60-
if (any_is_broadcasted) {
61-
executorch::extension::parallel_for(
62-
0,
63-
out_numel,
64-
::executorch::extension::internal::GRAIN_SIZE,
65-
[&](const auto begin, const auto end) {
66-
auto range = BroadcastIndexesRange<3>(out, a, b, cond);
67-
auto begin_it = range.begin();
68-
begin_it += begin;
69-
for (; (*begin_it)[0] < end; ++begin_it) {
70-
const auto [out_index, a_index, b_index, cond_index] =
71-
*begin_it;
72-
data_out[out_index] =
73-
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
74-
}
75-
});
76-
} else {
77-
executorch::extension::parallel_for(
78-
0,
79-
out_numel,
80-
::executorch::extension::internal::GRAIN_SIZE,
81-
[&](const auto begin, const auto end) {
82-
for (const auto i : c10::irange(begin, end)) {
83-
data_out[i] = data_cond[i] ? data_a[i] : data_b[i];
84-
}
85-
});
86-
}
55+
executorch::extension::parallel_for(
56+
0,
57+
out_numel,
58+
::executorch::extension::internal::GRAIN_SIZE,
59+
[&](const auto begin, const auto end) {
60+
auto range = BroadcastIndexesRange<3>(out, a, b, cond);
61+
auto begin_it = range.begin();
62+
begin_it += begin;
63+
for (; (*begin_it)[0] < end; ++begin_it) {
64+
const auto [out_index, a_index, b_index, cond_index] = *begin_it;
65+
data_out[out_index] =
66+
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
67+
}
68+
});
8769
});
8870
} else {
8971
// Fall back for mixed dtype to keep code size and compile time

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,17 @@ class BroadcastIndexesIterator {
3434

3535
template <typename... Args>
3636
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
37-
: output_dim_(output.dim()),
38-
output_shape_(output.sizes()),
39-
effective_input_broadcast_strides_{
40-
effective_input_broadcast_stride(output, args)...} {
37+
: output_dim_or_zero_if_no_broadcasting_(
38+
((args.sizes() == output.sizes()) && ...) ? 0 : output.dim()),
39+
output_shape_(output.sizes()) {
4140
static_assert(
4241
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
4342
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
4443
"arguments!");
44+
if (output_dim_or_zero_if_no_broadcasting_ != 0) {
45+
effective_input_broadcast_strides_ = {
46+
effective_input_broadcast_stride(output, args)...};
47+
}
4548
}
4649

4750
struct make_end_t {
@@ -73,9 +76,14 @@ class BroadcastIndexesIterator {
7376

7477
BroadcastIndexesIterator& operator++() {
7578
output_index()++;
79+
if (output_dim_or_zero_if_no_broadcasting_ == 0) {
80+
std::fill(
81+
current_indexes_.begin() + 1, current_indexes_.end(), output_index());
82+
return *this;
83+
}
7684
// TODO: add optimization for particular input tensors not being
7785
// broadcasted?
78-
for (auto ii = output_dim_ - 1; ii >= 0; --ii) {
86+
for (auto ii = output_dim_or_zero_if_no_broadcasting_ - 1; ii >= 0; --ii) {
7987
// You might wonder what happens if output_shape_[ii] == 0. In
8088
// that case, output.numel() would be 0, and thus we would have
8189
// begin() == end() and no iteration.
@@ -121,7 +129,8 @@ class BroadcastIndexesIterator {
121129
delinearized_output_index_.size());
122130
for (const auto ii : c10::irange(1, kNumInputs + 1)) {
123131
current_indexes_[ii] = 0;
124-
for (const auto jj : c10::irange(output_dim_)) {
132+
for (const auto jj :
133+
c10::irange(output_dim_or_zero_if_no_broadcasting_)) {
125134
current_indexes_[ii] += delinearized_output_index_[jj] *
126135
effective_input_broadcast_strides_[ii - 1][jj];
127136
}
@@ -180,7 +189,7 @@ class BroadcastIndexesIterator {
180189
// followed by kNumInputs input indexes.
181190
std::array<ssize_t, kNumInputs + 1> current_indexes_ = {0};
182191
ShapeType delinearized_output_index_ = {0};
183-
ssize_t output_dim_;
192+
ssize_t output_dim_or_zero_if_no_broadcasting_;
184193
ArrayRef<exec_aten::SizesType> output_shape_;
185194
// The linear index for a broadcast tensor is
186195
// sum(delinearized_output_index_[i] * input_stride_[i] if
@@ -189,8 +198,7 @@ class BroadcastIndexesIterator {
189198
// output_dim. This is straightforwardly implementable with an
190199
// adjusted stride array that contains 0s where the padded input
191200
// shape would contain 1s.
192-
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_ = {
193-
{{0}}};
201+
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_;
194202
};
195203
} // namespace internal
196204

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -254,26 +254,13 @@ inline void apply_binary_elementwise_fn(
254254
const Tensor& a,
255255
const Tensor& b,
256256
const Tensor& out) {
257-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
258-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
259-
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
260-
261257
const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
262258
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
263259
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
264260

265-
if (any_is_broadcasted) {
266-
for (const auto [out_index, a_index, b_index] :
267-
BroadcastIndexesRange<2>(out, a, b)) {
268-
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
269-
}
270-
} else {
271-
for (const auto i : c10::irange(out.numel())) {
272-
size_t a_linear_index = i;
273-
size_t b_linear_index = i;
274-
275-
data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
276-
}
261+
for (const auto [out_index, a_index, b_index] :
262+
BroadcastIndexesRange<2>(out, a, b)) {
263+
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
277264
}
278265
}
279266

@@ -294,27 +281,15 @@ inline void apply_ternary_elementwise_fn(
294281
const Tensor& b,
295282
const Tensor& c,
296283
const Tensor& out) {
297-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
298-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
299-
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
300-
const bool any_is_broadcasted =
301-
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
302-
303284
const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
304285
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
305286
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
306287
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
307288

308-
if (any_is_broadcasted) {
309-
for (const auto [out_index, a_index, b_index, c_index] :
310-
BroadcastIndexesRange<3>(out, a, b, c)) {
311-
data_out[out_index] =
312-
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
313-
}
314-
} else {
315-
for (const auto i : c10::irange(out.numel())) {
316-
data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]);
317-
}
289+
for (const auto [out_index, a_index, b_index, c_index] :
290+
BroadcastIndexesRange<3>(out, a, b, c)) {
291+
data_out[out_index] =
292+
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
318293
}
319294
}
320295

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ inline void apply_elementwise_fn(
7676
internal::check_tensor_dtype(out, out_dtypes, compute_type),
7777
InvalidArgument, );
7878

79-
bool any_is_broadcasted = false;
80-
if constexpr (kNumInputs > 1) {
81-
any_is_broadcasted = (!out.sizes().equals(inputs.first->sizes()) || ...);
82-
}
83-
8479
struct InputInfo {
8580
load_to_common_fn<CTYPE_COMMON> load_to_common;
8681
const char* data_ptr;
@@ -99,29 +94,16 @@ inline void apply_elementwise_fn(
9994
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
10095
const auto out_element_size = out.element_size();
10196

102-
if (any_is_broadcasted) {
103-
for (const auto& indexes :
104-
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
105-
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
106-
for (const auto idx : c10::irange(kNumInputs)) {
107-
const auto& input_info = inputs_info[idx];
108-
loaded_inputs[idx] = input_info.load_to_common(
109-
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
110-
}
111-
auto result = std::apply(compute_fun, loaded_inputs);
112-
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
113-
}
114-
} else {
115-
for (const auto i : c10::irange(out.numel())) {
116-
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
117-
for (const auto idx : c10::irange(kNumInputs)) {
118-
const auto& input_info = inputs_info[idx];
119-
loaded_inputs[idx] = input_info.load_to_common(
120-
&input_info.data_ptr[i * input_info.element_size]);
121-
}
122-
auto result = std::apply(compute_fun, loaded_inputs);
123-
store_common_to_out(result, &data_out[i * out_element_size]);
97+
for (const auto& indexes :
98+
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
99+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
100+
for (const auto idx : c10::irange(kNumInputs)) {
101+
const auto& input_info = inputs_info[idx];
102+
loaded_inputs[idx] = input_info.load_to_common(
103+
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
124104
}
105+
auto result = std::apply(compute_fun, loaded_inputs);
106+
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
125107
}
126108
}
127109
} // namespace internal

0 commit comments

Comments
 (0)