Skip to content

Commit 8dd8eb1

Browse files
authored
Refactor: MapReduceOverDimListPlan (#9108)
Another step toward plan-then-execute for reductions.
1 parent a08c58b commit 8dd8eb1

File tree

1 file changed

+60
-32
lines changed

1 file changed

+60
-32
lines changed

kernels/portable/cpu/util/reduce_util.h

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ class ApplyOverDimListPlan {
329329
dim_list,
330330
const int64_t start = 0,
331331
const int64_t end = -1)
332-
: in_(in) {
332+
: dim_list_(dim_list), in_(in) {
333333
ET_CHECK(check_dim_list_is_valid(in, dim_list));
334334
out_numel_ = get_out_numel(in_, dim_list);
335335
if (in.numel() == 0) {
@@ -372,13 +372,22 @@ class ApplyOverDimListPlan {
372372
fn,
373373
in_,
374374
is_in_dim_list_.data(),
375-
get_init_index(in_, dim_list_, out_ix),
375+
get_init_index(in_, dim_list_.value(), out_ix),
376376
ustart_,
377377
uend_);
378378
return;
379379
}
380380
}
381381

382+
const executorch::aten::Tensor& get_input_tensor() const {
383+
return in_;
384+
}
385+
386+
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
387+
get_dim_list() const {
388+
return dim_list_;
389+
}
390+
382391
private:
383392
// Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
384393
size_t ustart_;
@@ -396,7 +405,7 @@ class ApplyOverDimListPlan {
396405
};
397406
ExecutionMode mode_;
398407
size_t out_numel_;
399-
executorch::aten::ArrayRef<int64_t> dim_list_;
408+
executorch::aten::optional<executorch::aten::ArrayRef<int64_t>> dim_list_;
400409
std::array<bool, kTensorDimensionLimit> is_in_dim_list_;
401410
const executorch::aten::Tensor& in_;
402411
};
@@ -502,6 +511,52 @@ std::tuple<CTYPE_OUT, long> map_reduce_over_dim(
502511
return std::tuple<CTYPE_OUT, long>{acc_val, acc_ix};
503512
}
504513

514+
/**
515+
* Execution plan for repeated map_reduce_over_dim_list with the same
516+
* function, input tensor, and dim_list but varying out_ix.
517+
*/
518+
class MapReduceOverDimListPlan {
519+
public:
520+
MapReduceOverDimListPlan(
521+
const executorch::aten::Tensor& in,
522+
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
523+
dim_list)
524+
: plan_(in, dim_list, 1, -1) {
525+
ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty");
526+
}
527+
528+
template <
529+
typename CTYPE_IN,
530+
typename CTYPE_OUT,
531+
typename MapOp,
532+
typename ReduceOp>
533+
CTYPE_OUT execute(
534+
const MapOp& map_fun,
535+
const ReduceOp& reduce_fun,
536+
const size_t out_ix) const {
537+
const size_t init_index =
538+
get_init_index(plan_.get_input_tensor(), plan_.get_dim_list(), out_ix);
539+
540+
const CTYPE_IN* const in_data =
541+
plan_.get_input_tensor().const_data_ptr<CTYPE_IN>();
542+
CTYPE_OUT acc_val = map_fun(in_data[init_index]);
543+
544+
if (plan_.get_input_tensor().numel() == 1) {
545+
return acc_val;
546+
}
547+
548+
plan_.execute(
549+
[&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
550+
acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val);
551+
},
552+
out_ix);
553+
return acc_val;
554+
}
555+
556+
private:
557+
ApplyOverDimListPlan plan_;
558+
};
559+
505560
/**
506561
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
507562
* for the output element at index `out_ix`, first applying the map `map_fun`
@@ -537,35 +592,8 @@ CTYPE_OUT map_reduce_over_dim_list(
537592
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
538593
dim_list,
539594
const size_t out_ix) {
540-
ET_CHECK(check_dim_list_is_valid(in, dim_list));
541-
542-
ET_CHECK_MSG(
543-
out_ix < get_out_numel(in, dim_list),
544-
"Out index %zd is out of bounds",
545-
out_ix);
546-
547-
ET_CHECK_MSG(in.numel() > 0, "Input tensor must be nonempty");
548-
549-
const size_t init_index = get_init_index(in, dim_list, out_ix);
550-
551-
const CTYPE_IN* const in_data = in.const_data_ptr<CTYPE_IN>();
552-
CTYPE_OUT acc_val = map_fun(in_data[init_index]);
553-
554-
if (in.numel() == 1) {
555-
return acc_val;
556-
}
557-
558-
apply_over_dim_list(
559-
[&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
560-
acc_val = reduce_fun(map_fun(in_data[in_ix]), acc_val);
561-
},
562-
in,
563-
dim_list,
564-
out_ix,
565-
1,
566-
-1);
567-
568-
return acc_val;
595+
MapReduceOverDimListPlan plan(in, dim_list);
596+
return plan.execute<CTYPE_IN, CTYPE_OUT>(map_fun, reduce_fun, out_ix);
569597
}
570598

571599
/**

0 commit comments

Comments
 (0)