Skip to content

[MLIR][SCF] Add support for loop pipeline peeling for dynamic loops. #106436

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 4, 2024

Conversation

sjw36
Copy link
Contributor

@sjw36 sjw36 commented Aug 28, 2024

* Allow speculative execution and predicate results per stage.

@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2024

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: SJW (sjw36)

Changes
* Allow speculative execution and predicate results per stage.

Full diff: https://github.com/llvm/llvm-project/pull/106436.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp (+86-38)
  • (modified) mlir/test/Dialect/SCF/loop-pipelining.mlir (+39-4)
  • (modified) mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp (+2-2)
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index d8e1cc0ecef88e..95fa7c8b0ef7d5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -94,8 +94,8 @@ struct LoopPipelinerInternal {
       RewriterBase &rewriter);
   /// Emits the epilogue, this creates `maxStage - 1` part which will contain
   /// operations from stages [i; maxStage], where i is the part index.
-  void emitEpilogue(RewriterBase &rewriter,
-                    llvm::SmallVector<Value> &returnValues);
+  LogicalResult emitEpilogue(RewriterBase &rewriter,
+                             llvm::SmallVector<Value> &returnValues);
 };
 
 bool LoopPipelinerInternal::initializeLoopInfo(
@@ -133,10 +133,6 @@ bool LoopPipelinerInternal::initializeLoopInfo(
     LDBG("--no epilogue or predicate set -> BAIL");
     return false;
   }
-  if (dynamicLoop && peelEpilogue) {
-    LDBG("--dynamic loop doesn't support epilogue yet -> BAIL");
-    return false;
-  }
   std::vector<std::pair<Operation *, unsigned>> schedule;
   options.getScheduleFn(forOp, schedule);
   if (schedule.empty()) {
@@ -313,10 +309,10 @@ void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
           });
       int predicateIdx = i - stages[op];
       if (predicates[predicateIdx]) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
         newOp = predicateFn(rewriter, newOp, predicates[predicateIdx]);
         assert(newOp && "failed to predicate op.");
       }
-      rewriter.setInsertionPointAfter(newOp);
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Prologue, i);
       for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
@@ -561,6 +557,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
     }
 
     if (predicates[useStage]) {
+      OpBuilder::InsertionGuard insertGuard(rewriter);
       newOp = predicateFn(rewriter, newOp, predicates[useStage]);
       if (!newOp)
         return failure();
@@ -568,7 +565,6 @@ LogicalResult LoopPipelinerInternal::createKernel(
       for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
         mapping.map(std::get<0>(values), std::get<1>(values));
     }
-    rewriter.setInsertionPointAfter(newOp);
     if (annotateFn)
       annotateFn(newOp, PipeliningOption::PipelinerPart::Kernel, 0);
   }
@@ -640,70 +636,121 @@ LogicalResult LoopPipelinerInternal::createKernel(
   return success();
 }
 
-void LoopPipelinerInternal::emitEpilogue(
-    RewriterBase &rewriter, llvm::SmallVector<Value> &returnValues) {
+LogicalResult
+LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
+                                    llvm::SmallVector<Value> &returnValues) {
+  Location loc = forOp.getLoc();
   // Emit different versions of the induction variable. They will be
   // removed by dead code if not used.
+
+  // bounds_range = ub - lb
+  // total_iterations = bounds_range / step + (bounds_range % step ? 1 : 0)
+  Type t = lb.getType();
+  Value minus1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
+
+  Value const_0 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
+  Value const_1 =
+      rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
+  Value boundsRange = rewriter.create<arith::SubIOp>(loc, ub, lb);
+  Value boundsRem = rewriter.create<arith::RemUIOp>(loc, boundsRange, step);
+  Value hasRem = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne,
+                                                boundsRem, const_0);
+  Value totalIterations = rewriter.create<arith::AddIOp>(
+      loc, rewriter.create<arith::DivUIOp>(loc, boundsRange, step),
+      rewriter.create<arith::SelectOp>(loc, hasRem, const_1, const_0));
+
+  SmallVector<Value> predicates(maxStage + 1);
   for (int64_t i = 0; i < maxStage; i++) {
-    Location loc = forOp.getLoc();
-    Type t = lb.getType();
-    Value minusOne =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
-    // number of iterations = ((ub - 1) - lb) / step
-    Value totalNumIteration = rewriter.create<arith::DivUIOp>(
-        loc,
-        rewriter.create<arith::SubIOp>(
-            loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
-        step);
-    // newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
+    // iterI = total_iters - 1 - i
+    // May go negative...
     Value minusI =
         rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
+    Value iterI = rewriter.create<arith::AddIOp>(
+        loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minus1),
+        minusI);
+    // newLastIter = lb + step * iterI
     Value newlastIter = rewriter.create<arith::AddIOp>(
-        loc, lb,
-        rewriter.create<arith::MulIOp>(
-            loc, step,
-            rewriter.create<arith::AddIOp>(loc, totalNumIteration, minusI)));
+        loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
+
     setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
+
+    if (dynamicLoop) {
+      // pred = iterI >= 0
+      predicates[i + 1] = rewriter.create<arith::CmpIOp>(
+          loc, arith::CmpIPredicate::sge, iterI, const_0);
+    }
   }
+
   // Emit `maxStage - 1` epilogue part that includes operations from stages
   // [i; maxStage].
   for (int64_t i = 1; i <= maxStage; i++) {
+    SmallVector<std::pair<Value, unsigned>> returnMap(returnValues.size());
     for (Operation *op : opOrder) {
       if (stages[op] < i)
         continue;
+      unsigned currentVersion = maxStage - stages[op] + i;
+      unsigned nextVersion = currentVersion + 1;
       Operation *newOp =
           cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
             auto it = valueMapping.find(newOperand->get());
             if (it != valueMapping.end()) {
-              Value replacement = it->second[maxStage - stages[op] + i];
+              Value replacement = it->second[currentVersion];
               newOperand->set(replacement);
             }
           });
       if (annotateFn)
         annotateFn(newOp, PipeliningOption::PipelinerPart::Epilogue, i - 1);
-      for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
-        setValueMapping(op->getResult(destId), newOp->getResult(destId),
-                        maxStage - stages[op] + i);
+      if (dynamicLoop) {
+        OpBuilder::InsertionGuard insertGuard(rewriter);
+        newOp = predicateFn(rewriter, newOp, predicates[currentVersion]);
+        if (!newOp)
+          return failure();
+      }
+
+      for (auto [opRes, newRes] :
+           llvm::zip(op->getResults(), newOp->getResults())) {
+        setValueMapping(opRes, newRes, currentVersion);
         // If the value is a loop carried dependency update the loop argument
         // mapping and keep track of the last version to replace the original
         // forOp uses.
         for (OpOperand &operand :
              forOp.getBody()->getTerminator()->getOpOperands()) {
-          if (operand.get() != op->getResult(destId))
+          if (operand.get() != opRes)
             continue;
-          unsigned version = maxStage - stages[op] + i + 1;
           // If the version is greater than maxStage it means it maps to the
           // original forOp returned value.
-          if (version > maxStage) {
-            returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
-            continue;
-          }
-          setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
-                          newOp->getResult(destId), version);
+          unsigned ri = operand.getOperandNumber();
+          returnValues[ri] = newRes;
+          Value mapVal = forOp.getRegionIterArgs()[ri];
+          returnMap[ri] = std::make_pair(mapVal, currentVersion);
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, newRes, nextVersion);
+        }
+      }
+    }
+    if (dynamicLoop) {
+      // Select return values from this stage (live outs) based on predication.
+      // If the stage is valid select the peeled value, else use previous stage
+      // value.
+      for (auto pair : llvm::enumerate(returnValues)) {
+        unsigned ri = pair.index();
+        auto [mapVal, currentVersion] = returnMap[ri];
+        if (mapVal) {
+          unsigned nextVersion = currentVersion + 1;
+          Value pred = predicates[currentVersion];
+          Value prevValue = valueMapping[mapVal][currentVersion];
+          auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
+                                                        prevValue);
+          returnValues[ri] = selOp;
+          if (nextVersion <= maxStage)
+            setValueMapping(mapVal, selOp, nextVersion);
         }
       }
     }
   }
+  return success();
 }
 
 void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
@@ -760,7 +807,8 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
   if (options.peelEpilogue) {
     // 4. Emit the epilogue after the new forOp.
     rewriter.setInsertionPointAfter(newForOp);
-    pipeliner.emitEpilogue(rewriter, returnValues);
+    if (failed(pipeliner.emitEpilogue(rewriter, returnValues)))
+      return failure();
   }
   // 5. Erase the original loop and replace the uses with the epilogue output.
   if (forOp->getNumResults() > 0)
diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir
index 9687f80f5ddfc8..957dc5295c0583 100644
--- a/mlir/test/Dialect/SCF/loop-pipelining.mlir
+++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir
@@ -764,11 +764,46 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
 //      NOEPILOGUE:     memref.load %[[A]][%[[IV3]]] : memref<?xf32>
 //      NOEPILOGUE:   scf.yield %[[V2]], %[[L3]] : f32, f32
 
-// In case dynamic loop pipelining is off check that the transformation didn't
-// apply.
+// Check for predicated epilogue for dynamic loop.
 // CHECK-LABEL: dynamic_loop(
-//   CHECK-NOT:   memref.load
-//       CHECK:   scf.for
+//        CHECK:   %{{.*}}:2 = scf.for %[[ARG5:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
+//        CHECK:       memref.store %[[ARG6]], %{{.*}}[%[[ARG5]]] 
+//        CHECK:       %[[ADDF_26:.*]] = arith.addf %[[ARG7]], %{{.*}} 
+//        CHECK:       %[[MULI_27:.*]] = arith.muli %{{.*}}, %{{.*}} 
+//        CHECK:       %[[ADDI_28:.*]] = arith.addi %[[ARG5]], %[[MULI_27]] 
+//        CHECK:       %[[LOAD_29:.*]] = memref.load %{{.*}}[%[[ADDI_28]]] 
+//        CHECK:       scf.yield %[[ADDF_26]], %[[LOAD_29]] 
+//        CHECK:   }
+//        CHECK:   %[[SUBI_10:.*]] = arith.subi %{{.*}}, %{{.*}} 
+//        CHECK:   %[[REMUI_11:.*]] = arith.remui %[[SUBI_10]], %{{.*}} 
+//        CHECK:   %[[CMPI_12:.*]] = arith.cmpi ne, %[[REMUI_11]], %{{.*}} 
+//        CHECK:   %[[SELECT_13:.*]] = arith.select %[[CMPI_12]], %{{.*}}, %{{.*}} 
+//        CHECK:   %[[DIVUI_14:.*]] = arith.divui %[[SUBI_10]], %{{.*}} 
+//        CHECK:   %[[ADDI_15:.*]] = arith.addi %[[DIVUI_14]], %[[SELECT_13]] 
+//        CHECK:   %[[ADDI_16:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1 
+//        CHECK:   %[[MULI_17:.*]] = arith.muli %{{.*}}, %[[ADDI_16]] 
+//        CHECK:   %[[ADDI_18:.*]] = arith.addi %{{.*}}, %[[MULI_17]] 
+//        CHECK:   %[[CMPI_19:.*]] = arith.cmpi sge, %[[ADDI_16]], %{{.*}} 
+//        CHECK:   %[[ADDI_20:.*]] = arith.addi %[[ADDI_15]], %{{.*}}-1 
+//        CHECK:   %[[ADDI_21:.*]] = arith.addi %[[ADDI_20]], %{{.*}}-1 
+//        CHECK:   %[[MULI_22:.*]] = arith.muli %{{.*}}, %[[ADDI_21]] 
+//        CHECK:   %[[ADDI_23:.*]] = arith.addi %{{.*}}, %[[MULI_22]] 
+//        CHECK:   %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_21]], %{{.*}} 
+//        CHECK:   scf.if %[[CMPI_19]] {
+//        CHECK:     memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_23]]] 
+//        CHECK:   } else {
+//        CHECK:   }
+//        CHECK:   %[[IF_25:.*]] = scf.if %[[CMPI_24]] -> (f32) {
+//        CHECK:     %[[ADDF_26:.*]] = arith.addf %{{.*}}#1, %{{.*}} 
+//        CHECK:     scf.yield %[[ADDF_26]] 
+//        CHECK:   } else {
+//        CHECK:     scf.yield %{{.*}} 
+//        CHECK:   }
+//        CHECK:   scf.if %[[CMPI_24]] {
+//        CHECK:     memref.store %[[IF_25]], %{{.*}}[%[[ADDI_18]]] 
+//        CHECK:   } else {
+//        CHECK:   }
+//        CHECK:   return
 func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
   %cf = arith.constant 1.0 : f32
   scf.for %i0 = %lb to %ub step %step {
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 8a92d840ad1302..3ff7f9966e93da 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -214,12 +214,12 @@ struct TestSCFPipeliningPass
     RewritePatternSet patterns(&getContext());
     mlir::scf::PipeliningOption options;
     options.getScheduleFn = getSchedule;
+    options.supportDynamicLoops = true;
+    options.predicateFn = predicateOp;
     if (annotatePipeline)
       options.annotateFn = annotate;
     if (noEpiloguePeeling) {
-      options.supportDynamicLoops = true;
       options.peelEpilogue = false;
-      options.predicateFn = predicateOp;
     }
     scf::populateSCFLoopPipeliningPatterns(patterns, options);
     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));

* Allow speculative execution and predicate results per stage.
@sjw36 sjw36 force-pushed the mlir-pipeline-peel-dynamic branch from e1dcc2b to 9be66a1 Compare August 29, 2024 19:46
Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM but I'm not super familiar here; so will wait for @ThomasRaoux to finally approve.

}
}
}
if (dynamicLoop) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a test to excercise this case?

@@ -764,11 +764,44 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// NOEPILOGUE: memref.load %[[A]][%[[IV3]]] : memref<?xf32>
// NOEPILOGUE: scf.yield %[[V2]], %[[L3]] : f32, f32

// In case dynamic loop pipelining is off check that the transformation didn't
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to keep a test checking this case (not pipelining when dynamic loop support is turned off?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would need to add a new switch to TestSCFUtils and probably a new test file so we don't run all the tests again without dynamic loop support. Or perhaps add it to the annotate run? Is that acceptable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not sure if this is worth the effort to validate that the transformation is disabled. I think I'm OK if you'd want to skip it.

for (auto pair : llvm::enumerate(returnValues)) {
unsigned ri = pair.index();
auto [mapVal, currentVersion] = returnMap[ri];
if (mapVal) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to predicate all the return values? I would think that we could predicate only the values that are later used outside of the loop, otherwise it is OK to speculatively execute.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When maxStage > 2 there are multiple stages peeled. But if K is only 1 only the last stage would be executed with selected results bypassing the previous peeled stages to the loop results (which would actually be the init values).

Some results may not be used outside loop, and would be optimized away. But since we capture these as we peel each iteration, they feed to the next iteration, and the final set replaces forLoop results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I see what you mean. This will happen only for dependencies within the same stage, right? For example:

i = i+1
store(ptr, i)

If both ops are in the same stage (say: last), you need to predicate i=i+1, otherwise once you finally get to execute store, you have wrong value of i. But if i=i+1 would be in the previous stage, normal accounting for value versions will take care of it

Copy link
Contributor Author

@sjw36 sjw36 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, versioning takes care of that dependency. But a case where each stage returns a new value based on the old value, requires the select.

%result:2 = scf.for {...}
// Stage N-2
%s1 = mul %result#0, %c32
%sel1 = select %valid_stage_1, %s1, %result#0
// Stage N-1
%s2 = mul %sel1, %c32
%sel2 = select %valid_stage_2, %s2, %sel1

I will add an example test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes please include en example I'm not sure I understand in what case that would be needed.
My thinking is that if the value doesn't escape the loop then any uses of an op that was predicated should be also predicated, therefore we shouldn't need the select

Copy link
Contributor

@pawelszczerbuk pawelszczerbuk Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But a case where each stage returns a new value based on the old value, requires the select.

Could we have a check for that, instead of adding selects for all the return values? I can imagine removing them may be hard afterwards

EDIT: actually I take that back. We have spent some time with @ThomasRaoux analyzing different cases and we agree that predicates are needed for all the return values. I guess the test case for it won't hurt :) But the code looks correct!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example:

func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
  %cf0 = arith.constant 1.0 : f32
  %cf1 = arith.constant 33.0 : f32
  %cst = arith.constant 0 : index
  %res:1 = scf.for %i0 = %lb to %ub step %step iter_args (%arg0 = %cf0) -> (f32) {
    %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref<?xf32>
    %A1_elem = arith.addf %A_elem, %arg0 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 0 } : f32
    %A2_elem = arith.mulf %A1_elem, %cf1 { __test_pipelining_stage__ = 1, __test_pipelining_op_order__ = 1 } : f32
    scf.yield %A2_elem : f32
  } { __test_pipelining_loop__ }
  memref.store %res#0, %result[%cst] : memref<?xf32>
  return
}

I see now the example predicates every operation using the predicateFn, not just the side-effecting ops. So this becomes:

  func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
    %c-1 = arith.constant -1 : index
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 1.000000e+00 : f32
    %cst_1 = arith.constant 3.300000e+01 : f32
    %c0 = arith.constant 0 : index
    %0 = arith.cmpi slt, %arg2, %arg3 : index
    %1 = scf.if %0 -> (f32) {
      %13 = memref.load %arg0[%arg2] : memref<?xf32>
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %2 = arith.subi %arg3, %arg4 : index
    %3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
      %13 = arith.addf %arg7, %arg6 : f32
      %14 = arith.mulf %13, %cst_1 : f32
      %15 = arith.addi %arg5, %arg4 : index
      %16 = memref.load %arg0[%15] : memref<?xf32>
      scf.yield %14, %16 : f32, f32
    }
    %4 = arith.subi %arg3, %arg2 : index
    %5 = arith.addi %4, %arg4 : index
    %6 = arith.addi %5, %c-1 : index
    %7 = arith.divui %6, %arg4 : index
    %8 = arith.addi %7, %c-1 : index
    %9 = arith.cmpi sge, %8, %arg2 : index
    %10 = scf.if %9 -> (f32) {
      %13 = arith.addf %3#1, %3#0 : f32
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %11 = scf.if %9 -> (f32) {
      %13 = arith.mulf %10, %cst_1 : f32
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %12 = arith.select %9, %11, %3#0 : f32   /// redundant
    memref.store %12, %arg1[%c0] : memref<?xf32>
    return
  }

As you can see every operations is guarded (including ops that do not produce a loop result). And it doesn't really do speculative execution then.

If only side-effecting ops are guarded and only results are selected based on stage range, results would be:

  func.func @dynamic_loop_result(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: index, %arg3: index, %arg4: index) {
    %c-1 = arith.constant -1 : index
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 1.000000e+00 : f32
    %cst_1 = arith.constant 3.300000e+01 : f32
    %c0 = arith.constant 0 : index
    %0 = arith.cmpi slt, %arg2, %arg3 : index
    %1 = scf.if %0 -> (f32) {
      %13 = memref.load %arg0[%arg2] : memref<?xf32>
      scf.yield %13 : f32
    } else {
      scf.yield %cst : f32
    }
    %2 = arith.subi %arg3, %arg4 : index
    %3:2 = scf.for %arg5 = %arg2 to %2 step %arg4 iter_args(%arg6 = %cst_0, %arg7 = %1) -> (f32, f32) {
      %13 = arith.addf %arg7, %arg6 : f32
      %14 = arith.mulf %13, %cst_1 : f32
      %15 = arith.addi %arg5, %arg4 : index
      %16 = memref.load %arg0[%15] : memref<?xf32>
      scf.yield %14, %16 : f32, f32
    }
    %4 = arith.subi %arg3, %arg2 : index
    %5 = arith.addi %4, %arg4 : index
    %6 = arith.addi %5, %c-1 : index
    %7 = arith.divui %6, %arg4 : index
    %8 = arith.addi %7, %c-1 : index
    %9 = arith.cmpi sge, %8, %arg2 : index
    %10 = arith.addf %3#1, %3#0 : f32
    %11 = arith.mulf %10, %cst_1 : f32
    %12 = arith.select %9, %11, %3#0 : f32
    memref.store %12, %arg1[%c0] : memref<?xf32>
    return
  }

And this seems to be what the Prologue logic is doing as well (see line 343).

Copy link
Contributor

@pawelszczerbuk pawelszczerbuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This is a nice improvement!

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks really good, thanks for the improvements

@sjw36
Copy link
Contributor Author

sjw36 commented Sep 4, 2024

Thanks folks. I added a test for scf.for with results.

@antiagainst antiagainst merged commit ebf0599 into llvm:main Sep 4, 2024
6 of 7 checks passed
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Sep 4, 2024
Select epilogue results based on iteration predication
and allow speculative execution.

For instance, when pipelining with num_stages==3
```
load (0)
load(1)
local_store(0)
%res = for (0..K-1) {
  dot(i)
  load(i+2)
  local_store(i+1)
}
%d1 = dot(K-2)
local_store(K-1)
%s1 = select %valid_iteration1, %d1, %res#0
%d0 = dot(K-1)
%s0 = select %valid_iteration0, %d0, %s1
```

This mirrors upstream change
llvm/llvm-project#106436
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
Select epilogue results based on iteration predication
and allow speculative execution.

For instance, when pipelining with num_stages==3
```
load (0)
load(1)
local_store(0)
%res = for (0..K-1) {
  dot(i)
  load(i+2)
  local_store(i+1)
}
%d1 = dot(K-2)
local_store(K-1)
%s1 = select %valid_iteration1, %d1, %res#0
%d0 = dot(K-1)
%s0 = select %valid_iteration0, %d0, %s1
```

This mirrors upstream change
llvm/llvm-project#106436
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants