Skip to content

[mlir][scf] Fix loop iteration calculation for negative step in LoopPipelining #110035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 25, 2024

Conversation

sjw36
Copy link
Contributor

@sjw36 sjw36 commented Sep 25, 2024

This fixes loop iteration count calculation if the step is
a negative value, where we should adjust the added
delta from `step-1` to `step+1` when doing the ceil div.

…ipelining

    This fixes loop iteration count calculation if the step is
    a negative value, where we should adjust the added
    delta from `step-1` to `step+1` when doing the ceil div.
@llvmbot
Copy link
Member

llvmbot commented Sep 25, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: SJW (sjw36)

Changes
This fixes loop iteration count calculation if the step is
a negative value, where we should adjust the added
delta from `step-1` to `step+1` when doing the ceil div.

Full diff: https://github.com/llvm/llvm-project/pull/110035.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+15-8)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+29-18)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 3d6da066875f99..83c9cf69ba0364 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -648,15 +648,22 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
   // bounds_range = ub - lb
   // total_iterations = (bounds_range + step - 1) / step
   Type t = lb.getType();
-  Value minus1 =
-      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-  Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
-  Value rangeIncr = rewriter.create<arith::AddIOp>(loc, boundsRange, step);
-  Value rangeDecr = rewriter.create<arith::AddIOp>(loc, rangeIncr, minus1);
-  Value totalIterations = rewriter.create<arith::DivUIOp>(loc, rangeDecr, step);
-
   Value zero =
       rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+  Value one =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
+  Value minusOne =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+  Value stepLessZero = rewriter.create<arith::CmpIOp>(
+      loc, arith::CmpIPredicate::slt, step, zero);
+  Value stepDecr =
+      rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
+
+  Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
+  Value rangeDecr =
+      rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
+  Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
 
   SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
@@ -665,7 +672,7 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
     Value minusI =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
     Value iterI = rewriter.create<arith::AddIOp>(
-        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
         minusI);
     // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 4747aad977a492..af49d2afc049ba 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -766,8 +766,11 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
-//        CHECK:   %[[C0:.*]] = arith.constant 0 : index
-//        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//    CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//    CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//    CHECK-DAG:   %[[CM1:.*]] = arith.constant -1 : index
+//        CHECK:   %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
+//        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
 //        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]]
 //        CHECK:       %[[ADDF_24:.*]] = arith.addf %[[ARG7]], %{{.*}}
 //        CHECK:       %[[MULI_25:.*]] = arith.muli %{{.*}}, %{{.*}}
@@ -775,15 +778,17 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //        CHECK:       %[[LOAD_27:.*]] = memref.load %{{.*}}[%[[ADDI_26]]]
 //        CHECK:       scf.yield %[[ADDF_24]], %[[LOAD_27]]
 //        CHECK:   }
-//        CHECK:   %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}}
-//        CHECK:   %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %{{.*}}
-//        CHECK:   %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %{{.*}}-1
-//        CHECK:   %[[DIVUI_13:.*]] = arith.divui %[[ADDI_12]], %{{.*}}
-//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+//        CHECK:   %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
+//        CHECK:   %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
+//        CHECK:   %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
+//        CHECK:   %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
+//        CHECK:   %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
+//        CHECK:   %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
+//        CHECK:   %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
 //        CHECK:   %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
 //        CHECK:   %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
 //        CHECK:   %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
-//        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVUI_13]], %{{.*}}-1
+//        CHECK:   %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
 //        CHECK:   %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
 //        CHECK:   %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
 //        CHECK:   %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
@@ -834,32 +839,38 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
 
 // Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL:   func.func @dynamic_loop_result
-//       CHECK:     %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//   CHECK-DAG:   %[[C1:.*]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[C0:.*]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[CM1:.*]] = arith.constant -1 : index
+//       CHECK:   %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
+//       CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
 //       CHECK:       %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
 //       CHECK:       %[[MULF_14:.*]] = arith.mulf %[[ADDF_13]], %{{.*}}
 //       CHECK:       %[[ADDI_15:.*]] = arith.addi %[[ARG5]], %{{.*}}
 //       CHECK:       %[[LOAD_16:.*]] = memref.load %{{.*}}[%[[ADDI_15]]]
 //       CHECK:       scf.yield %[[MULF_14]], %[[LOAD_16]]
 //       CHECK:     }
-//       CHECK:     %[[SUBI_4:.*]] = arith.subi %{{.*}}, %{{.*}}
-//       CHECK:     %[[ADDI_5:.*]] = arith.addi %[[SUBI_4]], %{{.*}}
-//       CHECK:     %[[ADDI_6:.*]] = arith.addi %[[ADDI_5]], %{{.*}}-1
-//       CHECK:     %[[DIVUI_7:.*]] = arith.divui %[[ADDI_6]], %{{.*}}
-//       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[DIVUI_7]], %{{.*}}-1
-//       CHECK:     %[[CMPI_9:.*]] = arith.cmpi sge, %[[ADDI_8]], %{{.*}}
-//       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_9]]
+//       CHECK:     %[[CMPI_4:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
+//       CHECK:     %[[SELECT_5:.*]] = arith.select %[[CMPI_4]], %[[C1]], %[[CM1]]
+//       CHECK:     %[[SUBI_6:.*]] = arith.subi %[[UB]], %[[LB]]
+//       CHECK:     %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
+//       CHECK:     %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
+//       CHECK:     %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
+//       CHECK:     %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
+//       CHECK:     %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
+//       CHECK:     %[[IF_10:.*]] = scf.if %[[CMPI_11]]
 //       CHECK:       %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
 //       CHECK:       scf.yield %[[ADDF_13]]
 //       CHECK:     } else {
 //       CHECK:       scf.yield %{{.*}}
 //       CHECK:     }
-//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_9]]
+//       CHECK:     %[[IF_11:.*]] = scf.if %[[CMPI_11]]
 //       CHECK:       %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
 //       CHECK:       scf.yield %[[MULF_13]]
 //       CHECK:     } else {
 //       CHECK:       scf.yield %{{.*}}
 //       CHECK:     }
-//       CHECK:     %[[SELECT_12:.*]] = arith.select %[[CMPI_9]], %[[IF_11]], %{{.*}}#0
+//       CHECK:     %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
 //       CHECK:     memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
 func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
   %cf0 = arith.constant 1.0 : f32

@antiagainst antiagainst merged commit 7645d9c into llvm:main Sep 25, 2024
9 of 10 checks passed
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Sep 27, 2024
…ipelining (llvm#110035)

This fixes loop iteration count calculation if the step is
    a negative value, where we should adjust the added
    delta from `step-1` to `step+1` when doing the ceil div.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants