Skip to content

[MLIR][Mem2Reg][LLVM] Enhance partial load support #89094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 18, 2024

Conversation

Dinistro
Copy link
Contributor

This commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences.

Note: There will be additional followups that introduce a similar functionality for stores, and there are plans to support accesses into the middle of memory slots.

This commit improves LLVM dialect's Mem2Reg interfaces to support
promotions of partial loads from larger memory slots. To support this,
the Mem2Reg interface methods are extended with additional data layout
parameters. The data layout is required to determine type sizes to
produce correct conversion sequences.
@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Christian Ulmann (Dinistro)

Changes

This commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences.

Note: There will be additional followups that introduce a similar functionality for stores, and there are plans to support accesses into the middle of memory slots.


Patch is 24.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89094.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+4-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+136-38)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+8-6)
  • (modified) mlir/lib/Transforms/Mem2Reg.cpp (+10-7)
  • (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+112-17)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 9db89361c78002..8c642c0ed26aca 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       }],
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::RewriterBase &":$rewriter,
+           "const ::mlir::DataLayout &":$dataLayout)
     >,
     InterfaceMethod<[{
         Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
            "::mlir::RewriterBase &":$rewriter,
-           "::mlir::Value":$reachingDefinition)
+           "::mlir::Value":$reachingDefinition,
+           "const ::mlir::DataLayout &":$dataLayout)
     >,
   ];
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index c7ca0b4a5843ad..0c4d019f5654ac 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                              const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
 
@@ -122,37 +123,124 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
   return getAddr() == slot.ptr;
 }
 
-/// Checks that two types are the same or can be cast into one another.
-static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
-  return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
-                        !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
-                        layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+/// Checks if `type` can be used in any kind of conversion sequences.
+static bool isSupportedTypeForConversion(Type type) {
+  // Aggregate types are not bitcastable.
+  if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
+    return false;
+
+  // LLVM vector types are only used for either pointers or target specific
+  // types. These types cannot be casted in the general case, thus the memory
+  // optimizations do not support them.
+  if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
+    return false;
+
+  // Scalable types are not supported.
+  if (auto vectorType = dyn_cast<VectorType>(type))
+    return !vectorType.isScalable();
+  return true;
+}
+
+/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
+/// truncations.
+static bool areConversionCompatible(const DataLayout &layout, Type lhs,
+                                    Type rhs) {
+  if (lhs == rhs)
+    return true;
+
+  // Aggregate types cannot be casted.
+  if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
+    return false;
+  return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
 }
 
+/// Checks if `dataLayout` describes a little endian layout.
+static bool isLittleEndian(const DataLayout &dataLayout) {
+  auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
+  return !endiannessStr || endiannessStr == "little";
+}
+
+/// The size of a byte in bits.
+constexpr const static uint64_t kBitsInByte = 8;
+
 /// Constructs operations that convert `inputValue` into a new value of type
 /// `targetType`. Assumes that this conversion is possible.
 static Value createConversionSequence(RewriterBase &rewriter, Location loc,
-                                      Value inputValue, Type targetType) {
-  if (inputValue.getType() == targetType)
-    return inputValue;
+                                      Value srcValue, Type targetType,
+                                      const DataLayout &dataLayout) {
+  // Get the types of the source and destination values.
+  Type srcType = srcValue.getType();
+
+  uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
+  uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
+
+  // Nothing has to be done if the types are already the same.
+  if (srcType == targetType)
+    return srcValue;
+
+  // The code below is currently not capable of handling aggregate types as it
+  // makes use of bitcasts. Aggregates cannot be bitcast.
+  // TODO: We should have a `LLVMAggregateType` base class to easily perform
+  // this `isa`.
+  if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
+      isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
+    return nullptr;
+
+  // In the special case of casting one pointer to another, we want to generate
+  // an address space cast. Bitcasts of pointers are not allowed and using
+  // pointer to integer conversions are not equivalent due to the loss or
+  // provenance.
+  if (isa<LLVM::LLVMPointerType>(targetType) &&
+      isa<LLVM::LLVMPointerType>(srcType)) {
+    // Abort the conversion if the pointers have different bitwidths.
+    if (srcTypeSize != targetTypeSize)
+      return nullptr;
+    return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+                                                        srcValue);
+  }
 
