Skip to content

Commit 6f0ae19

Browse files
authored
Replace {map_,}reduce_over_dim_list with {Map,}ReduceOverDimListPlan in kernels/portable (#9110)
Plan-then-execute for reductions should avoid repeating prepratory work.
1 parent 5fa637a commit 6f0ae19

File tree

6 files changed

+26
-23
lines changed

6 files changed

+26
-23
lines changed

kernels/portable/cpu/op_amax.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,14 @@ Tensor& amax_out(
4343
ET_KERNEL_CHECK(
4444
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4545

46+
ReduceOverDimListPlan plan(in, dim_list);
4647
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amax.out", CTYPE, [&]() {
4748
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
4849
for (const auto out_ix : c10::irange(out.numel())) {
49-
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
50+
out_data[out_ix] = plan.execute<CTYPE>(
5051
[](CTYPE v, CTYPE max_v) {
5152
return std::isnan(v) || v > max_v ? v : max_v;
5253
},
53-
in,
54-
dim_list,
5554
out_ix);
5655
}
5756
});

kernels/portable/cpu/op_amin.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,14 @@ Tensor& amin_out(
4242
ET_KERNEL_CHECK(
4343
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
4444

45+
ReduceOverDimListPlan plan(in, dim_list);
4546
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "amin.out", CTYPE, [&]() {
4647
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
4748
for (const auto out_ix : c10::irange(out.numel())) {
48-
out_data[out_ix] = reduce_over_dim_list<CTYPE>(
49+
out_data[out_ix] = plan.execute<CTYPE>(
4950
[](CTYPE v, CTYPE min_v) {
5051
return std::isnan(v) || v < min_v ? v : min_v;
5152
},
52-
in,
53-
dim_list,
5453
out_ix);
5554
}
5655
});

kernels/portable/cpu/op_any.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212

13+
#include <optional>
14+
1315
namespace torch {
1416
namespace executor {
1517
namespace native {
@@ -79,6 +81,11 @@ Tensor& any_dims_out(
7981
ScalarType out_type = out.scalar_type();
8082
constexpr auto name = "any.dims_out";
8183

84+
const bool in_not_empty = in.numel() > 0;
85+
std::optional<MapReduceOverDimListPlan> plan;
86+
if ((!dim_list.has_value() || !dim_list.value().empty()) && in_not_empty) {
87+
plan.emplace(in, dim_list);
88+
}
8289
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
8390
ET_SWITCH_TWO_TYPES(Bool, Byte, out_type, ctx, name, CTYPE_OUT, [&] {
8491
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
@@ -91,12 +98,10 @@ Tensor& any_dims_out(
9198
} else {
9299
for (const auto out_ix : c10::irange(out.numel())) {
93100
bool any = false;
94-
if (in.numel() > 0) {
95-
any = map_reduce_over_dim_list<CTYPE_IN, bool>(
101+
if (in_not_empty) {
102+
any = plan->execute<CTYPE_IN, bool>(
96103
[](CTYPE_IN v) { return static_cast<bool>(v); },
97104
[](bool outv, bool acc) { return acc || outv; },
98-
in,
99-
dim_list,
100105
out_ix);
101106
}
102107
out_data[out_ix] = static_cast<CTYPE_OUT>(any);

kernels/portable/cpu/op_mean.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Tensor& mean_dim_out(
4545
InvalidArgument,
4646
out);
4747

48+
MapReduceOverDimListPlan plan(in, dim_list);
4849
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
4950
ET_SWITCH_FLOATHBF16_TYPES(
5051
out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
@@ -53,11 +54,9 @@ Tensor& mean_dim_out(
5354
for (const auto out_ix : c10::irange(out.numel())) {
5455
CTYPE_OUT sum = 0;
5556
if (in.numel() > 0) {
56-
sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
57+
sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
5758
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
5859
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
59-
in,
60-
dim_list,
6160
out_ix);
6261
}
6362
out_data[out_ix] = sum / static_cast<float>(num);

kernels/portable/cpu/op_sum.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
#include <executorch/runtime/kernel/kernel_includes.h>
1212
#include <executorch/runtime/platform/assert.h>
1313

14+
#include <optional>
15+
1416
namespace torch {
1517
namespace executor {
1618
namespace native {
@@ -44,19 +46,21 @@ Tensor& sum_dim_out(
4446

4547
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
4648

49+
std::optional<MapReduceOverDimListPlan> plan;
50+
if (in.numel() > 0) {
51+
plan.emplace(in, dim_list);
52+
}
4753
ET_SWITCH_REALHBBF16_TYPES(
4854
in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] {
4955
ET_SWITCH_REALHBBF16_TYPES(
5056
out.scalar_type(), ctx, "sum.IntList_out", CTYPE_OUT, [&] {
5157
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
5258
for (const auto out_ix : c10::irange(out.numel())) {
5359
CTYPE_OUT sum = 0;
54-
if (in.numel() > 0) {
55-
sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
60+
if (plan.has_value()) {
61+
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
5662
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
5763
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
58-
in,
59-
dim_list,
6064
out_ix);
6165
}
6266
out_data[out_ix] = sum;

kernels/portable/cpu/op_var.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,20 @@ void compute_variance(
3232
out_data[out_ix] = NAN;
3333
}
3434
} else {
35+
MapReduceOverDimListPlan plan(in, dim_list);
3536
for (const auto out_ix : c10::irange(out.numel())) {
36-
CTYPE_OUT sum = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
37+
CTYPE_OUT sum = plan.execute<CTYPE_IN, CTYPE_OUT>(
3738
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
3839
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
39-
in,
40-
dim_list,
4140
out_ix);
4241
CTYPE_OUT mean = sum / static_cast<CTYPE_OUT>(num);
43-
CTYPE_OUT sum2 = map_reduce_over_dim_list<CTYPE_IN, CTYPE_OUT>(
42+
CTYPE_OUT sum2 = plan.execute<CTYPE_IN, CTYPE_OUT>(
4443
[mean](CTYPE_IN v) {
4544
return (
4645
(static_cast<CTYPE_OUT>(v) - mean) *
4746
(static_cast<CTYPE_OUT>(v) - mean));
4847
},
4948
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
50-
in,
51-
dim_list,
5249
out_ix);
5350
out_data[out_ix] = sum2 / denominator;
5451
}

0 commit comments

Comments
 (0)