Skip to content

[MLIR][SROA] Replace pattern based approach with a one-shot one #85437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 18, 2024

Conversation

Dinistro
Copy link
Contributor

This commit changes MLIR's SROA implementation back from being pattern based into a full pass. This is beneficial for upcoming changes that rely more heavily on the datalayout.

Unfortunately, this change required substantial test changes, as the IRBuilder no cleans up the IR.

This commit changes MLIR's SROA implementation back from being pattern
based into a full pass. This is beneficial for upcomming changes that
rely more heavily on the datalayout.
@Dinistro Dinistro requested review from gysit and Moxinilian March 15, 2024 17:38
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:llvm mlir labels Mar 15, 2024
@llvmbot
Copy link
Member

llvmbot commented Mar 15, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Christian Ulmann (Dinistro)

Changes

This commit changes MLIR's SROA implementation back from being pattern based into a full pass. This is beneficial for upcoming changes that rely more heavily on the datalayout.

Unfortunately, this change required substantial test changes, as the IRBuilder no cleans up the IR.


Full diff: https://github.com/llvm/llvm-project/pull/85437.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Transforms/SROA.h (-20)
  • (modified) mlir/lib/Transforms/SROA.cpp (+28-13)
  • (modified) mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir (+57-47)
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index 0b3e72400c2873..1af1fe930723f1 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -9,11 +9,9 @@
 #ifndef MLIR_TRANSFORMS_SROA_H
 #define MLIR_TRANSFORMS_SROA_H
 
-#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/Statistic.h"
-#include <variant>
 
 namespace mlir {
 
@@ -29,24 +27,6 @@ struct SROAStatistics {
   llvm::Statistic *maxSubelementAmount = nullptr;
 };
 
-/// Pattern applying SROA to the regions of the operations on which it
-/// matches.
-class SROAPattern
-    : public OpInterfaceRewritePattern<DestructurableAllocationOpInterface> {
-public:
-  using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
-
-  SROAPattern(MLIRContext *context, SROAStatistics statistics = {},
-              PatternBenefit benefit = 1)
-      : OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
-
-  LogicalResult matchAndRewrite(DestructurableAllocationOpInterface allocator,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  SROAStatistics statistics;
-};
-
 /// Attempts to destructure the slots of destructurable allocators. Returns
 /// failure if no slot was destructured.
 LogicalResult tryToDestructureMemorySlots(
diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp
index 3ceda51e1c894c..977714e75373cf 100644
--- a/mlir/lib/Transforms/SROA.cpp
+++ b/mlir/lib/Transforms/SROA.cpp
@@ -9,7 +9,6 @@
 #include "mlir/Transforms/SROA.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Interfaces/MemorySlotInterfaces.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "mlir/Transforms/Passes.h"
 
 namespace mlir {
@@ -205,13 +204,6 @@ LogicalResult mlir::tryToDestructureMemorySlots(
   return success(destructuredAny);
 }
 
-LogicalResult
-SROAPattern::matchAndRewrite(DestructurableAllocationOpInterface allocator,
-                             PatternRewriter &rewriter) const {
-  hasBoundedRewriteRecursion();
-  return tryToDestructureMemorySlots({allocator}, rewriter, statistics);
-}
-
 namespace {
 
 struct SROA : public impl::SROABase<SROA> {
@@ -223,12 +215,35 @@ struct SROA : public impl::SROABase<SROA> {
     SROAStatistics statistics{&destructuredAmount, &slotsWithMemoryBenefit,
                               &maxSubelementAmount};
 
-    RewritePatternSet rewritePatterns(&getContext());
-    rewritePatterns.add<SROAPattern>(&getContext(), statistics);
-    FrozenRewritePatternSet frozen(std::move(rewritePatterns));
+    bool changed = false;
+
+    for (Region &region : scopeOp->getRegions()) {
+      if (region.getBlocks().empty())
+        continue;
 
-    if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen)))
-      signalPassFailure();
+      OpBuilder builder(&region.front(), region.front().begin());
+      IRRewriter rewriter(builder);
+
+      // Destructuring a slot can allow for further destructuring of other
+      // slots, destructuring is tried until no destructuring succeeds.
+      while (true) {
+        SmallVector<DestructurableAllocationOpInterface> allocators;
+        // Build a list of allocators to attempt to destructure the slots of.
+        for (Block &block : region)
+          for (Operation &op : block.getOperations())
+            if (auto allocator =
+                    dyn_cast<DestructurableAllocationOpInterface>(op))
+              allocators.emplace_back(allocator);
+
+        if (failed(
+                tryToDestructureMemorySlots(allocators, rewriter, statistics)))
+          break;
+
+        changed = true;
+      }
+    }
+    if (!changed)
+      markAllAnalysesPreserved();
   }
 };
 
diff --git a/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir b/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
index c2e3458134ba4b..04c4af1a3e79a8 100644
--- a/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
+++ b/mlir/test/Dialect/LLVMIR/sroa-intrinsics.mlir
@@ -146,12 +146,10 @@ llvm.func @invalid_indirect_memset() -> i32 {
 
 // CHECK-LABEL: llvm.func @memset_double_use
 llvm.func @memset_double_use() -> i32 {
-  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
-  // CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
-  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
-  // CHECK-DAG: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
+  // CHECK: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x !llvm.struct<"foo", (i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
   %memset_value = llvm.mlir.constant(42 : i8) : i8
@@ -159,8 +157,11 @@ llvm.func @memset_double_use() -> i32 {
   %memset_len = llvm.mlir.constant(8 : i32) : i32
   // We expect two generated memset, one for each field.
   // CHECK-NOT: "llvm.intr.memset"
-  // CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
-  // CHECK-DAG: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  // CHECK: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[MEMSET_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN]]) <{isVolatile = false}>
   // CHECK-NOT: "llvm.intr.memset"
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
   %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<"foo", (i32, f32)>
@@ -208,13 +209,10 @@ llvm.func @memset_considers_alignment() -> i32 {
 
 // CHECK-LABEL: llvm.func @memset_considers_packing
 llvm.func @memset_considers_packing() -> i32 {
-  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
-  // CHECK-DAG: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
-  // CHECK-DAG: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
-  // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
-  // CHECK-DAG: %[[MEMSET_LEN_WHOLE:.*]] = llvm.mlir.constant(4 : i32) : i32
-  // CHECK-DAG: %[[MEMSET_LEN_PARTIAL:.*]] = llvm.mlir.constant(3 : i32) : i32
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA_FLOAT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x f32
+  // CHECK: %[[ALLOCA_INT:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK: %[[MEMSET_VALUE:.*]] = llvm.mlir.constant(42 : i8) : i8
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x !llvm.struct<"foo", packed (i8, i32, f32)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
   %memset_value = llvm.mlir.constant(42 : i8) : i8
@@ -222,7 +220,10 @@ llvm.func @memset_considers_packing() -> i32 {
   %memset_len = llvm.mlir.constant(8 : i32) : i32
   // Now all fields are touched by the memset.
   // CHECK-NOT: "llvm.intr.memset"
+  // After SROA, only 32-bit values will be actually used, so only 4 bytes will be set.
+  // CHECK: %[[MEMSET_LEN_WHOLE:.*]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: "llvm.intr.memset"(%[[ALLOCA_INT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_WHOLE]]) <{isVolatile = false}>
+  // CHECK: %[[MEMSET_LEN_PARTIAL:.*]] = llvm.mlir.constant(3 : i32) : i32
   // CHECK: "llvm.intr.memset"(%[[ALLOCA_FLOAT]], %[[MEMSET_VALUE]], %[[MEMSET_LEN_PARTIAL]]) <{isVolatile = false}>
   // CHECK-NOT: "llvm.intr.memset"
   "llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
@@ -241,14 +242,14 @@ llvm.func @memset_considers_packing() -> i32 {
 // CHECK-LABEL: llvm.func @memcpy_dest
 // CHECK-SAME: (%[[OTHER_ARRAY:.*]]: !llvm.ptr)
 llvm.func @memcpy_dest(%other_array: !llvm.ptr) -> i32 {
-  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
-  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
-  // CHECK-DAG: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x !llvm.array<10 x i32> : (i32) -> !llvm.ptr
   %memcpy_len = llvm.mlir.constant(40 : i32) : i32
   // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  // CHECK: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: "llvm.intr.memcpy"(%[[ALLOCA]], %[[SLOT_IN_OTHER]], %[[MEMCPY_LEN]]) <{isVolatile = false}>
   "llvm.intr.memcpy"(%1, %other_array, %memcpy_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
   %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
@@ -261,9 +262,8 @@ llvm.func @memcpy_dest(%other_array: !llvm.ptr) -> i32 {
 // CHECK-LABEL: llvm.func @memcpy_src
 // CHECK-SAME: (%[[OTHER_ARRAY:.*]]: !llvm.ptr)
 llvm.func @memcpy_src(%other_array: !llvm.ptr) -> i32 {
-  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
   // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
-  // CHECK-DAG: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK-COUNT-4: = llvm.alloca %[[ALLOCA_LEN]] x i32
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x !llvm.array<4 x i32> : (i32) -> !llvm.ptr
@@ -271,14 +271,18 @@ llvm.func @memcpy_src(%other_array: !llvm.ptr) -> i32 {
   // Unfortunately because of FileCheck limitations it is not possible to check which slot gets read from.
   // We can only check that the amount of operations and allocated slots is correct, which should be sufficient
   // as unused slots are not generated.
-  // CHECK-DAG: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
-  // CHECK-DAG: "llvm.intr.memcpy"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
-  // CHECK-DAG: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
-  // CHECK-DAG: "llvm.intr.memcpy"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
-  // CHECK-DAG: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
-  // CHECK-DAG: "llvm.intr.memcpy"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
-  // CHECK-DAG: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
-  // CHECK-DAG: "llvm.intr.memcpy"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
+  // CHECK: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memcpy"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
+  // CHECK: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memcpy"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
+  // CHECK: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memcpy"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
+  // CHECK: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memcpy"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
   "llvm.intr.memcpy"(%other_array, %1, %memcpy_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
   %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
   %3 = llvm.load %2 : !llvm.ptr -> i32
@@ -289,14 +293,18 @@ llvm.func @memcpy_src(%other_array: !llvm.ptr) -> i32 {
 
 // CHECK-LABEL: llvm.func @memcpy_double
 llvm.func @memcpy_double() -> i32 {
-  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-COUNT-2: = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // CHECK: = llvm.alloca %[[ALLOCA_LEN]] x i32
+  // TODO: This should also disappear.
+  // CHECK: = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<1 x i32>
   %1 = llvm.alloca %0 x !llvm.array<1 x i32> : (i32) -> !llvm.ptr
   %2 = llvm.alloca %0 x !llvm.array<1 x i32> : (i32) -> !llvm.ptr
+  // Match the dead constant, to avoid collision with the newly created one.
+  // CHECK: llvm.mlir.constant
   %memcpy_len = llvm.mlir.constant(4 : i32) : i32
   // CHECK-NOT: "llvm.intr.memcpy"
+  // CHECK: %[[MEMCPY_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: "llvm.intr.memcpy"(%{{.*}}, %{{.*}}, %[[MEMCPY_LEN]]) <{isVolatile = false}>
   // CHECK-NOT: "llvm.intr.memcpy"
   "llvm.intr.memcpy"(%1, %2, %memcpy_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
@@ -346,14 +354,14 @@ llvm.func @memcpy_no_volatile(%other_array: !llvm.ptr) -> i32 {
 // CHECK-LABEL: llvm.func @memmove_dest
 // CHECK-SAME: (%[[OTHER_ARRAY:.*]]: !llvm.ptr)
 llvm.func @memmove_dest(%other_array: !llvm.ptr) -> i32 {
-  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-DAG: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
-  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
-  // CHECK-DAG: %[[MEMMOVE_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: %[[ALLOCA:.*]] = llvm.alloca %[[ALLOCA_LEN]] x i32
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x !llvm.array<10 x i32> : (i32) -> !llvm.ptr
   %memmove_len = llvm.mlir.constant(40 : i32) : i32
   // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
+  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
+  // CHECK: %[[MEMMOVE_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
   // CHECK: "llvm.intr.memmove"(%[[ALLOCA]], %[[SLOT_IN_OTHER]], %[[MEMMOVE_LEN]]) <{isVolatile = false}>
   "llvm.intr.memmove"(%1, %other_array, %memmove_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
   %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<10 x i32>
@@ -366,9 +374,7 @@ llvm.func @memmove_dest(%other_array: !llvm.ptr) -> i32 {
 // CHECK-LABEL: llvm.func @memmove_src
 // CHECK-SAME: (%[[OTHER_ARRAY:.*]]: !llvm.ptr)
 llvm.func @memmove_src(%other_array: !llvm.ptr) -> i32 {
-  // CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
-  // After SROA, only one i32 will be actually used, so only 4 bytes will be set.
-  // CHECK-DAG: %[[MEMMOVE_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK-COUNT-4: = llvm.alloca %[[ALLOCA_LEN]] x i32
   %0 = llvm.mlir.constant(1 : i32) : i32
   %1 = llvm.alloca %0 x !llvm.array<4 x i32> : (i32) -> !llvm.ptr
@@ -376,14 +382,18 @@ llvm.func @memmove_src(%other_array: !llvm.ptr) -> i32 {
   // Unfortunately because of FileCheck limitations it is not possible to check which slot gets read from.
   // We can only check that the amount of operations and allocated slots is correct, which should be sufficient
   // as unused slots are not generated.
-  // CHECK-DAG: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
-  // CHECK-DAG: "llvm.intr.memmove"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMMOVE_LEN]]) <{isVolatile = false}>
-  // CHECK-DAG: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
-  // CHECK-DAG: "llvm.intr.memmove"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMMOVE_LEN]]) <{isVolatile = false}>
-  // CHECK-DAG: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
-  // CHECK-DAG: "llvm.intr.memmove"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMMOVE_LEN]]) <{isVolatile = false}>
-  // CHECK-DAG: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
-  // CHECK-DAG: "llvm.intr.memmove"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMMOVE_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
+  // CHECK: %[[MEMMOVE_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memmove"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMMOVE_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
+  // CHECK: %[[MEMMOVE_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memmove"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMMOVE_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 2] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
+  // CHECK: %[[MEMMOVE_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memmove"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMMOVE_LEN]]) <{isVolatile = false}>
+  // CHECK: %[[SLOT_IN_OTHER:.*]] = llvm.getelementptr %[[OTHER_ARRAY]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
+  // CHECK: %[[MEMMOVE_LEN:.*]] = llvm.mlir.constant(4 : i32) : i32
+  // CHECK: "llvm.intr.memmove"(%[[SLOT_IN_OTHER]], %{{.*}}, %[[MEMMOVE_LEN]]) <{isVolatile = false}>
   "llvm.intr.memmove"(%other_array, %1, %memmove_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
   %2 = llvm.getelementptr %1[0, 1] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<4 x i32>
   %3 = llvm.load %2 : !llvm.ptr -> i32

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

// CHECK-COUNT-2: = llvm.alloca %[[ALLOCA_LEN]] x i32
// CHECK: = llvm.alloca %[[ALLOCA_LEN]] x i32
// TODO: This should also disappear.
// CHECK: = llvm.alloca %[[ALLOCA_LEN]] x !llvm.array<1 x i32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this no longer disappear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to investigate. Will get back to you once I found the cause.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found the issue:
This is due to changing from the rewriter to a builder. The rewriter was able to fold the GEPOp zero indexing GEPOp away, while this is not done be the builder.
Somehow, SROA is not happy with this intermediate GEP, which degrades it's performance in such cases.

I'll address this as part of the improvement efforts.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rewriter was able to fold the GEPOp zero indexing GEPOp away, while this is not done be the builder.

It's not the builder/rewriter I believe, instead it is the Greedy driver doing a lot of cleanup like this on the fly.

That said you could do it yourself by using createOrFold instead of create if these are GEPOp you are creating. And otherwise have the algorithm match GEPOp and fold them as it goes.

for (Operation &op : block.getOperations())
if (auto allocator =
dyn_cast<DestructurableAllocationOpInterface>(op))
allocators.emplace_back(allocator);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous implementation was a recursive walk, this is a simple block iteration, isn't this a change in behavior?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other thing: instead of reconstructing the allocators list at each iteration, can you modify tryToDestructureMemorySlots to actually preserve it up-to-date as we go?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch regarding the recursion difference. Will fix accordingly.

Other thing: instead of reconstructing the allocators list at each iteration, can you modify tryToDestructureMemorySlots to actually preserve it up-to-date as we go?

We have substantial SROA improvements in the pipeline, I'll try to address this as part of the upcoming changes, if that is fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, can you add a TODO in the code to capture this?

@Dinistro Dinistro merged commit 63897a5 into main Mar 18, 2024
@Dinistro Dinistro deleted the users/dinistro/remove-sroa-pattern branch March 18, 2024 07:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants