Skip to content

[mlir][SCF] Use Affine ops for indexing math. #108450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 28, 2024

Conversation

MaheshRavishankar
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar commented Sep 12, 2024

For index type of induction variable, the indexing math is better represented using affine ops such as affine.delinearize_index.

This also further demonstrates that some of these affine ops might need to move to a different dialect. For one these ops only support IndexType when they should be able to work with any integer type.

This change also includes some canonicalization patterns for affine.delinearize_index operation to

  1. Drop unit basis values
  2. Remove the delinearize_index op when the linear_index is a loop induction variable of a normalized loop and the basis is of size 1 and is also the upper bound of the normalized loop.

For index type of induction variable, the indexing math is better
represented using affine ops such as `affine.delinearize_index`.

This also further demonstrates that some of these `affine` ops might
need to move to a different dialect. For one these ops only support
`IndexType` when they should be able to work with any integer type.

Signed-off-by: MaheshRavishankar <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2024

@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

For index type of induction variable, the indexing math is better represented using affine ops such as affine.delinearize_index.

This also further demonstrates that some of these affine ops might need to move to a different dialect. For one these ops only support IndexType when they should be able to work with any integer type.


Patch is 47.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108450.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Passes.td (+1-1)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/Passes.td (+1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp (+1)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+77-1)
  • (modified) mlir/test/Dialect/Affine/loop-coalescing.mlir (+118-144)
  • (modified) mlir/test/Dialect/SCF/transform-op-coalesce.mlir (+28-45)
  • (modified) mlir/test/Transforms/parallel-loop-collapsing.mlir (+2-5)
  • (modified) mlir/test/Transforms/single-parallel-loop-collapsing.mlir (+5-10)
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index 1036e93a039240..b08e803345f76e 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -394,7 +394,7 @@ def LoopCoalescing : Pass<"affine-loop-coalescing", "func::FuncOp"> {
   let summary = "Coalesce nested loops with independent bounds into a single "
                 "loop";
   let constructor = "mlir::affine::createLoopCoalescingPass()";
-  let dependentDialects = ["arith::ArithDialect"];
+  let dependentDialects = ["affine::AffineDialect","arith::ArithDialect"];
 }
 
 def SimplifyAffineStructures : Pass<"affine-simplify-structures", "func::FuncOp"> {
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 9b29affb97c432..53d1ae10dc87d8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -56,6 +56,7 @@ def SCFParallelLoopFusion : Pass<"scf-parallel-loop-fusion"> {
 def TestSCFParallelLoopCollapsing : Pass<"test-scf-parallel-loop-collapsing"> {
   let summary = "Test parallel loops collapsing transformation";
   let constructor = "mlir::createTestSCFParallelLoopCollapsingPass()";
+  let dependentDialects = ["affine::AffineDialect"];
   let description = [{
       This pass is purely for testing the scf::collapseParallelLoops
       transformation. The transformation does not have opinions on how a
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
index 6ba7020e86fa67..358a3b38a4cd32 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Dialect/SCF/Transforms/Passes.h"
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Transforms/RegionUtils.h"
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a794a121d6267b..c9f0955256b9bf 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -671,9 +672,26 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
   return success();
 }
 
+Range emitNormalizedLoopBoundsForIndexType(RewriterBase &rewriter, Location loc,
+                                           OpFoldResult lb, OpFoldResult ub,
+                                           OpFoldResult step) {
+  Range normalizedLoopBounds;
+  normalizedLoopBounds.offset = rewriter.getIndexAttr(0);
+  normalizedLoopBounds.stride = rewriter.getIndexAttr(1);
+  AffineExpr s0, s1, s2;
+  bindSymbols(rewriter.getContext(), s0, s1, s2);
+  AffineExpr e = (s1 - s0).ceilDiv(s2);
+  normalizedLoopBounds.size =
+      affine::makeComposedFoldedAffineApply(rewriter, loc, e, {lb, ub, step});
+  return normalizedLoopBounds;
+}
+
 Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
                                      OpFoldResult lb, OpFoldResult ub,
                                      OpFoldResult step) {
+  if (getType(lb) == rewriter.getIndexType()) {
+    return emitNormalizedLoopBoundsForIndexType(rewriter, loc, lb, ub, step);
+  }
   // For non-index types, generate `arith` instructions
   // Check if the loop is already known to have a constant zero lower bound or
   // a constant one step.
@@ -714,9 +732,38 @@ Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
   return {newLowerBound, newUpperBound, newStep};
 }
 
+static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
+                                                     Location loc,
+                                                     Value normalizedIv,
+                                                     OpFoldResult origLb,
+                                                     OpFoldResult origStep) {
+  AffineExpr d0, s0, s1;
+  bindSymbols(rewriter.getContext(), s0, s1);
+  bindDims(rewriter.getContext(), d0);
+  AffineExpr e = d0 * s1 + s0;
+  OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply(
+      rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
+  Value denormalizedIvVal =
+      getValueOrCreateConstantIndexOp(rewriter, loc, denormalizedIv);
+  SmallPtrSet<Operation *, 1> preservedUses;
+  // If an `affine.apply` operation is generated for denormalization, the use
+  // of `origLb` in those ops must not be replaced. These arent not generated
+  // when `orig_lb == 0` and `orig_step == 1`.
+  if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
+    if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
+      preservedUses.insert(preservedUse);
+    }
+  }
+  rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIvVal, preservedUses);
+}
+
 void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
                                         Value normalizedIv, OpFoldResult origLb,
                                         OpFoldResult origStep) {
+  if (getType(origLb) == rewriter.getIndexType()) {
+    return denormalizeInductionVariableForIndexType(rewriter, loc, normalizedIv,
+                                                    origLb, origStep);
+  }
   Value denormalizedIv;
   SmallPtrSet<Operation *, 2> preserve;
   bool isStepOne = isConstantIntValue(origStep, 1);
@@ -739,10 +786,29 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
   rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
 }
 
+static OpFoldResult getProductOfIndexes(RewriterBase &rewriter, Location loc,
+                                        ArrayRef<OpFoldResult> values) {
+  assert(!values.empty() && "unexecpted empty array");
+  AffineExpr s0, s1;
+  bindSymbols(rewriter.getContext(), s0, s1);
+  AffineExpr mul = s0 * s1;
+  OpFoldResult products = rewriter.getIndexAttr(1);
+  for (auto v : values) {
+    products = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
+  }
+  return products;
+}
+
 /// Helper function to multiply a sequence of values.
 static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
                                        ArrayRef<Value> values) {
   assert(!values.empty() && "unexpected empty list");
+  if (getType(values.front()) == rewriter.getIndexType()) {
+    SmallVector<OpFoldResult> ofrs = getAsOpFoldResult(values);
+    OpFoldResult product = getProductOfIndexes(rewriter, loc, ofrs);
+    return getValueOrCreateConstantIndexOp(rewriter, loc, product);
+  }
   std::optional<Value> productOf;
   for (auto v : values) {
     auto vOne = getConstantIntValue(v);
@@ -757,7 +823,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
   if (!productOf) {
     productOf = rewriter
                     .create<arith::ConstantOp>(
-                        loc, rewriter.getOneAttr(values.front().getType()))
+                        loc, rewriter.getOneAttr(getType(values.front())))
                     .getResult();
   }
   return productOf.value();
@@ -774,6 +840,16 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
 static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
 delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
                              Value linearizedIv, ArrayRef<Value> ubs) {
+
+  if (linearizedIv.getType() == rewriter.getIndexType()) {
+    Operation *delinearizedOp =
+        rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
+                                                          ubs);
+    auto resultVals = llvm::map_to_vector(
+        delinearizedOp->getResults(), [](OpResult r) -> Value { return r; });
+    return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}};
+  }
+
   SmallVector<Value> delinearizedIvs(ubs.size());
   SmallPtrSet<Operation *, 2> preservedUsers;
 
