@@ -329,7 +329,7 @@ class ApplyOverDimListPlan {
329
329
dim_list,
330
330
const int64_t start = 0 ,
331
331
const int64_t end = -1 )
332
- : in_(in) {
332
+ : dim_list_(dim_list), in_(in) {
333
333
ET_CHECK (check_dim_list_is_valid (in, dim_list));
334
334
out_numel_ = get_out_numel (in_, dim_list);
335
335
if (in.numel () == 0 ) {
@@ -372,13 +372,22 @@ class ApplyOverDimListPlan {
372
372
fn,
373
373
in_,
374
374
is_in_dim_list_.data (),
375
- get_init_index (in_, dim_list_, out_ix),
375
+ get_init_index (in_, dim_list_. value () , out_ix),
376
376
ustart_,
377
377
uend_);
378
378
return ;
379
379
}
380
380
}
381
381
382
+ const executorch::aten::Tensor& get_input_tensor () const {
383
+ return in_;
384
+ }
385
+
386
+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
387
+ get_dim_list () const {
388
+ return dim_list_;
389
+ }
390
+
382
391
private:
383
392
// Start argument to apply_on_flat_ix_with_{stride,dim_mask}_and_base.
384
393
size_t ustart_;
@@ -396,7 +405,7 @@ class ApplyOverDimListPlan {
396
405
};
397
406
ExecutionMode mode_;
398
407
size_t out_numel_;
399
- executorch::aten::ArrayRef<int64_t > dim_list_;
408
+ executorch::aten::optional<executorch::aten:: ArrayRef<int64_t > > dim_list_;
400
409
std::array<bool , kTensorDimensionLimit > is_in_dim_list_;
401
410
const executorch::aten::Tensor& in_;
402
411
};
@@ -502,6 +511,52 @@ std::tuple<CTYPE_OUT, long> map_reduce_over_dim(
502
511
return std::tuple<CTYPE_OUT, long >{acc_val, acc_ix};
503
512
}
504
513
514
+ /* *
515
+ * Execution plan for repeated map_reduce_over_dim_list with the same
516
+ * function, input tensor, and dim_list but varying out_ix.
517
+ */
518
+ class MapReduceOverDimListPlan {
519
+ public:
520
+ MapReduceOverDimListPlan (
521
+ const executorch::aten::Tensor& in,
522
+ const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
523
+ dim_list)
524
+ : plan_(in, dim_list, 1 , -1 ) {
525
+ ET_CHECK_MSG (in.numel () > 0 , " Input tensor must be nonempty" );
526
+ }
527
+
528
+ template <
529
+ typename CTYPE_IN,
530
+ typename CTYPE_OUT,
531
+ typename MapOp,
532
+ typename ReduceOp>
533
+ CTYPE_OUT execute (
534
+ const MapOp& map_fun,
535
+ const ReduceOp& reduce_fun,
536
+ const size_t out_ix) const {
537
+ const size_t init_index =
538
+ get_init_index (plan_.get_input_tensor (), plan_.get_dim_list (), out_ix);
539
+
540
+ const CTYPE_IN* const in_data =
541
+ plan_.get_input_tensor ().const_data_ptr <CTYPE_IN>();
542
+ CTYPE_OUT acc_val = map_fun (in_data[init_index]);
543
+
544
+ if (plan_.get_input_tensor ().numel () == 1 ) {
545
+ return acc_val;
546
+ }
547
+
548
+ plan_.execute (
549
+ [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
550
+ acc_val = reduce_fun (map_fun (in_data[in_ix]), acc_val);
551
+ },
552
+ out_ix);
553
+ return acc_val;
554
+ }
555
+
556
+ private:
557
+ ApplyOverDimListPlan plan_;
558
+ };
559
+
505
560
/* *
506
561
* Useful to reduce a tensor `in` over a given list of dimensions `dim_list`
507
562
* for the output element at index `out_ix`, first applying the map `map_fun`
@@ -537,35 +592,8 @@ CTYPE_OUT map_reduce_over_dim_list(
537
592
const executorch::aten::optional<executorch::aten::ArrayRef<int64_t >>&
538
593
dim_list,
539
594
const size_t out_ix) {
540
- ET_CHECK (check_dim_list_is_valid (in, dim_list));
541
-
542
- ET_CHECK_MSG (
543
- out_ix < get_out_numel (in, dim_list),
544
- " Out index %zd is out of bounds" ,
545
- out_ix);
546
-
547
- ET_CHECK_MSG (in.numel () > 0 , " Input tensor must be nonempty" );
548
-
549
- const size_t init_index = get_init_index (in, dim_list, out_ix);
550
-
551
- const CTYPE_IN* const in_data = in.const_data_ptr <CTYPE_IN>();
552
- CTYPE_OUT acc_val = map_fun (in_data[init_index]);
553
-
554
- if (in.numel () == 1 ) {
555
- return acc_val;
556
- }
557
-
558
- apply_over_dim_list (
559
- [&acc_val, reduce_fun, map_fun, in_data](const size_t in_ix) {
560
- acc_val = reduce_fun (map_fun (in_data[in_ix]), acc_val);
561
- },
562
- in,
563
- dim_list,
564
- out_ix,
565
- 1 ,
566
- -1 );
567
-
568
- return acc_val;
595
+ MapReduceOverDimListPlan plan (in, dim_list);
596
+ return plan.execute <CTYPE_IN, CTYPE_OUT>(map_fun, reduce_fun, out_ix);
569
597
}
570
598
571
599
/* *
0 commit comments