Skip to content

Commit c35fdea

Browse files
committed
Make BroadcastIndexesRange efficient if there is no broadcasting
It seems to be OK to just check for broadcasting when creating the iterator and then leave a highly-predictable branch inside the loop. Avoids code near-duplication to handle both broadcast and non-broadcast cases. ghstack-source-id: 8be7c18 ghstack-comment-id: 2725945271 Pull Request resolved: #9298
1 parent 91e1037 commit c35fdea

File tree

4 files changed

+47
-100
lines changed

4 files changed

+47
-100
lines changed

kernels/optimized/cpu/op_where.cpp

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,42 +48,24 @@ Tensor& opt_where_out(
4848
cond.scalar_type() == ScalarType::Bool) {
4949
auto out_numel = out.numel();
5050
ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
51-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
52-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
53-
const bool cond_is_broadcasted = !out.sizes().equals(cond.sizes());
54-
const bool any_is_broadcasted =
55-
(a_is_broadcasted || b_is_broadcasted || cond_is_broadcasted);
5651
const CTYPE_COMPUTE* const data_a = a.const_data_ptr<CTYPE_COMPUTE>();
5752
const CTYPE_COMPUTE* const data_b = b.const_data_ptr<CTYPE_COMPUTE>();
5853
const bool* const data_cond = cond.const_data_ptr<bool>();
5954
CTYPE_COMPUTE* const data_out = out.data_ptr<CTYPE_COMPUTE>();
60-
if (any_is_broadcasted) {
61-
executorch::extension::parallel_for(
62-
0,
63-
out_numel,
64-
::executorch::extension::internal::GRAIN_SIZE,
65-
[&](const auto begin, const auto end) {
66-
auto range = BroadcastIndexesRange<3>(out, a, b, cond);
67-
auto begin_it = range.begin();
68-
begin_it += begin;
69-
for (; (*begin_it)[0] < end; ++begin_it) {
70-
const auto [out_index, a_index, b_index, cond_index] =
71-
*begin_it;
72-
data_out[out_index] =
73-
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
74-
}
75-
});
76-
} else {
77-
executorch::extension::parallel_for(
78-
0,
79-
out_numel,
80-
::executorch::extension::internal::GRAIN_SIZE,
81-
[&](const auto begin, const auto end) {
82-
for (const auto i : c10::irange(begin, end)) {
83-
data_out[i] = data_cond[i] ? data_a[i] : data_b[i];
84-
}
85-
});
86-
}
55+
executorch::extension::parallel_for(
56+
0,
57+
out_numel,
58+
::executorch::extension::internal::GRAIN_SIZE,
59+
[&](const auto begin, const auto end) {
60+
auto range = BroadcastIndexesRange<3>(out, a, b, cond);
61+
auto begin_it = range.begin();
62+
begin_it += begin;
63+
for (; (*begin_it)[0] < end; ++begin_it) {
64+
const auto [out_index, a_index, b_index, cond_index] = *begin_it;
65+
data_out[out_index] =
66+
data_cond[cond_index] ? data_a[a_index] : data_b[b_index];
67+
}
68+
});
8769
});
8870
} else {
8971
// Fall back for mixed dtype to keep code size and compile time

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,17 @@ class BroadcastIndexesIterator {
3434

3535
template <typename... Args>
3636
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
37-
: output_dim_(output.dim()),
38-
output_shape_(output.sizes()),
39-
effective_input_broadcast_strides_{
40-
effective_input_broadcast_stride(output, args)...} {
37+
: output_dim_or_zero_if_no_broadcasting_(
38+
((args.sizes() == output.sizes()) && ...) ? 0 : output.dim()),
39+
output_shape_(output.sizes()) {
4140
static_assert(
4241
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
4342
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
4443
"arguments!");
44+
if (output_dim_or_zero_if_no_broadcasting_ != 0) {
45+
effective_input_broadcast_strides_ = {
46+
effective_input_broadcast_stride(output, args)...};
47+
}
4548
}
4649

4750
struct make_end_t {
@@ -73,9 +76,14 @@ class BroadcastIndexesIterator {
7376

7477
BroadcastIndexesIterator& operator++() {
7578
output_index()++;
79+
if (output_dim_or_zero_if_no_broadcasting_ == 0) {
80+
std::fill(
81+
current_indexes_.begin() + 1, current_indexes_.end(), output_index());
82+
return *this;
83+
}
7684
// TODO: add optimization for particular input tensors not being
7785
// broadcasted?
78-
for (auto ii = output_dim_ - 1; ii >= 0; --ii) {
86+
for (auto ii = output_dim_or_zero_if_no_broadcasting_ - 1; ii >= 0; --ii) {
7987
// You might wonder what happens if output_shape_[ii] == 0. In
8088
// that case, output.numel() would be 0, and thus we would have
8189
// begin() == end() and no iteration.
@@ -121,7 +129,8 @@ class BroadcastIndexesIterator {
121129
delinearized_output_index_.size());
122130
for (const auto ii : c10::irange(1, kNumInputs + 1)) {
123131
current_indexes_[ii] = 0;
124-
for (const auto jj : c10::irange(output_dim_)) {
132+
for (const auto jj :
133+
c10::irange(output_dim_or_zero_if_no_broadcasting_)) {
125134
current_indexes_[ii] += delinearized_output_index_[jj] *
126135
effective_input_broadcast_strides_[ii - 1][jj];
127136
}
@@ -180,7 +189,7 @@ class BroadcastIndexesIterator {
180189
// followed by kNumInputs input indexes.
181190
std::array<ssize_t, kNumInputs + 1> current_indexes_ = {0};
182191
ShapeType delinearized_output_index_ = {0};
183-
ssize_t output_dim_;
192+
ssize_t output_dim_or_zero_if_no_broadcasting_;
184193
ArrayRef<exec_aten::SizesType> output_shape_;
185194
// The linear index for a broadcast tensor is
186195
// sum(delinearized_output_index_[i] * input_stride_[i] if
@@ -189,8 +198,7 @@ class BroadcastIndexesIterator {
189198
// output_dim. This is straightforwardly implementable with an
190199
// adjusted stride array that contains 0s where the padded input
191200
// shape would contain 1s.
192-
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_ = {
193-
{{0}}};
201+
std::array<ShapeType, kNumInputs> effective_input_broadcast_strides_;
194202
};
195203
} // namespace internal
196204

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -254,26 +254,13 @@ inline void apply_binary_elementwise_fn(
254254
const Tensor& a,
255255
const Tensor& b,
256256
const Tensor& out) {
257-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
258-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
259-
const bool any_is_broadcasted = (a_is_broadcasted || b_is_broadcasted);
260-
261257
const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
262258
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
263259
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
264260

265-
if (any_is_broadcasted) {
266-
for (const auto [out_index, a_index, b_index] :
267-
BroadcastIndexesRange<2>(out, a, b)) {
268-
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
269-
}
270-
} else {
271-
for (const auto i : c10::irange(out.numel())) {
272-
size_t a_linear_index = i;
273-
size_t b_linear_index = i;
274-
275-
data_out[i] = compute_fun(data_a[a_linear_index], data_b[b_linear_index]);
276-
}
261+
for (const auto [out_index, a_index, b_index] :
262+
BroadcastIndexesRange<2>(out, a, b)) {
263+
data_out[out_index] = compute_fun(data_a[a_index], data_b[b_index]);
277264
}
278265
}
279266

@@ -294,27 +281,15 @@ inline void apply_ternary_elementwise_fn(
294281
const Tensor& b,
295282
const Tensor& c,
296283
const Tensor& out) {
297-
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
298-
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
299-
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
300-
const bool any_is_broadcasted =
301-
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
302-
303284
const CTYPE_A* const data_a = a.const_data_ptr<CTYPE_A>();
304285
const CTYPE_B* const data_b = b.const_data_ptr<CTYPE_B>();
305286
const CTYPE_C* const data_c = c.const_data_ptr<CTYPE_C>();
306287
CTYPE_OUT* const data_out = out.mutable_data_ptr<CTYPE_OUT>();
307288

308-
if (any_is_broadcasted) {
309-
for (const auto [out_index, a_index, b_index, c_index] :
310-
BroadcastIndexesRange<3>(out, a, b, c)) {
311-
data_out[out_index] =
312-
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
313-
}
314-
} else {
315-
for (const auto i : c10::irange(out.numel())) {
316-
data_out[i] = compute_fun(data_a[i], data_b[i], data_c[i]);
317-
}
289+
for (const auto [out_index, a_index, b_index, c_index] :
290+
BroadcastIndexesRange<3>(out, a, b, c)) {
291+
data_out[out_index] =
292+
compute_fun(data_a[a_index], data_b[b_index], data_c[c_index]);
318293
}
319294
}
320295

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ inline void apply_elementwise_fn(
7676
internal::check_tensor_dtype(out, out_dtypes, compute_type),
7777
InvalidArgument, );
7878

79-
bool any_is_broadcasted = false;
80-
if constexpr (kNumInputs > 1) {
81-
any_is_broadcasted = (!out.sizes().equals(inputs.first->sizes()) || ...);
82-
}
83-
8479
struct InputInfo {
8580
load_to_common_fn<CTYPE_COMMON> load_to_common;
8681
const char* data_ptr;
@@ -99,29 +94,16 @@ inline void apply_elementwise_fn(
9994
char* const data_out = reinterpret_cast<char*>(out.mutable_data_ptr());
10095
const auto out_element_size = out.element_size();
10196

102-
if (any_is_broadcasted) {
103-
for (const auto& indexes :
104-
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
105-
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
106-
for (const auto idx : c10::irange(kNumInputs)) {
107-
const auto& input_info = inputs_info[idx];
108-
loaded_inputs[idx] = input_info.load_to_common(
109-
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
110-
}
111-
auto result = std::apply(compute_fun, loaded_inputs);
112-
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
113-
}
114-
} else {
115-
for (const auto i : c10::irange(out.numel())) {
116-
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
117-
for (const auto idx : c10::irange(kNumInputs)) {
118-
const auto& input_info = inputs_info[idx];
119-
loaded_inputs[idx] = input_info.load_to_common(
120-
&input_info.data_ptr[i * input_info.element_size]);
121-
}
122-
auto result = std::apply(compute_fun, loaded_inputs);
123-
store_common_to_out(result, &data_out[i * out_element_size]);
97+
for (const auto& indexes :
98+
BroadcastIndexesRange<kNumInputs>(out, (*inputs.first)...)) {
99+
std::array<CTYPE_COMMON, kNumInputs> loaded_inputs;
100+
for (const auto idx : c10::irange(kNumInputs)) {
101+
const auto& input_info = inputs_info[idx];
102+
loaded_inputs[idx] = input_info.load_to_common(
103+
&input_info.data_ptr[indexes[idx + 1] * input_info.element_size]);
124104
}
105+
auto result = std::apply(compute_fun, loaded_inputs);
106+
store_common_to_out(result, &data_out[indexes[0] * out_element_size]);
125107
}
126108
}
127109
} // namespace internal

0 commit comments

Comments
 (0)