-  if (!isa<LLVM::LLVMPointerType>(targetType) &&
-      !isa<LLVM::LLVMPointerType>(inputValue.getType()))
-    return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+  IntegerType valueSizeInteger =
+      rewriter.getIntegerType(srcTypeSize * kBitsInByte);
+  Value replacement = srcValue;
+
+  // First, cast the value to a same-sized integer type.
+  if (isa<LLVM::LLVMPointerType>(srcType))
+    replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
+                                                          replacement);
+  else if (replacement.getType() != valueSizeInteger)
+    replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
+                                                         replacement);
+
+  // Truncate the integer if the size of the read is less than the value.
+  if (targetTypeSize != srcTypeSize) {
+    if (!isLittleEndian(dataLayout)) {
+      uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
+      auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+      replacement =
+          rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+    }
 
-  if (!isa<LLVM::LLVMPointerType>(targetType))
-    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+    replacement = rewriter.create<LLVM::TruncOp>(
+        loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
+        replacement);
+  }
 
-  if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
-    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+  // Now cast the integer to the actual destination type if required.
+  if (isa<LLVM::LLVMPointerType>(targetType))
+    replacement =
+        rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
+  else if (replacement.getType() != targetType)
+    replacement =
+        rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
 
-  return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
-                                                      inputValue);
+  return replacement;
 }
 
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
-  return createConversionSequence(rewriter, getLoc(), getValue(),
-                                  slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                               const DataLayout &dataLayout) {
+  return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
+                                  dataLayout);
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +255,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
   // be removed (provided it loads the exact stored value and is not
   // volatile).
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+         areConversionCompatible(dataLayout, getResult().getType(),
+                                 slot.elemType) &&
          !getVolatile_();
 }
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  Value newResult = createConversionSequence(
-      rewriter, getLoc(), reachingDefinition, getResult().getType());
+  Value newResult =
+      createConversionSequence(rewriter, getLoc(), reachingDefinition,
+                               getResult().getType(), dataLayout);
   rewriter.replaceAllUsesWith(getResult(), newResult);
   return DeletionKind::Delete;
 }
@@ -194,13 +285,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
   // store OF the slot pointer, only INTO the slot pointer.
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
          getValue() != slot.ptr &&
-         areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
+         areConversionCompatible(dataLayout, slot.elemType,
+                                 getValue().getType()) &&
          !getVolatile_();
 }
 
 DeletionKind LLVM::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
@@ -747,8 +840,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
   return getDst() == slot.ptr;
 }
 
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   // TODO: Support non-integer types.
   return TypeSwitch<Type, Value>(slot.elemType)
       .Case([&](IntegerType intType) -> Value {
@@ -802,7 +895,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemsetOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
@@ -1059,8 +1153,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
   return memcpyStoresTo(*this, slot);
 }
 
-Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1074,7 +1168,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
@@ -1109,7 +1204,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
-                                      RewriterBase &rewriter) {
+                                      RewriterBase &rewriter,
+                                      const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1123,7 +1219,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
@@ -1159,8 +1256,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
   return memcpyStoresTo(*this, slot);
 }
 
-Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
-                                 RewriterBase &rewriter) {
+Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1174,7 +1271,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 6c5250d527ade8..ebbf20f1b76b67 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value memref::LoadOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
 
@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
 
 DeletionKind memref::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
   rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
   return getMemRef() == slot.ptr;
 }
 
-Value memref::StoreOp::getStored(const MemorySlot &slot,
-                                 RewriterBase &rewriter) {
+Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 const DataLayout &dataLayout) {
   return getValue();
 }
 
@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
 
 DeletionKind memref::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index abe565ea862f8f..1e620e46af84ea 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -165,7 +165,7 @@ class MemorySlotPromoter {
 public:
   MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
                      RewriterBase &rewriter, DominanceInfo &dominance,
-                     MemorySlotPromotionInfo info,
+                     const DataLayout &dataLayout, MemorySlotPromotionInfo info,
                      const Mem2RegStatistics &statistics);
 
   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
@@ -204,6 +204,7 @@ class MemorySlotPromoter {
   DenseMap<PromotableMemOpInterface, Value> reachingDefs;
   DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
   DominanceInfo &dominance;
+  const DataLayout &dataLayout;
   MemorySlotPromotionInfo info;
   const Mem2RegStatistics &statistics;
 };
@@ -213,9 +214,11 @@ class MemorySlotPromoter {
 MemorySlotPromoter::MemorySlotPromoter(
     MemorySlot slot, PromotableAllocationOpInterface allocator,
     RewriterBase &rewriter, DominanceInfo &dominance,
-    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+    const DataLayout &dataLayout, MemorySlotPromotionInfo info,
+    const Mem2RegStatistics &statistics)
     : slot(slot), allocator(allocator), rewriter(rewriter),
-      dominance(dominance), info(std::move(info)), statistics(statistics) {
+      dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
+      statistics(statistics) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -435,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
 
       if (memOp.storesTo(slot)) {
         rewriter.setInsertionPointAfter(memOp);
-        Value stored = memOp.getStored(slot, rewriter);
+        Value stored = memOp.getStored(slot, rewriter, dataLayout);
         assert(stored && "a memory operation storing to a slot must provide a "
                          "new definition of the slot");
         reachingDef = stored;
@@ -568,8 +571,8 @@ void MemorySlotPromoter::removeBlockingUses() {
 
       rewriter.setInsertionPointAfter(toPromote);
       if (toPromoteMemOp.removeBlockingUses(
-              slot, info.userToBlockingUses[toPromote], rewriter,
-              reachingDef) == DeletionKind::Delete)
+              slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
+              dataLayout) == DeletionKind::Delete)
         toErase.push_back(toPromote);
       if (toPromoteMemOp.storesTo(slot))
         if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
@@ -642,7 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
       MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
       if (info) {
-        MemorySlotPromoter(slot, allocator, rewriter, dominance,
+        MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
                            std::move(*info), statistics)
             .promoteSlot();
         promotedAny = true;
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index fa5d842302d0f4..e724c2e8679501 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -448,19 +448,6 @@ llvm.func @store_load_forward() -> i32 {
 
 // -----
 
-// CHECK-LABEL: llvm.func @store_load_wrong_type
-llvm.func @store_load_wrong_type() -> i16 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  %1 = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: = llvm.alloca
-  %2 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
-  %3 = llvm.load %2 {alignment = 2 : i64} : !llvm.ptr -> i16
-  llvm.return %3 : i16
-}
-
-// -----
-
 // CHECK-LABEL: llvm.func @merge_point_cycle
 llvm.func @merge_point_cycle() {
   // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : i32
@@ -894,7 +881,7 @@ llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -
 // CHECK-LABEL: @load_smaller_int
 llvm.func @load_smaller_int() -> i16 {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
+  // CHECK-NOT: llvm.alloca
   %1 = ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Christian Ulmann (Dinistro)

Changes

This commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences.

Note: There will be additional followups that introduce a similar functionality for stores, and there are plans to support accesses into the middle of memory slots.


Patch is 24.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89094.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/MemorySlotInterfaces.td (+4-2)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+136-38)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp (+8-6)
  • (modified) mlir/lib/Transforms/Mem2Reg.cpp (+10-7)
  • (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+112-17)
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 9db89361c78002..8c642c0ed26aca 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       }],
       "::mlir::Value", "getStored",
       (ins "const ::mlir::MemorySlot &":$slot,
-           "::mlir::RewriterBase &":$rewriter)
+           "::mlir::RewriterBase &":$rewriter,
+           "const ::mlir::DataLayout &":$dataLayout)
     >,
     InterfaceMethod<[{
         Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
       (ins "const ::mlir::MemorySlot &":$slot,
            "const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
            "::mlir::RewriterBase &":$rewriter,
-           "::mlir::Value":$reachingDefinition)
+           "::mlir::Value":$reachingDefinition,
+           "const ::mlir::DataLayout &":$dataLayout)
     >,
   ];
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index c7ca0b4a5843ad..0c4d019f5654ac 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                              const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
 
@@ -122,37 +123,124 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
   return getAddr() == slot.ptr;
 }
 
-/// Checks that two types are the same or can be cast into one another.
-static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
-  return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
-                        !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
-                        layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+/// Checks if `type` can be used in any kind of conversion sequences.
+static bool isSupportedTypeForConversion(Type type) {
+  // Aggregate types are not bitcastable.
+  if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
+    return false;
+
+  // LLVM vector types are only used for either pointers or target specific
+  // types. These types cannot be casted in the general case, thus the memory
+  // optimizations do not support them.
+  if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
+    return false;
+
+  // Scalable types are not supported.
+  if (auto vectorType = dyn_cast<VectorType>(type))
+    return !vectorType.isScalable();
+  return true;
+}
+
+/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
+/// truncations.
+static bool areConversionCompatible(const DataLayout &layout, Type lhs,
+                                    Type rhs) {
+  if (lhs == rhs)
+    return true;
+
+  // Aggregate types cannot be casted.
+  if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
+    return false;
+  return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
 }
 
+/// Checks if `dataLayout` describes a little endian layout.
+static bool isLittleEndian(const DataLayout &dataLayout) {
+  auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
+  return !endiannessStr || endiannessStr == "little";
+}
+
+/// The size of a byte in bits.
+constexpr const static uint64_t kBitsInByte = 8;
+
 /// Constructs operations that convert `inputValue` into a new value of type
 /// `targetType`. Assumes that this conversion is possible.
 static Value createConversionSequence(RewriterBase &rewriter, Location loc,
-                                      Value inputValue, Type targetType) {
-  if (inputValue.getType() == targetType)
-    return inputValue;
+                                      Value srcValue, Type targetType,
+                                      const DataLayout &dataLayout) {
+  // Get the types of the source and destination values.
+  Type srcType = srcValue.getType();
+
+  uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
+  uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
+
+  // Nothing has to be done if the types are already the same.
+  if (srcType == targetType)
+    return srcValue;
+
+  // The code below is currently not capable of handling aggregate types as it
+  // makes use of bitcasts. Aggregates cannot be bitcast.
+  // TODO: We should have a `LLVMAggregateType` base class to easily perform
+  // this `isa`.
+  if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
+      isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
+    return nullptr;
+
+  // In the special case of casting one pointer to another, we want to generate
+  // an address space cast. Bitcasts of pointers are not allowed and using
+  // pointer to integer conversions are not equivalent due to the loss or
+  // provenance.
+  if (isa<LLVM::LLVMPointerType>(targetType) &&
+      isa<LLVM::LLVMPointerType>(srcType)) {
+    // Abort the conversion if the pointers have different bitwidths.
+    if (srcTypeSize != targetTypeSize)
+      return nullptr;
+    return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+                                                        srcValue);
+  }
 
-  if (!isa<LLVM::LLVMPointerType>(targetType) &&
-      !isa<LLVM::LLVMPointerType>(inputValue.getType()))
-    return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+  IntegerType valueSizeInteger =
+      rewriter.getIntegerType(srcTypeSize * kBitsInByte);
+  Value replacement = srcValue;
+
+  // First, cast the value to a same-sized integer type.
+  if (isa<LLVM::LLVMPointerType>(srcType))
+    replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
+                                                          replacement);
+  else if (replacement.getType() != valueSizeInteger)
+    replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
+                                                         replacement);
+
+  // Truncate the integer if the size of the read is less than the value.
+  if (targetTypeSize != srcTypeSize) {
+    if (!isLittleEndian(dataLayout)) {
+      uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
+      auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+          loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+      replacement =
+          rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+    }
 
-  if (!isa<LLVM::LLVMPointerType>(targetType))
-    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+    replacement = rewriter.create<LLVM::TruncOp>(
+        loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
+        replacement);
+  }
 
-  if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
-    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+  // Now cast the integer to the actual destination type if required.
+  if (isa<LLVM::LLVMPointerType>(targetType))
+    replacement =
+        rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
+  else if (replacement.getType() != targetType)
+    replacement =
+        rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
 
-  return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
-                                                      inputValue);
+  return replacement;
 }
 
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
-  return createConversionSequence(rewriter, getLoc(), getValue(),
-                                  slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                               const DataLayout &dataLayout) {
+  return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
+                                  dataLayout);
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +255,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
   // be removed (provided it loads the exact stored value and is not
   // volatile).
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+         areConversionCompatible(dataLayout, getResult().getType(),
+                                 slot.elemType) &&
          !getVolatile_();
 }
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  Value newResult = createConversionSequence(
-      rewriter, getLoc(), reachingDefinition, getResult().getType());
+  Value newResult =
+      createConversionSequence(rewriter, getLoc(), reachingDefinition,
+                               getResult().getType(), dataLayout);
   rewriter.replaceAllUsesWith(getResult(), newResult);
   return DeletionKind::Delete;
 }
@@ -194,13 +285,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
   // store OF the slot pointer, only INTO the slot pointer.
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
          getValue() != slot.ptr &&
-         areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
+         areConversionCompatible(dataLayout, slot.elemType,
+                                 getValue().getType()) &&
          !getVolatile_();
 }
 
 DeletionKind LLVM::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
@@ -747,8 +840,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
   return getDst() == slot.ptr;
 }
 
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   // TODO: Support non-integer types.
   return TypeSwitch<Type, Value>(slot.elemType)
       .Case([&](IntegerType intType) -> Value {
@@ -802,7 +895,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemsetOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
@@ -1059,8 +1153,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
   return memcpyStoresTo(*this, slot);
 }
 
-Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1074,7 +1168,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemcpyOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
@@ -1109,7 +1204,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
 }
 
 Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
-                                      RewriterBase &rewriter) {
+                                      RewriterBase &rewriter,
+                                      const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1123,7 +1219,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
@@ -1159,8 +1256,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
   return memcpyStoresTo(*this, slot);
 }
 
-Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
-                                 RewriterBase &rewriter) {
+Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 const DataLayout &dataLayout) {
   return memcpyGetStored(*this, slot, rewriter);
 }
 
@@ -1174,7 +1271,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
 
 DeletionKind LLVM::MemmoveOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
                                   reachingDefinition);
 }
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 6c5250d527ade8..ebbf20f1b76b67 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
 
 bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
 
-Value memref::LoadOp::getStored(const MemorySlot &slot,
-                                RewriterBase &rewriter) {
+Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                const DataLayout &dataLayout) {
   llvm_unreachable("getStored should not be called on LoadOp");
 }
 
@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
 
 DeletionKind memref::LoadOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
   rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
   return getMemRef() == slot.ptr;
 }
 
-Value memref::StoreOp::getStored(const MemorySlot &slot,
-                                 RewriterBase &rewriter) {
+Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+                                 const DataLayout &dataLayout) {
   return getValue();
 }
 
@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
 
 DeletionKind memref::StoreOp::removeBlockingUses(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
-    RewriterBase &rewriter, Value reachingDefinition) {
+    RewriterBase &rewriter, Value reachingDefinition,
+    const DataLayout &dataLayout) {
   return DeletionKind::Delete;
 }
 
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index abe565ea862f8f..1e620e46af84ea 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -165,7 +165,7 @@ class MemorySlotPromoter {
 public:
   MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
                      RewriterBase &rewriter, DominanceInfo &dominance,
-                     MemorySlotPromotionInfo info,
+                     const DataLayout &dataLayout, MemorySlotPromotionInfo info,
                      const Mem2RegStatistics &statistics);
 
   /// Actually promotes the slot by mutating IR. Promoting a slot DOES
@@ -204,6 +204,7 @@ class MemorySlotPromoter {
   DenseMap<PromotableMemOpInterface, Value> reachingDefs;
   DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
   DominanceInfo &dominance;
+  const DataLayout &dataLayout;
   MemorySlotPromotionInfo info;
   const Mem2RegStatistics &statistics;
 };
@@ -213,9 +214,11 @@ class MemorySlotPromoter {
 MemorySlotPromoter::MemorySlotPromoter(
     MemorySlot slot, PromotableAllocationOpInterface allocator,
     RewriterBase &rewriter, DominanceInfo &dominance,
-    MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+    const DataLayout &dataLayout, MemorySlotPromotionInfo info,
+    const Mem2RegStatistics &statistics)
     : slot(slot), allocator(allocator), rewriter(rewriter),
-      dominance(dominance), info(std::move(info)), statistics(statistics) {
+      dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
+      statistics(statistics) {
 #ifndef NDEBUG
   auto isResultOrNewBlockArgument = [&]() {
     if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -435,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
 
       if (memOp.storesTo(slot)) {
         rewriter.setInsertionPointAfter(memOp);
-        Value stored = memOp.getStored(slot, rewriter);
+        Value stored = memOp.getStored(slot, rewriter, dataLayout);
         assert(stored && "a memory operation storing to a slot must provide a "
                          "new definition of the slot");
         reachingDef = stored;
@@ -568,8 +571,8 @@ void MemorySlotPromoter::removeBlockingUses() {
 
       rewriter.setInsertionPointAfter(toPromote);
       if (toPromoteMemOp.removeBlockingUses(
-              slot, info.userToBlockingUses[toPromote], rewriter,
-              reachingDef) == DeletionKind::Delete)
+              slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
+              dataLayout) == DeletionKind::Delete)
         toErase.push_back(toPromote);
       if (toPromoteMemOp.storesTo(slot))
         if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
@@ -642,7 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
       MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
       std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
       if (info) {
-        MemorySlotPromoter(slot, allocator, rewriter, dominance,
+        MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
                            std::move(*info), statistics)
             .promoteSlot();
         promotedAny = true;
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index fa5d842302d0f4..e724c2e8679501 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -448,19 +448,6 @@ llvm.func @store_load_forward() -> i32 {
 
 // -----
 
-// CHECK-LABEL: llvm.func @store_load_wrong_type
-llvm.func @store_load_wrong_type() -> i16 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  %1 = llvm.mlir.constant(0 : i32) : i32
-  // CHECK: = llvm.alloca
-  %2 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
-  %3 = llvm.load %2 {alignment = 2 : i64} : !llvm.ptr -> i16
-  llvm.return %3 : i16
-}
-
-// -----
-
 // CHECK-LABEL: llvm.func @merge_point_cycle
 llvm.func @merge_point_cycle() {
   // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : i32
@@ -894,7 +881,7 @@ llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -
 // CHECK-LABEL: @load_smaller_int
 llvm.func @load_smaller_int() -> i16 {
   %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
+  // CHECK-NOT: llvm.alloca
   %1 = ...
[truncated]

@Dinistro Dinistro force-pushed the users/dinistro/mem2reg-allow-smaller-accesses branch from 3193204 to a0a1619 Compare April 17, 2024 16:10
Copy link
Member

@Moxinilian Moxinilian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. It's a bit of a shame that the most basic interfaces are becoming a bit complex due to arguably very LLVM-centric problems. Do you think there is a way to keep the simple API while still allowing this somehow?

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo nits!

@gysit
Copy link
Contributor

gysit commented Apr 17, 2024

It's a bit of a shame that the most basic interfaces are becoming a bit complex due to arguably very LLVM-centric problems. Do you think there is a way to keep the simple API while still allowing this somehow?

The data layout may be relevant for other dialects as well though?

@Moxinilian
Copy link
Member

Yeah most likely. I just mean that using it seems like an edge case (I don't expect it to be super frequent to have loads and stores on the same pointers with different types, right? This seems to only be a concern for dialects at a similar level of abstraction than LLVM). It would be nice if the (I expect) more common case where memory operations are well behaved did not have to bother. But it's only a small thing.

@Dinistro
Copy link
Contributor Author

Thanks for the reviews, will address the commenta tomorrow.
Regarding the generality issue: I do sadly not have a concrete idea on how to simplify the interfaces. Note that one does not need to use the provided parameters, though.

@Dinistro
Copy link
Contributor Author

We discussed the data layout parameter a bit more: We considered switching back to using the DataLayout::closest function, but this kills all the caching benefits and forces every user to create a somewhat expensive object. Therefore, we will stick with the current design for now, and might consider changing it when we have a better idea, or someone else has a proposition.

@Dinistro Dinistro merged commit ac39fa7 into main Apr 18, 2024
@Dinistro Dinistro deleted the users/dinistro/mem2reg-allow-smaller-accesses branch April 18, 2024 11:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants