Skip to content

[MLIR][Interfaces] Change MemorySlotInterface to use OpBuilder #91341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

Dinistro
Copy link
Contributor

@Dinistro Dinistro commented May 7, 2024

This commit changes the MemorySlotInterface back to using OpBuilder instead of a rewriter. This was originally introduced in https://reviews.llvm.org/D150432 but it was shown that patterns are a bad idea for both Mem2Reg and SROA.
Mem2Reg suffers from the usage of a rewriter due to being forced to create new basic blocks. This is an issue, as it leads to the invalidation of the dominance information, which can be expensive to recompute.

This commit changes the `MemorySlotInterface` back to using `OpBuilder`
instead of a rewriter. This was originally introduced in
https://reviews.llvm.org/D150432 but it was shown that patterns are a
bad idea for both Mem2Reg and SROA.
Mem2Reg suffers from the usage of a rewriter due to neing forced to
create new basic blocks. This is an issue, as it leads to the
invalidation of the dominance information, which can be expensive to
recompute.
@Dinistro Dinistro requested a review from gysit May 7, 2024 14:34
@Dinistro Dinistro requested a review from Moxinilian as a code owner May 7, 2024 14:34
@llvmbot
Copy link
Member

llvmbot commented May 7, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir-llvm

Author: Christian Ulmann (Dinistro)

Changes

This commit changes the MemorySlotInterface back to using OpBuilder instead of a rewriter. This was originally introduced in https://reviews.llvm.org/D150432 but it was shown that patterns are a bad idea for both Mem2Reg and SROA.
Mem2Reg suffers from the usage of a rewriter due to being forced to create new basic blocks. This is an issue, as it leads to the invalidation of the dominance information, which can be expensive to recompute.


Patch is 62.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91341.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+27-36)
  • (modified) mlir/include/mlir/Transforms/Mem2Reg.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/SROA.h (+1-1)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+121-127)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+24-28)
  • (modified) mlir/lib/Transforms/Mem2Reg.cpp (+30-60)
  • (modified) mlir/lib/Transforms/SROA.cpp (+12-14)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 764fa6d547b2e..adf182ac7069d 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -40,42 +40,40 @@ def PromotableAllocationOpInterface
         Provides the default Value of this memory slot. The provided Value
         will be used as the reaching definition of loads done before any store.
         This Value must outlive the promotion and dominate all the uses of this
-        slot's pointer. The provided rewriter can be used to create the default
+        slot's pointer. The provided builder can be used to create the default
         value on the fly.
 
-        The rewriter is located at the beginning of the block where the slot
-        pointer is defined. All IR mutations must happen through the rewriter.
+        The builder is located at the beginning of the block where the slot
+        pointer is defined.
       }], "::mlir::Value", "getDefaultValue",
       (ins
         "const ::mlir::MemorySlot &":$slot,
-        "::mlir::RewriterBase &":$rewriter)
+        "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         Hook triggered for every new block argument added to a block.
         This will only be called for slots declared by this operation.
 
-        The rewriter is located at the beginning of the block on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located at the beginning of the block on call. All IR
+        mutations must happen through the builder.
       }],
       "void", "handleBlockArgument",
       (ins
         "const ::mlir::MemorySlot &":$slot,
         "::mlir::BlockArgument":$argument,
-        "::mlir::RewriterBase &":$rewriter
+        "::mlir::OpBuilder &":$builder
       )
     >,
     InterfaceMethod<[{
         Hook triggered once the promotion of a slot is complete. This can
         also clean up the created default value if necessary.
         This will only be called for slots declared by this operation.
-
-        All IR mutations must happen through the rewriter.
       }],
       "void", "handlePromotionComplete",
       (ins
         "const ::mlir::MemorySlot &":$slot, 
         "::mlir::Value":$defaultValue,
-        "::mlir::RewriterBase &":$rewriter)
+        "::mlir::OpBuilder &":$builder)
     >,
   ];
 }
@@ -119,15 +117,14 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         The returned value must dominate all operations dominated by the storing
         operation.
 
-        If IR must be mutated to extract a concrete value being stored, mutation
-        must happen through the provided rewriter. The rewriter is located
-        immediately after the memory operation on call. No IR deletion is
-        allowed in this method. IR mutations must not introduce new uses of the
-        memory slot. Existing control flow must not be modified.
+        The builder is located immediately after the memory operation on call.
+        No IR deletion is allowed in this method. IR mutations must not
+        introduce new uses of the memory slot. Existing control flow must not
+        be modified.
       }],
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "::mlir::Value":$reachingDef,
            "const ::mlir::DataLayout &":$dataLayout)
     >,
@@ -166,14 +163,13 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located after the promotable operation on call.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "::mlir::Value":$reachingDefinition,
            "const ::mlir::DataLayout &":$dataLayout)
     >,
@@ -224,13 +220,12 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         have been done at the point of calling this method, but it will be done
         eventually.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter.
+        The builder is located after the promotable operation on call.
       }],
       "::mlir::DeletionKind",
       "removeBlockingUses",
       (ins "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         This method allows the promoted operation to visit the SSA values used
@@ -254,13 +249,12 @@ def PromotableOpInterface : OpInterface<"PromotableOpInterface"> {
         scheduled for removal and if `requiresReplacedValues` returned
         true.
 
-        The rewriter is located after the promotable operation on call. All IR
-        mutations must happen through the rewriter. During the transformation,
-        *no operation should be deleted*.
+        The builder is located after the promotable operation on call. During
+        the transformation, *no operation should be deleted*.
       }],
       "void", "visitReplacedValues",
       (ins "::llvm::ArrayRef<std::pair<::mlir::Operation*, ::mlir::Value>>":$mutatedDefs,
-           "::mlir::RewriterBase &":$rewriter), [{}], [{ return; }]
+           "::mlir::OpBuilder &":$builder), [{}], [{ return; }]
     >,
   ];
 }
@@ -293,25 +287,23 @@ def DestructurableAllocationOpInterface
         at the end of this call. Only generates subslots for the indices found in
         `usedIndices` since all other subslots are unused.
 
-        The rewriter is located at the beginning of the block where the slot
-        pointer is defined. All IR mutations must happen through the rewriter.
+        The builder is located at the beginning of the block where the slot
+        pointer is defined.
       }],
       "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot>",
       "destructure",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<::mlir::Attribute> &":$usedIndices,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
     InterfaceMethod<[{
         Hook triggered once the destructuring of a slot is complete, meaning the
         original slot is no longer being refered to and could be deleted.
         This will only be called for slots declared by this operation.
-
-        All IR mutations must happen through the rewriter.
       }],
       "void", "handleDestructuringComplete",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::OpBuilder &":$builder)
     >,
   ];
 }
@@ -376,15 +368,14 @@ def DestructurableAccessorOpInterface
         Rewires the use of a slot to the generated subslots, without deleting
         any operation. Returns whether the accessor should be deleted.
 
-        All IR mutations must happen through the rewriter. Deletion of
-        operations is not allowed, only the accessor can be scheduled for
-        deletion by returning the appropriate value.
+        Deletion of operations is not allowed, only the accessor can be
+        scheduled for deletion by returning the appropriate value.
       }],
       "::mlir::DeletionKind",
       "rewire",
       (ins "const ::mlir::DestructurableMemorySlot &":$slot,
            "::llvm::DenseMap<::mlir::Attribute, ::mlir::MemorySlot> &":$subslots,
-           "::mlir::RewriterBase &":$rewriter,
+           "::mlir::OpBuilder &":$builder,
            "const ::mlir::DataLayout &":$dataLayout)
     >
   ];
diff --git a/mlir/include/mlir/Transforms/Mem2Reg.h b/mlir/include/mlir/Transforms/Mem2Reg.h
index ed10644e26a51..b4f939d654142 100644
--- a/mlir/include/mlir/Transforms/Mem2Reg.h
+++ b/mlir/include/mlir/Transforms/Mem2Reg.h
@@ -27,7 +27,7 @@ struct Mem2RegStatistics {
 /// at least one memory slot was promoted.
 LogicalResult
 tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
-                        RewriterBase &rewriter, const DataLayout &dataLayout,
+                        OpBuilder &builder, const DataLayout &dataLayout,
                         Mem2RegStatistics statistics = {});
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Transforms/SROA.h b/mlir/include/mlir/Transforms/SROA.h
index d09a7989edeab..fa84fb1eae73a 100644
--- a/mlir/include/mlir/Transforms/SROA.h
+++ b/mlir/include/mlir/Transforms/SROA.h
@@ -31,7 +31,7 @@ struct SROAStatistics {
 /// failure if no slot was destructured.
 LogicalResult tryToDestructureMemorySlots(
     ArrayRef<DestructurableAllocationOpInterface> allocators,
-    RewriterBase &rewriter, const DataLayout &dataLayout,
+    OpBuilder &builder, const DataLayout &dataLayout,
     SROAStatistics statistics = {});
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 6eeb13ebffb51..70102e1c81920 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -36,26 +36,26 @@ llvm::SmallVector<MemorySlot> LLVM::AllocaOp::getPromotableSlots() {
 }
 
 Value LLVM::AllocaOp::getDefaultValue(const MemorySlot &slot,
-                                      RewriterBase &rewriter) {
-  return rewriter.create<LLVM::UndefOp>(getLoc(), slot.elemType);
+                                      OpBuilder &builder) {
+  return builder.create<LLVM::UndefOp>(getLoc(), slot.elemType);
 }
 
 void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
                                          BlockArgument argument,
-                                         RewriterBase &rewriter) {
+                                         OpBuilder &builder) {
   for (Operation *user : getOperation()->getUsers())
     if (auto declareOp = llvm::dyn_cast<LLVM::DbgDeclareOp>(user))
-      rewriter.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
-                                        declareOp.getVarInfo(),
-                                        declareOp.getLocationExpr());
+      builder.create<LLVM::DbgValueOp>(declareOp.getLoc(), argument,
+                                       declareOp.getVarInfo(),
+                                       declareOp.getLocationExpr());
 }
 
 void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
                                              Value defaultValue,
-                                             RewriterBase &rewriter) {
+                                             OpBuilder &builder) {
   if (defaultValue && defaultValue.use_empty())
-    rewriter.eraseOp(defaultValue.getDefiningOp());
-  rewriter.eraseOp(*this);
+    defaultValue.getDefiningOp()->erase();
+  this->erase();
 }
 
 SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
@@ -78,16 +78,16 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
 DenseMap<Attribute, MemorySlot>
 LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
                             const SmallPtrSetImpl<Attribute> &usedIndices,
-                            RewriterBase &rewriter) {
+                            OpBuilder &builder) {
   assert(slot.ptr == getResult());
-  rewriter.setInsertionPointAfter(*this);
+  builder.setInsertionPointAfter(*this);
 
   auto destructurableType = cast<DestructurableTypeInterface>(getElemType());
   DenseMap<Attribute, MemorySlot> slotMap;
   for (Attribute index : usedIndices) {
     Type elemType = destructurableType.getTypeAtIndex(index);
     assert(elemType && "used index must exist");
-    auto subAlloca = rewriter.create<LLVM::AllocaOp>(
+    auto subAlloca = builder.create<LLVM::AllocaOp>(
         getLoc(), LLVM::LLVMPointerType::get(getContext()), elemType,
         getArraySize());
     slotMap.try_emplace<MemorySlot>(index, {subAlloca.getResult(), elemType});
@@ -97,9 +97,9 @@ LLVM::AllocaOp::destructure(const DestructurableMemorySlot &slot,
 }
 
 void LLVM::AllocaOp::handleDestructuringComplete(
-    const DestructurableMemorySlot &slot, RewriterBase &rewriter) {
+    const DestructurableMemorySlot &slot, OpBuilder &builder) {
   assert(slot.ptr == getResult());
-  rewriter.eraseOp(*this);
+  this->erase();
 }
 
 //===----------------------------------------------------------------------===//
@@ -112,7 +112,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, OpBuilder &builder,
                               Value reachingDef, const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
@@ -175,7 +175,7 @@ static bool isBigEndian(const DataLayout &dataLayout) {
 
 /// Converts a value to an integer type of the same size.
 /// Assumes that the type can be converted.
-static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
+static Value castToSameSizedInt(OpBuilder &builder, Location loc, Value val,
                                 const DataLayout &dataLayout) {
   Type type = val.getType();
   assert(isSupportedTypeForConversion(type) &&
@@ -185,15 +185,15 @@ static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
     return val;
 
   uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
-  IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
+  IntegerType valueSizeInteger = builder.getIntegerType(typeBitSize);
 
   if (isa<LLVM::LLVMPointerType>(type))
-    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
-  return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
+    return builder.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
+  return builder.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
 }
 
 /// Converts a value with an integer type to `targetType`.
-static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
+static Value castIntValueToSameSizedType(OpBuilder &builder, Location loc,
                                          Value val, Type targetType) {
   assert(isa<IntegerType>(val.getType()) &&
          "expected value to have an integer type");
@@ -202,13 +202,13 @@ static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
   if (val.getType() == targetType)
     return val;
   if (isa<LLVM::LLVMPointerType>(targetType))
-    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
-  return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
+    return builder.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
+  return builder.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
 }
 
 /// Constructs operations that convert `srcValue` into a new value of type
 /// `targetType`. Assumes the types have the same bitsize.
-static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
+static Value castSameSizedTypes(OpBuilder &builder, Location loc,
                                 Value srcValue, Type targetType,
                                 const DataLayout &dataLayout) {
   Type srcType = srcValue.getType();
@@ -226,18 +226,18 @@ static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
   // provenance.
   if (isa<LLVM::LLVMPointerType>(targetType) &&
       isa<LLVM::LLVMPointerType>(srcType))
-    return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
-                                                        srcValue);
+    return builder.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+                                                       srcValue);
 
   // For all other castable types, casting through integers is necessary.
-  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
-  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+  Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
+  return castIntValueToSameSizedType(builder, loc, replacement, targetType);
 }
 
 /// Constructs operations that convert `srcValue` into a new value of type
 /// `targetType`. Performs bit-level extraction if the source type is larger
 /// than the target type. Assumes that this conversion is possible.
-static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
+static Value createExtractAndCast(OpBuilder &builder, Location loc,
                                   Value srcValue, Type targetType,
                                   const DataLayout &dataLayout) {
   // Get the types of the source and target values.
@@ -249,31 +249,31 @@ static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
   uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
   uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
   if (srcTypeSize == targetTypeSize)
-    return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
+    return castSameSizedTypes(builder, loc, srcValue, targetType, dataLayout);
 
   // First, cast the value to a same-sized integer type.
-  Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+  Value replacement = castToSameSizedInt(builder, loc, srcValue, dataLayout);
 
   // Truncate the integer if the size of the target is less than the value.
   if (isBigEndian(dataLayout)) {
     uint64_t shiftAmount = srcTypeSize - targetTypeSize;
-    auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
-        loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+    auto shiftConstant = builder.create<LLVM::ConstantOp>(
+        loc, builder.getIntegerAttr(srcType, shiftAmount));
     replacement =
-        rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+        builder.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
   }
 
-  replacement = rewriter.create<LLVM::TruncOp>(
-      loc, rewriter.getIntegerType(targetTypeSize), replacement);
+  replacement = builder.create<LLVM::TruncOp>(
+      loc, builder.getIntegerType(targetTypeSize), replacement);
 
   // Now cast the integer to the actual target type if required.
-  return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
+  return castIntValueToSameSizedType(builder, loc, replacement, targetType);
 }
 
 /// Constructs operations that insert the bits of `srcValue` into the
 /// "beginning" of `reachingDef` (beginning is endianness dependent).
 /// Assumes that this conversion is possible.
-static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
+static Value createInsertAndCast(OpBuilder &builder, Location loc,
                                  Value srcValue, Value reachingDef,
                                  const DataLayout &dataLayout) {
 
@@ -284,27 +284,27 @@ static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
   uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
   uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
   if (slotTypeSize == valueTypeSize)
-    return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
+    return castSameSizedTypes(builder, loc, srcValue, reachingDef.getType(),
                               dataLayout);
 
   // In the case where the store only overwrites parts of the memory,
   // bit fiddling is required to construct the new value.
 
   // First convert both values to integers of the same size.
-  Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
-  Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
+  Value defAsInt = castToSameSizedInt(builder, loc, reachingDef, dataLayout);
...
[truncated]

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Member

@Moxinilian Moxinilian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really wish there was a way to extend the RewriterBase instead of doing this, as even though this is not used in patterns, being able to hook into the rewriting process seems useful. Alas!

@Dinistro
Copy link
Contributor Author

Dinistro commented May 8, 2024

I really wish there was a way to extend the RewriterBase instead of doing this, as even though this is not used in patterns, being able to hook into the rewriting process seems useful. Alas!

I've talked to multiple people that complained about the rewriter not being able to properly deal with block arguments. I did never investigate how hard this would be to implement, but given that many people encountered this, I suspect it's highly non-trivial.

@Dinistro Dinistro merged commit 084e2b5 into main May 8, 2024
@Dinistro Dinistro deleted the users/dinistro/change-memory-slot-interface-to-builders branch May 8, 2024 05:40
@matthias-springer
Copy link
Member

A rewriter is essentially a builder with extra API. So I am surprised that switching from RewriterBase to OpBuilder made any difference.

I've talked to multiple people that complained about the rewriter not being able to properly deal with block arguments.

What kind of functionality/API is missing?

This is an issue, as it leads to the invalidation of the dominance information

How does this relate to dominance information? Either you modify the IR in such a way that dominance info is invalidated or not; whether you use a rewriter or a builder should not matter.

@Dinistro
Copy link
Contributor Author

Dinistro commented May 8, 2024

A rewriter is essentially a builder with extra API. So I am surprised that switching from RewriterBase to OpBuilder made any difference.

The problem is purely related to adding block arguments. RewriterBase forces us to create a new block and then inline the old block into this one. The DominanceInfo internally stores block pointers, which are turned into dangling pointers by such a change. See https://github.com/llvm/llvm-project/pull/91341/files#diff-814ca4e9963a13044682d70a817c6c3503c779b3aabfa06398162f7a573b0b94L478-L504 for the concrete diff

What kind of functionality/API is missing?

A way of adding block arguments to an existing block without needing a workaround involving block creation and block inlining.

How does this relate to dominance information? Either you modify the IR in such a way that dominance info is invalidated or not; whether you use a rewriter or a builder should not matter.

The rewriter's lack of a "block argument adding" API leads to the usage of APIs that modify the IR in a way that it destroys dominance information.

@matthias-springer
Copy link
Member

I see. We are missing a RewriterBase::addBlockArgument function. And the general convention is "when we have a rewriter, all modifications must be done through the rewriter", so creating a new block was the only way to adhere to that convention. I can add RewriterBase::addBlockArgument, so that we have it in the future.

There is one downside of switching everything to OpBuilder: if a function takes an OpBuilder (as opposed to a rewriter), it "looks" like the function merely adds new IR, but does not modify/erase existing IR. (But you are actually erasing ops.) I didn't look at the code in detail, maybe it is obvious from the documentation that IR may get modified/erased.

@Dinistro
Copy link
Contributor Author

Dinistro commented May 8, 2024

I see. We are missing a RewriterBase::addBlockArgument function. And the general convention is "when we have a rewriter, all modifications must be done through the rewriter", so creating a new block was the only way to adhere to that convention. I can add RewriterBase::addBlockArgument, so that we have it in the future.

That would be fantastic. We sadly did not have enough time to investigate into this, but I know that many parties would be happy about such an extension 😄

There is one downside of switching everything to OpBuilder: if a function takes an OpBuilder (as opposed to a rewriter), it "looks" like the function merely adds new IR, but does not modify/erase existing IR. (But you are actually erasing ops.) I didn't look at the code in detail, maybe it is obvious from the documentation that IR may get modified/erased.

I understand that this might be a bit odd, but I would assume that people who run a Mem2Reg function are aware that it will delete unnecessary allocas and rewire the IR accordingly.

@Moxinilian
Copy link
Member

If RewriterBase gets an addBlockArgument, I am wholly in favor of reverting this change. I tried to add this feature last year but lost all hopes when I tried to make the conversion rewriter implement it, so if someone with more experience than me last year can make it work that would be great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants