@@ -626,6 +626,28 @@ std::tuple<CTYPE, long> reduce_over_dim(
626
626
[](CTYPE v) { return v; }, reduce_fun, in, dim, out_ix);
627
627
}
628
628
629
+ /* *
630
+ * Execution plan for repeated reduce_over_dim_list with the same
631
+ * function, input tensor, and dim_list but varying out_ix.
632
+ */
633
+ class ReduceOverDimListPlan {
634
+ public:
635
+ ReduceOverDimListPlan (
636
+ const executorch::aten::Tensor& in,
637
+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
638
+ dim_list)
639
+ : plan_(in, dim_list) {}
640
+
641
+ template <typename CTYPE, typename ReduceOp>
642
+ CTYPE execute (const ReduceOp& reduce_fun, const size_t out_ix) {
643
+ return plan_.execute <CTYPE, CTYPE>(
644
+ [](CTYPE v) { return v; }, reduce_fun, out_ix);
645
+ }
646
+
647
+ private:
648
+ MapReduceOverDimListPlan plan_;
649
+ };
650
+
629
651
/* *
630
652
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
631
653
* for the output element at index `out_ix` using the reduce function
@@ -652,8 +674,8 @@ CTYPE reduce_over_dim_list(
652
674
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
653
675
dim_list,
654
676
const size_t out_ix) {
655
- return map_reduce_over_dim_list<CTYPE, CTYPE>(
656
- [](CTYPE v) { return v; }, reduce_fun, in, dim_list , out_ix);
677
+ ReduceOverDimListPlan plan (in, dim_list);
678
+ return plan. execute <CTYPE>( reduce_fun, out_ix);
657
679
}
658
680
659
681
//
0 commit comments