Skip to content

Commit 5fa637a

Browse files
authored
Refactor: ReduceOverDimListPlan (#9109)
Now we can plan-then-execute for all reductions that use dim lists.
1 parent 8dd8eb1 commit 5fa637a

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

kernels/portable/cpu/util/reduce_util.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,28 @@ std::tuple<CTYPE, long> reduce_over_dim(
626626
[](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix);
627627
}
628628

629+
/**
630+
* Execution plan for repeated reduce_over_dim_list with the same
631+
* function, input tensor, and dim_list but varying out_ix.
632+
*/
633+
class ReduceOverDimListPlan {
634+
public:
635+
ReduceOverDimListPlan(
636+
const executorch::aten::Tensor& in,
637+
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
638+
dim_list)
639+
: plan_(in, dim_list) {}
640+
641+
template <typename CTYPE, typename ReduceOp>
642+
CTYPE execute(const ReduceOp& reduce_fun, const size_t out_ix) {
643+
return plan_.execute<CTYPE, CTYPE>(
644+
[](CTYPE v) { return v; }, reduce_fun, out_ix);
645+
}
646+
647+
private:
648+
MapReduceOverDimListPlan plan_;
649+
};
650+
629651
/**
630652
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
631653
* for the output element at index `out_ix` using the reduce function
@@ -652,8 +674,8 @@ CTYPE reduce_over_dim_list(
652674
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t>>&
653675
dim_list,
654676
const size_t out_ix) {
655-
return map_reduce_over_dim_list<CTYPE, CTYPE>(
656-
[](CTYPE v) { return v; }, reduce_fun, in, dim_list, out_ix);
677+
ReduceOverDimListPlan plan(in, dim_list);
678+
return plan.execute<CTYPE>(reduce_fun, out_ix);
657679
}
658680

659681
//

0 commit comments

Comments
 (0)