Skip to content

Commit c5e457a

Browse files
authored
special case dim-list ops for dim list having size 1 (#9111)
Applying over a single dim is much faster; the general apply-over-dim-list logic has to do extra work (tracking the current delinearized index).
1 parent 6f0ae19 commit c5e457a

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

kernels/portable/cpu/util/reduce_util.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ class ApplyOverDimListPlan {
347347
return;
348348
}
349349
dim_list_ = dim_list.value();
350+
if (dim_list_.value().size() == 1) {
351+
mode_ = ExecutionMode::OnlyOneDim;
352+
return;
353+
}
350354
is_in_dim_list_.fill(0);
351355
for (const auto& d : dim_list.value()) {
352356
const size_t non_neg_d = d < 0 ? d + in.dim() : d;
@@ -367,6 +371,14 @@ class ApplyOverDimListPlan {
367371
apply_on_flat_ix_with_stride_and_base(
368372
fn, /*stride=*/1, /*base=*/0, ustart_, uend_);
369373
return;
374+
case ExecutionMode::OnlyOneDim:
375+
apply_on_flat_and_dim_ix_with_stride_and_base(
376+
[&](const auto in_ix, const auto dim_ix) { fn(in_ix); },
377+
in_.strides()[ET_NORMALIZE_IX(dim_list_.value()[0], in_.dim())],
378+
get_init_index(in_, dim_list_.value(), out_ix),
379+
ustart_,
380+
uend_);
381+
return;
370382
case ExecutionMode::NormalDimMask:
371383
apply_on_flat_ix_with_dim_mask_and_base(
372384
fn,
@@ -399,6 +411,9 @@ class ApplyOverDimListPlan {
399411
// Iterate over the entire tensor with
400412
// apply_on_flat_ix_with_stride_and_base.
401413
NoDimMaskOrZeroDimension,
414+
// dim_list has size 1, iterate with
415+
// apply_on_flat_and_dim_ix_with_stride_and_base
416+
OnlyOneDim,
402417
// General mode, iterate with
403418
// apply_on_flat_ix_with_dim_mask_and_base.
404419
NormalDimMask

0 commit comments

Comments
 (0)