diff --git a/mlir/test/Dialect/Affine/loop-coalescing.mlir b/mlir/test/Dialect/Affine/loop-coalescing.mlir
index 45dd299295f640..f6e7b21bc66aba 100644
--- a/mlir/test/Dialect/Affine/loop-coalescing.mlir
+++ b/mlir/test/Dialect/Affine/loop-coalescing.mlir
@@ -1,14 +1,15 @@
-// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -affine-loop-coalescing --cse %s | FileCheck %s
+// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -affine-loop-coalescing --cse --mlir-print-local-scope %s | FileCheck %s
 
 // CHECK-LABEL: @one_3d_nest
 func.func @one_3d_nest() {
   // Capture original bounds.  Note that for zero-based step-one loops, the
   // upper bound is also the number of iterations.
-  // CHECK: %[[orig_lb:.*]] = arith.constant 0
-  // CHECK: %[[orig_step:.*]] = arith.constant 1
-  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
-  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
-  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1
+  // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56
+  // CHECK-DAG: %[[range:.*]] = arith.constant 7056
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c2 = arith.constant 2 : index
@@ -16,9 +17,6 @@ func.func @one_3d_nest() {
   %c42 = arith.constant 42 : index
   %c56 = arith.constant 56 : index
   // The range of the new scf.
-  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
-  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
-
   // Updated loop bounds.
   // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]]
   scf.for %i = %c0 to %c42 step %c1 {
@@ -26,13 +24,11 @@ func.func @one_3d_nest() {
     // CHECK-NOT: scf.for
 
     // Reconstruct original IVs from the linearized one.
-    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
-    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
-    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
-    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]]
+    // CHECK-SAME: into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]])
     scf.for %j = %c0 to %c56 step %c1 {
       scf.for %k = %c0 to %c3 step %c1 {
-        // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2)
         "use"(%i, %j, %k) : (index, index, index) -> ()
       }
     }
@@ -40,6 +36,8 @@ func.func @one_3d_nest() {
   return
 }
 
+// -----
+
 // Check that there is no chasing the replacement of value uses by ensuring
 // multiple uses of loop induction variables get rewritten to the same values.
 
@@ -52,13 +50,10 @@ func.func @multi_use() {
   scf.for %i = %c1 to %c10 step %c1 {
     scf.for %j = %c1 to %c10 step %c1 {
       scf.for %k = %c1 to %c10 step %c1 {
-        // CHECK: %[[k_unshifted:.*]] = arith.remsi %[[iv]], %[[k_extent:.*]]
-        // CHECK: %[[ij:.*]] = arith.divsi %[[iv]], %[[k_extent]]
-        // CHECK: %[[j_unshifted:.*]] = arith.remsi %[[ij]], %[[j_extent:.*]]
-        // CHECK: %[[i_unshifted:.*]] = arith.divsi %[[ij]], %[[j_extent]]
-        // CHECK: %[[k:.*]] = arith.addi %[[k_unshifted]]
-        // CHECK: %[[j:.*]] = arith.addi %[[j_unshifted]]
-        // CHECK: %[[i:.*]] = arith.addi %[[i_unshifted]]
+      	// CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[iv]]
+        // CHECK: %[[k:.*]] = affine.apply affine_map<(d0) -> (d0 + 1)>(%[[delinearize]]#2)
+        // CHECK: %[[j:.*]] = affine.apply affine_map<(d0) -> (d0 + 1)>(%[[delinearize]]#1)
+        // CHECK: %[[i:.*]] = affine.apply affine_map<(d0) -> (d0 + 1)>(%[[delinearize]]#0)
 
         // CHECK: "use1"(%[[i]], %[[j]], %[[k]])
         "use1"(%i,%j,%k) : (index,index,index) -> ()
@@ -72,12 +67,20 @@ func.func @multi_use() {
   return
 }
 
+// -----
+
 func.func @unnormalized_loops() {
-  // CHECK: %[[orig_step_i:.*]] = arith.constant 2
+  // Normalized lower bound and step for the outer scf.
+  // CHECK-DAG: %[[lb_i:.*]] = arith.constant 0
+  // CHECK-DAG: %[[step_i:.*]] = arith.constant 1
+  // CHECK-DAG: %[[orig_step_j_and_numiter_i:.*]] = arith.constant 3
+
+  // Number of iterations in the inner loop, the pattern is the same as above,
+  // only capture the final result.
+  // CHECK-DAG: %[[numiter_j:.*]] = arith.constant 4
+
+  // CHECK-DAG: %[[range:.*]] = arith.constant 12
 
-  // CHECK: %[[orig_step_j_and_numiter_i:.*]] = arith.constant 3
-  // CHECK: %[[orig_lb_i:.*]] = arith.constant 5
-  // CHECK: %[[orig_lb_j:.*]] = arith.constant 7
   %c2 = arith.constant 2 : index
   %c3 = arith.constant 3 : index
   %c5 = arith.constant 5 : index
@@ -85,28 +88,18 @@ func.func @unnormalized_loops() {
   %c10 = arith.constant 10 : index
   %c17 = arith.constant 17 : index
 
-  // Normalized lower bound and step for the outer scf.
-  // CHECK: %[[lb_i:.*]] = arith.constant 0
-  // CHECK: %[[step_i:.*]] = arith.constant 1
-
-  // Number of iterations in the inner loop, the pattern is the same as above,
-  // only capture the final result.
-  // CHECK: %[[numiter_j:.*]] = arith.constant 4
 
   // New bounds of the outer scf.
-  // CHECK: %[[range:.*]] = arith.muli %[[orig_step_j_and_numiter_i:.*]], %[[numiter_j]]
   // CHECK: scf.for %[[i:.*]] = %[[lb_i]] to %[[range]] step %[[step_i]]
   scf.for %i = %c5 to %c10 step %c2 {
     // The inner loop has been removed.
     // CHECK-NOT: scf.for
     scf.for %j = %c7 to %c17 step %c3 {
       // The IVs are rewritten.
-      // CHECK: %[[normalized_j:.*]] = arith.remsi %[[i]], %[[numiter_j]]
-      // CHECK: %[[normalized_i:.*]] = arith.divsi %[[i]], %[[numiter_j]]
-      // CHECK: %[[scaled_j:.*]] = arith.muli %[[normalized_j]], %[[orig_step_j_and_numiter_i]]
-      // CHECK: %[[orig_j:.*]] = arith.addi %[[scaled_j]], %[[orig_lb_j]]
-      // CHECK: %[[scaled_i:.*]] = arith.muli %[[normalized_i]], %[[orig_step_i]]
-      // CHECK: %[[orig_i:.*]] = arith.addi %[[scaled_i]], %[[orig_lb_i]]
+      // CHECK: %[[delinearize:.+]]:2 = affine.delinearize_index %[[i]]
+      // CHECK-SAME: into (%[[orig_step_j_and_numiter_i]], %[[numiter_j]])
+      // CHECK: %[[orig_j:.*]] = affine.apply affine_map<(d0) -> (d0 * 3 + 7)>(%[[delinearize]]#1)
+      // CHECK: %[[orig_i:.*]] = affine.apply affine_map<(d0) -> (d0 * 2 + 5)>(%[[delinearize]]#0)
       // CHECK: "use"(%[[orig_i]], %[[orig_j]])
       "use"(%i, %j) : (index, index) -> ()
     }
@@ -114,20 +107,21 @@ func.func @unnormalized_loops() {
   return
 }
 
+// -----
+
 func.func @noramalized_loops_with_yielded_iter_args() {
-  // CHECK: %[[orig_lb:.*]] = arith.constant 0
-  // CHECK: %[[orig_step:.*]] = arith.constant 1
-  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
-  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
-  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1
+  // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56
+  // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK-DAG: %[[range:.*]] = arith.constant 7056
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
   %c42 = arith.constant 42 : index
   %c56 = arith.constant 56 : index
   // The range of the new scf.
-  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
-  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
 
   // Updated loop bounds.
   // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]]) -> (index) {
@@ -136,13 +130,10 @@ func.func @noramalized_loops_with_yielded_iter_args() {
     // CHECK-NOT: scf.for
 
     // Reconstruct original IVs from the linearized one.
-    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
-    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
-    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
-    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]] into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]])
     %1:1 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg1 = %arg0) -> (index){
       %0:1 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg2 = %arg1) -> (index) {
-        // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2)
         "use"(%i, %j, %k) : (index, index, index) -> ()
         // CHECK: scf.yield %[[VAL_1]] : index
         scf.yield %arg2 : index
@@ -154,20 +145,21 @@ func.func @noramalized_loops_with_yielded_iter_args() {
   return
 }
 
+// -----
+
 func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
-  // CHECK: %[[orig_lb:.*]] = arith.constant 0
-  // CHECK: %[[orig_step:.*]] = arith.constant 1
-  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
-  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
-  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1
+  // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK-DAG: %[[orig_ub_j:.*]] = arith.constant 56
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c3 = arith.constant 3 : index
   %c42 = arith.constant 42 : index
   %c56 = arith.constant 56 : index
   // The range of the new scf.
-  // CHECK:     %[[partial_range:.*]] = arith.muli %[[orig_ub_i]], %[[orig_ub_j]]
-  // CHECK-NEXT:%[[range:.*]] = arith.muli %[[partial_range]], %[[orig_ub_k]]
+  // CHECK-DAG:%[[range:.*]] = arith.constant 7056
 
   // Updated loop bounds.
   // CHECK: scf.for %[[i:.*]] = %[[orig_lb]] to %[[range]] step %[[orig_step]] iter_args(%[[VAL_1:.*]] = %[[orig_lb]], %[[VAL_2:.*]] = %[[orig_lb]]) -> (index, index) {
@@ -176,13 +168,11 @@ func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
     // CHECK-NOT: scf.for
 
     // Reconstruct original IVs from the linearized one.
-    // CHECK: %[[orig_k:.*]] = arith.remsi %[[i]], %[[orig_ub_k]]
-    // CHECK: %[[div:.*]] = arith.divsi %[[i]], %[[orig_ub_k]]
-    // CHECK: %[[orig_j:.*]] = arith.remsi %[[div]], %[[orig_ub_j]]
-    // CHECK: %[[orig_i:.*]] = arith.divsi %[[div]], %[[orig_ub_j]]
+    // CHECK: %[[delinearize:.+]]:3 = affine.delinearize_index %[[i]]
+    // CHECK-SAME: into (%[[orig_ub_i]], %[[orig_ub_j]], %[[orig_ub_k]])
     %1:2 = scf.for %j = %c0 to %c56 step %c1 iter_args(%arg2 = %arg0, %arg3 = %arg1) -> (index, index){
       %0:2 = scf.for %k = %c0 to %c3 step %c1 iter_args(%arg4 = %arg2, %arg5 = %arg3) -> (index, index) {
-        // CHECK: "use"(%[[orig_i]], %[[orig_j]], %[[orig_k]])
+        // CHECK: "use"(%[[delinearize]]#0, %[[delinearize]]#1, %[[delinearize]]#2)
         "use"(%i, %j, %k) : (index, index, index) -> ()
         // CHECK: scf.yield %[[VAL_2]], %[[VAL_1]] : index, index
         scf.yield %arg5, %arg4 : index, index
@@ -194,20 +184,21 @@ func.func @noramalized_loops_with_shuffled_yielded_iter_args() {
   return
 }
 
+// -----
+
 func.func @noramalized_loops_with_yielded_non_iter_args() {
-  // CHECK: %[[orig_lb:.*]] = arith.constant 0
-  // CHECK: %[[orig_step:.*]] = arith.constant 1
-  // CHECK: %[[orig_ub_k:.*]] = arith.constant 3
-  // CHECK: %[[orig_ub_i:.*]] = arith.constant 42
-  // CHECK: %[[orig_ub_j:.*]] = arith.constant 56
+  // CHECK-DAG: %[[orig_lb:.*]] = arith.constant 0
+  // CHECK-DAG: %[[orig_step:.*]] = arith.constant 1
+  // CHECK-DAG: %[[orig_ub_k:.*]] = arith.constant 3
+  // CHECK-DAG: %[[orig_ub_i:.*]] = arith.constant 42
+  // CHECK-DAG: %[[orig_ub_j:.*...
[truncated]

@MaheshRavishankar
Copy link
Contributor Author

THere are two commits in this PR. They could be reviewed independently.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, just some nits

@MaheshRavishankar MaheshRavishankar merged commit cca3217 into llvm:main Sep 28, 2024
8 checks passed
bjacob added a commit that referenced this pull request Sep 30, 2024
This fixes `cast`'s that were introduced in
#108450.

Signed-off-by: Benoit Jacob <[email protected]>
bjacob added a commit that referenced this pull request Oct 1, 2024
#110518 fixed assertion
failures in `cast` introduced in
#108450.

Signed-off-by: Benoit Jacob <[email protected]>
puja2196 pushed a commit to puja2196/LLVM-tutorial that referenced this pull request Oct 2, 2024
This fixes `cast`'s that were introduced in
llvm/llvm-project#108450.

Signed-off-by: Benoit Jacob <[email protected]>
puja2196 pushed a commit to puja2196/LLVM-tutorial that referenced this pull request Oct 2, 2024
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
llvm#110518 fixed assertion
failures in `cast` introduced in
llvm#108450.

Signed-off-by: Benoit Jacob <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants