Skip to content

Commit 7645d9c

Browse files
authored
[mlir][scf] Fix loop iteration calculation for negative step in LoopPipelining (#110035)
This fixes loop iteration count calculation if the step is a negative value, where we should adjust the added delta from `step-1` to `step+1` when doing the ceil div.
1 parent 29b92d0 commit 7645d9c

File tree

2 files changed

+44
-26
lines changed

2 files changed

+44
-26
lines changed

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -648,15 +648,22 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
648648
// bounds_range = ub - lb
649649
// total_iterations = (bounds_range + step - 1) / step
650650
Type t = lb.getType();
651-
Value minus1 =
652-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
653-
Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
654-
Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
655-
Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
656-
Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
657-
658651
Value zero =
659652
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
653+
Value one =
654+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
655+
Value minusOne =
656+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
657+
Value stepLessZero = rewriter.create<arith::CmpIOp>(
658+
loc, arith::CmpIPredicate::slt, step, zero);
659+
Value stepDecr =
660+
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
661+
662+
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
663+
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
664+
Value rangeDecr =
665+
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
666+
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
660667

661668
SmallVector<Value> predicates(maxStage + 1);
662669
for (int64_t i = 0; i < maxStage; i++) {
@@ -665,7 +672,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
665672
Value minusI =
666673
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
667674
Value iterI = rewriter.create<arith::AddIOp>(
668-
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
675+
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
669676
minusI);
670677
// newLastIter = lb + step * iterI
671678
Value newlastIter = rewriter.create<arith::AddIOp>(

mlir/test/Dialect/SCF/loop-pipelining.mlir

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -766,24 +766,29 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
766766

767767
// Check for predicated epilogue for dynamic loop.
768768
// CHECK-LABEL: dynamic_loop(
769-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
770-
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
769+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
770+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
771+
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
772+
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
773+
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
771774
// CHECK: memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
772775
// CHECK: %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
773776
// CHECK: %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
774777
// CHECK: %[[ADDI_26:.*]] = arith.addi %[[ARG5]], %[[MULI_25]]
775778
// CHECK: %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
776779
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
777780
// CHECK: }
778-
// CHECK: %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
779-
// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}}
780-
// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
781-
// CHECK: %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
782-
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
781+
// CHECK: %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
782+
// CHECK: %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
783+
// CHECK: %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
784+
// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
785+
// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
786+
// CHECK: %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
787+
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
783788
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
784789
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
785790
// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
786-
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
791+
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
787792
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
788793
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
789794
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
@@ -834,32 +839,38 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
834839

835840
// Check for predicated epilogue for dynamic loop.
836841
// CHECK-LABEL: func.func @dynamic_loop_result
837-
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
842+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
843+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
844+
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
845+
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
846+
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
838847
// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
839848
// CHECK: %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}
840849
// CHECK: %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}}
841850
// CHECK: %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]]
842851
// CHECK: scf.yield %[[MULF_14]], %[[LOAD_16]]
843852
// CHECK: }
844-
// CHECK: %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}}
845-
// CHECK: %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}}
846-
// CHECK: %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
847-
// CHECK: %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
848-
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
849-
// CHECK: %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
850-
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_9]]
853+
// CHECK: %[[CMPI_4:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
854+
// CHECK: %[[SELECT_5:.*]] = arith.select %[[CMPI_4]], %[[C1]], %[[CM1]]
855+
// CHECK: %[[SUBI_6:.*]] = arith.subi %[[UB]], %[[LB]]
856+
// CHECK: %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
857+
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
858+
// CHECK: %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
859+
// CHECK: %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
860+
// CHECK: %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
861+
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_11]]
851862
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
852863
// CHECK: scf.yield %[[ADDF_13]]
853864
// CHECK: } else {
854865
// CHECK: scf.yield %{{.*}}
855866
// CHECK: }
856-
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_9]]
867+
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_11]]
857868
// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
858869
// CHECK: scf.yield %[[MULF_13]]
859870
// CHECK: } else {
860871
// CHECK: scf.yield %{{.*}}
861872
// CHECK: }
862-
// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0
873+
// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
863874
// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
864875
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
865876
%cf0 = arith.constant 1.0 : f32

0 commit comments

Comments
 (0)