Skip to content

[mlir][SCF] Allow canonicalization of zero-trip count scf.forall with empty mapping. #105793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

MaheshRavishankar
Copy link
Contributor

Current folding of one-trip count loop does not kick in with an empty mapping. Enable this for empty mapping.

@llvmbot
Copy link
Member

llvmbot commented Aug 23, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir-scf

Author: None (MaheshRavishankar)

Changes

Current folding of one-trip count loop does not kick in with an empty mapping. Enable this for empty mapping.


Full diff: https://github.com/llvm/llvm-project/pull/105793.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+7-6)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+27)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e92d9503372cdf..bfa7db84bd9af7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1700,7 +1700,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
   LogicalResult matchAndRewrite(ForallOp op,
                                 PatternRewriter &rewriter) const override {
     // Do not fold dimensions if they are mapped to processing units.
-    if (op.getMapping().has_value())
+    if (op.getMapping().has_value() && !op.getMapping()->empty())
       return failure();
     Location loc = op.getLoc();
 
@@ -1729,11 +1729,6 @@ struct ForallOpSingleOrZeroIterationDimsFolder
       newMixedUpperBounds.push_back(ub);
       newMixedSteps.push_back(step);
     }
-    // Exit if none of the loop dimensions perform a single iteration.
-    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
-      return rewriter.notifyMatchFailure(
-          op, "no dimensions have 0 or 1 iterations");
-    }
 
     // All of the loop dimensions perform a single iteration. Inline loop body.
     if (newMixedLowerBounds.empty()) {
@@ -1741,6 +1736,12 @@ struct ForallOpSingleOrZeroIterationDimsFolder
       return success();
     }
 
+    // Exit if none of the loop dimensions perform a single iteration.
+    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
+      return rewriter.notifyMatchFailure(
+          op, "no dimensions have 0 or 1 iterations");
+    }
+
     // Replace the loop by a lower-dimensional loop.
     ForallOp newOp;
     newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 268946803de7a5..ff7fafac42cb4a 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1635,6 +1635,33 @@ func.func @do_not_inline_distributed_forall_loop(
 
 // -----
 
+func.func @inline_empty_loop_with_empty_mapping(
+    %in: tensor<16xf32>) -> tensor<16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<16xf32>
+  %1 = scf.forall () in () shared_outs (%out_ = %0) -> (tensor<16xf32>) {
+    %slice = tensor.extract_slice %out_[0] [16] [1]
+      : tensor<16xf32> to tensor<16xf32>
+    %generic = linalg.generic {
+        indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+        iterator_types = ["parallel"]}
+        ins(%slice : tensor<16xf32>) outs(%out_0 : tensor<16xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %2 = arith.addf %b0, %b0 : f32
+        linalg.yield %2 : f32
+    } -> tensor<16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic into %out_[0] [16] [1]
+        : tensor<16xf32> into tensor<16xf32>
+    }
+  }{ mapping = [] }
+  return %1 : tensor<16xf32>
+}
+// CHECK-LABEL: func @inline_empty_loop_with_empty_mapping
+//   CHECK-NOT:   scf.forall
+
+// -----
+
 func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
   %c8 = arith.constant 8 : index
   %c0 = arith.constant 0 : index

@llvmbot
Copy link
Member

llvmbot commented Aug 23, 2024

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

Current folding of one-trip count loop does not kick in with an empty mapping. Enable this for empty mapping.


Full diff: https://github.com/llvm/llvm-project/pull/105793.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+7-6)
  • (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+27)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index e92d9503372cdf..bfa7db84bd9af7 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -1700,7 +1700,7 @@ struct ForallOpSingleOrZeroIterationDimsFolder
   LogicalResult matchAndRewrite(ForallOp op,
                                 PatternRewriter &rewriter) const override {
     // Do not fold dimensions if they are mapped to processing units.
-    if (op.getMapping().has_value())
+    if (op.getMapping().has_value() && !op.getMapping()->empty())
       return failure();
     Location loc = op.getLoc();
 
@@ -1729,11 +1729,6 @@ struct ForallOpSingleOrZeroIterationDimsFolder
       newMixedUpperBounds.push_back(ub);
       newMixedSteps.push_back(step);
     }
-    // Exit if none of the loop dimensions perform a single iteration.
-    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
-      return rewriter.notifyMatchFailure(
-          op, "no dimensions have 0 or 1 iterations");
-    }
 
     // All of the loop dimensions perform a single iteration. Inline loop body.
     if (newMixedLowerBounds.empty()) {
@@ -1741,6 +1736,12 @@ struct ForallOpSingleOrZeroIterationDimsFolder
       return success();
     }
 
+    // Exit if none of the loop dimensions perform a single iteration.
+    if (newMixedLowerBounds.size() == static_cast<unsigned>(op.getRank())) {
+      return rewriter.notifyMatchFailure(
+          op, "no dimensions have 0 or 1 iterations");
+    }
+
     // Replace the loop by a lower-dimensional loop.
     ForallOp newOp;
     newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds,
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index 268946803de7a5..ff7fafac42cb4a 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -1635,6 +1635,33 @@ func.func @do_not_inline_distributed_forall_loop(
 
 // -----
 
+func.func @inline_empty_loop_with_empty_mapping(
+    %in: tensor<16xf32>) -> tensor<16xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<16xf32>
+  %1 = scf.forall () in () shared_outs (%out_ = %0) -> (tensor<16xf32>) {
+    %slice = tensor.extract_slice %out_[0] [16] [1]
+      : tensor<16xf32> to tensor<16xf32>
+    %generic = linalg.generic {
+        indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+        iterator_types = ["parallel"]}
+        ins(%slice : tensor<16xf32>) outs(%out_0 : tensor<16xf32>) {
+      ^bb0(%b0 : f32, %b1 : f32):
+        %2 = arith.addf %b0, %b0 : f32
+        linalg.yield %2 : f32
+    } -> tensor<16xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %generic into %out_[0] [16] [1]
+        : tensor<16xf32> into tensor<16xf32>
+    }
+  }{ mapping = [] }
+  return %1 : tensor<16xf32>
+}
+// CHECK-LABEL: func @inline_empty_loop_with_empty_mapping
+//   CHECK-NOT:   scf.forall
+
+// -----
+
 func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
   %c8 = arith.constant 8 : index
   %c0 = arith.constant 0 : index

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Aug 23, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Aug 23, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
…th empty mapping.

Current folding of one-trip count loop does not kick in with an empty
mapping. Enable this for empty mapping.

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the users/MaheshRavishankar/forall_canonicalize branch from 1094b9b to 66cee0e Compare August 23, 2024 17:21
@MaheshRavishankar MaheshRavishankar merged commit 00620ab into llvm:main Aug 23, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants