@@ -347,6 +347,10 @@ class ApplyOverDimListPlan {
347
347
return ;
348
348
}
349
349
dim_list_ = dim_list.value ();
350
+ if (dim_list_.value ().size () == 1 ) {
351
+ mode_ = ExecutionMode::OnlyOneDim;
352
+ return ;
353
+ }
350
354
is_in_dim_list_.fill (0 );
351
355
for (const auto & d : dim_list.value ()) {
352
356
const size_t non_neg_d = d < 0 ? d + in.dim () : d;
@@ -367,6 +371,14 @@ class ApplyOverDimListPlan {
367
371
apply_on_flat_ix_with_stride_and_base (
368
372
fn, /* stride=*/ 1 , /* base=*/ 0 , ustart_, uend_);
369
373
return ;
374
+ case ExecutionMode::OnlyOneDim:
375
+ apply_on_flat_and_dim_ix_with_stride_and_base (
376
+ [&](const auto in_ix, const auto dim_ix) { fn (in_ix); },
377
+ in_.strides ()[ET_NORMALIZE_IX (dim_list_.value ()[0 ], in_.dim ())],
378
+ get_init_index (in_, dim_list_.value (), out_ix),
379
+ ustart_,
380
+ uend_);
381
+ return ;
370
382
case ExecutionMode::NormalDimMask:
371
383
apply_on_flat_ix_with_dim_mask_and_base (
372
384
fn,
@@ -399,6 +411,9 @@ class ApplyOverDimListPlan {
399
411
// Iterate over the entire tensor with
400
412
// apply_on_flat_ix_with_stride_and_base.
401
413
NoDimMaskOrZeroDimension,
414
+ // dim_list has size 1, iterate with
415
+ // apply_on_flat_and_dim_ix_with_stride_and_base
416
+ OnlyOneDim,
402
417
// General mode, iterate with
403
418
// apply_on_flat_ix_with_dim_mask_and_base.
404
419
NormalDimMask
0 commit comments