-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Mem2Reg][LLVM] Enhance partial load support #89094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences.
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Christian Ulmann (Dinistro) ChangesThis commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences. Note: There will be additional followups that introduce a similar functionality for stores, and there are plans to support accesses into the middle of memory slots. Patch is 24.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89094.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 9db89361c78002..8c642c0ed26aca 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
}],
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
- "::mlir::RewriterBase &":$rewriter)
+ "::mlir::RewriterBase &":$rewriter,
+ "const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
(ins "const ::mlir::MemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
"::mlir::RewriterBase &":$rewriter,
- "::mlir::Value":$reachingDefinition)
+ "::mlir::Value":$reachingDefinition,
+ "const ::mlir::DataLayout &":$dataLayout)
>,
];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index c7ca0b4a5843ad..0c4d019f5654ac 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -122,37 +123,124 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}
-/// Checks that two types are the same or can be cast into one another.
-static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
- return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
- !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
- layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+/// Checks if `type` can be used in any kind of conversion sequences.
+static bool isSupportedTypeForConversion(Type type) {
+ // Aggregate types are not bitcastable.
+ if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
+ return false;
+
+ // LLVM vector types are only used for either pointers or target specific
+ // types. These types cannot be casted in the general case, thus the memory
+ // optimizations do not support them.
+ if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
+ return false;
+
+ // Scalable types are not supported.
+ if (auto vectorType = dyn_cast<VectorType>(type))
+ return !vectorType.isScalable();
+ return true;
+}
+
+/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
+/// truncations.
+static bool areConversionCompatible(const DataLayout &layout, Type lhs,
+ Type rhs) {
+ if (lhs == rhs)
+ return true;
+
+ // Aggregate types cannot be casted.
+ if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
+ return false;
+ return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
}
+/// Checks if `dataLayout` describes a little endian layout.
+static bool isLittleEndian(const DataLayout &dataLayout) {
+ auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
+ return !endiannessStr || endiannessStr == "little";
+}
+
+/// The size of a byte in bits.
+constexpr const static uint64_t kBitsInByte = 8;
+
/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
- Value inputValue, Type targetType) {
- if (inputValue.getType() == targetType)
- return inputValue;
+ Value srcValue, Type targetType,
+ const DataLayout &dataLayout) {
+ // Get the types of the source and destination values.
+ Type srcType = srcValue.getType();
+
+ uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
+ uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
+
+ // Nothing has to be done if the types are already the same.
+ if (srcType == targetType)
+ return srcValue;
+
+ // The code below is currently not capable of handling aggregate types as it
+ // makes use of bitcasts. Aggregates cannot be bitcast.
+ // TODO: We should have a `LLVMAggregateType` base class to easily perform
+ // this `isa`.
+ if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
+ isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
+ return nullptr;
+
+ // In the special case of casting one pointer to another, we want to generate
+ // an address space cast. Bitcasts of pointers are not allowed and using
+ // pointer to integer conversions are not equivalent due to the loss or
+ // provenance.
+ if (isa<LLVM::LLVMPointerType>(targetType) &&
+ isa<LLVM::LLVMPointerType>(srcType)) {
+ // Abort the conversion if the pointers have different bitwidths.
+ if (srcTypeSize != targetTypeSize)
+ return nullptr;
+ return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+ srcValue);
+ }
- if (!isa<LLVM::LLVMPointerType>(targetType) &&
- !isa<LLVM::LLVMPointerType>(inputValue.getType()))
- return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+ IntegerType valueSizeInteger =
+ rewriter.getIntegerType(srcTypeSize * kBitsInByte);
+ Value replacement = srcValue;
+
+ // First, cast the value to a same-sized integer type.
+ if (isa<LLVM::LLVMPointerType>(srcType))
+ replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
+ replacement);
+ else if (replacement.getType() != valueSizeInteger)
+ replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
+ replacement);
+
+ // Truncate the integer if the size of the read is less than the value.
+ if (targetTypeSize != srcTypeSize) {
+ if (!isLittleEndian(dataLayout)) {
+ uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
+ auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+ replacement =
+ rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+ }
- if (!isa<LLVM::LLVMPointerType>(targetType))
- return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+ replacement = rewriter.create<LLVM::TruncOp>(
+ loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
+ replacement);
+ }
- if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
- return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+ // Now cast the integer to the actual destination type if required.
+ if (isa<LLVM::LLVMPointerType>(targetType))
+ replacement =
+ rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
+ else if (replacement.getType() != targetType)
+ replacement =
+ rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
- return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
- inputValue);
+ return replacement;
}
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
- return createConversionSequence(rewriter, getLoc(), getValue(),
- slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
+ dataLayout);
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +255,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
- areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+ areConversionCompatible(dataLayout, getResult().getType(),
+ slot.elemType) &&
!getVolatile_();
}
DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- Value newResult = createConversionSequence(
- rewriter, getLoc(), reachingDefinition, getResult().getType());
+ Value newResult =
+ createConversionSequence(rewriter, getLoc(), reachingDefinition,
+ getResult().getType(), dataLayout);
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
@@ -194,13 +285,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
- areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
+ areConversionCompatible(dataLayout, slot.elemType,
+ getValue().getType()) &&
!getVolatile_();
}
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
@@ -747,8 +840,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
@@ -802,7 +895,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
@@ -1059,8 +1153,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
-Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1074,7 +1168,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
DeletionKind LLVM::MemcpyOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
@@ -1109,7 +1204,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1123,7 +1219,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
@@ -1159,8 +1256,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
-Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1174,7 +1271,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
DeletionKind LLVM::MemmoveOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 6c5250d527ade8..ebbf20f1b76b67 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
-Value memref::LoadOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
DeletionKind memref::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
return getMemRef() == slot.ptr;
}
-Value memref::StoreOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return getValue();
}
@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
DeletionKind memref::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index abe565ea862f8f..1e620e46af84ea 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -165,7 +165,7 @@ class MemorySlotPromoter {
public:
MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
RewriterBase &rewriter, DominanceInfo &dominance,
- MemorySlotPromotionInfo info,
+ const DataLayout &dataLayout, MemorySlotPromotionInfo info,
const Mem2RegStatistics &statistics);
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
@@ -204,6 +204,7 @@ class MemorySlotPromoter {
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
DominanceInfo &dominance;
+ const DataLayout &dataLayout;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
};
@@ -213,9 +214,11 @@ class MemorySlotPromoter {
MemorySlotPromoter::MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
RewriterBase &rewriter, DominanceInfo &dominance,
- MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+ const DataLayout &dataLayout, MemorySlotPromotionInfo info,
+ const Mem2RegStatistics &statistics)
: slot(slot), allocator(allocator), rewriter(rewriter),
- dominance(dominance), info(std::move(info)), statistics(statistics) {
+ dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
+ statistics(statistics) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -435,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
if (memOp.storesTo(slot)) {
rewriter.setInsertionPointAfter(memOp);
- Value stored = memOp.getStored(slot, rewriter);
+ Value stored = memOp.getStored(slot, rewriter, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
@@ -568,8 +571,8 @@ void MemorySlotPromoter::removeBlockingUses() {
rewriter.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
- slot, info.userToBlockingUses[toPromote], rewriter,
- reachingDef) == DeletionKind::Delete)
+ slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
+ dataLayout) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteMemOp.storesTo(slot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
@@ -642,7 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
- MemorySlotPromoter(slot, allocator, rewriter, dominance,
+ MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
std::move(*info), statistics)
.promoteSlot();
promotedAny = true;
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index fa5d842302d0f4..e724c2e8679501 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -448,19 +448,6 @@ llvm.func @store_load_forward() -> i32 {
// -----
-// CHECK-LABEL: llvm.func @store_load_wrong_type
-llvm.func @store_load_wrong_type() -> i16 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- %1 = llvm.mlir.constant(0 : i32) : i32
- // CHECK: = llvm.alloca
- %2 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
- %3 = llvm.load %2 {alignment = 2 : i64} : !llvm.ptr -> i16
- llvm.return %3 : i16
-}
-
-// -----
-
// CHECK-LABEL: llvm.func @merge_point_cycle
llvm.func @merge_point_cycle() {
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : i32
@@ -894,7 +881,7 @@ llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -
// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.alloca
+ // CHECK-NOT: llvm.alloca
%1 = ...
[truncated]
|
@llvm/pr-subscribers-mlir-llvm Author: Christian Ulmann (Dinistro) ChangesThis commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences. Note: There will be additional followups that introduce a similar functionality for stores, and there are plans to support accesses into the middle of memory slots. Patch is 24.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89094.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 9db89361c78002..8c642c0ed26aca 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
}],
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
- "::mlir::RewriterBase &":$rewriter)
+ "::mlir::RewriterBase &":$rewriter,
+ "const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
(ins "const ::mlir::MemorySlot &":$slot,
"const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
"::mlir::RewriterBase &":$rewriter,
- "::mlir::Value":$reachingDefinition)
+ "::mlir::Value":$reachingDefinition,
+ "const ::mlir::DataLayout &":$dataLayout)
>,
];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index c7ca0b4a5843ad..0c4d019f5654ac 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
-Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
+Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -122,37 +123,124 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
return getAddr() == slot.ptr;
}
-/// Checks that two types are the same or can be cast into one another.
-static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
- return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
- !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
- layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+/// Checks if `type` can be used in any kind of conversion sequences.
+static bool isSupportedTypeForConversion(Type type) {
+ // Aggregate types are not bitcastable.
+ if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
+ return false;
+
+ // LLVM vector types are only used for either pointers or target specific
+ // types. These types cannot be casted in the general case, thus the memory
+ // optimizations do not support them.
+ if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
+ return false;
+
+ // Scalable types are not supported.
+ if (auto vectorType = dyn_cast<VectorType>(type))
+ return !vectorType.isScalable();
+ return true;
+}
+
+/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
+/// truncations.
+static bool areConversionCompatible(const DataLayout &layout, Type lhs,
+ Type rhs) {
+ if (lhs == rhs)
+ return true;
+
+ // Aggregate types cannot be casted.
+ if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
+ return false;
+ return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
}
+/// Checks if `dataLayout` describes a little endian layout.
+static bool isLittleEndian(const DataLayout &dataLayout) {
+ auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
+ return !endiannessStr || endiannessStr == "little";
+}
+
+/// The size of a byte in bits.
+constexpr const static uint64_t kBitsInByte = 8;
+
/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
- Value inputValue, Type targetType) {
- if (inputValue.getType() == targetType)
- return inputValue;
+ Value srcValue, Type targetType,
+ const DataLayout &dataLayout) {
+ // Get the types of the source and destination values.
+ Type srcType = srcValue.getType();
+
+ uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
+ uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
+
+ // Nothing has to be done if the types are already the same.
+ if (srcType == targetType)
+ return srcValue;
+
+ // The code below is currently not capable of handling aggregate types as it
+ // makes use of bitcasts. Aggregates cannot be bitcast.
+ // TODO: We should have a `LLVMAggregateType` base class to easily perform
+ // this `isa`.
+ if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
+ isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
+ return nullptr;
+
+ // In the special case of casting one pointer to another, we want to generate
+ // an address space cast. Bitcasts of pointers are not allowed and using
+ // pointer to integer conversions are not equivalent due to the loss or
+ // provenance.
+ if (isa<LLVM::LLVMPointerType>(targetType) &&
+ isa<LLVM::LLVMPointerType>(srcType)) {
+ // Abort the conversion if the pointers have different bitwidths.
+ if (srcTypeSize != targetTypeSize)
+ return nullptr;
+ return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+ srcValue);
+ }
- if (!isa<LLVM::LLVMPointerType>(targetType) &&
- !isa<LLVM::LLVMPointerType>(inputValue.getType()))
- return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+ IntegerType valueSizeInteger =
+ rewriter.getIntegerType(srcTypeSize * kBitsInByte);
+ Value replacement = srcValue;
+
+ // First, cast the value to a same-sized integer type.
+ if (isa<LLVM::LLVMPointerType>(srcType))
+ replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
+ replacement);
+ else if (replacement.getType() != valueSizeInteger)
+ replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
+ replacement);
+
+ // Truncate the integer if the size of the read is less than the value.
+ if (targetTypeSize != srcTypeSize) {
+ if (!isLittleEndian(dataLayout)) {
+ uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
+ auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
+ loc, rewriter.getIntegerAttr(srcType, shiftAmount));
+ replacement =
+ rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
+ }
- if (!isa<LLVM::LLVMPointerType>(targetType))
- return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+ replacement = rewriter.create<LLVM::TruncOp>(
+ loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
+ replacement);
+ }
- if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
- return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+ // Now cast the integer to the actual destination type if required.
+ if (isa<LLVM::LLVMPointerType>(targetType))
+ replacement =
+ rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
+ else if (replacement.getType() != targetType)
+ replacement =
+ rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
- return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
- inputValue);
+ return replacement;
}
-Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
- return createConversionSequence(rewriter, getLoc(), getValue(),
- slot.elemType);
+Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
+ return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
+ dataLayout);
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +255,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
- areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+ areConversionCompatible(dataLayout, getResult().getType(),
+ slot.elemType) &&
!getVolatile_();
}
DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- Value newResult = createConversionSequence(
- rewriter, getLoc(), reachingDefinition, getResult().getType());
+ Value newResult =
+ createConversionSequence(rewriter, getLoc(), reachingDefinition,
+ getResult().getType(), dataLayout);
rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
@@ -194,13 +285,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
- areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
+ areConversionCompatible(dataLayout, slot.elemType,
+ getValue().getType()) &&
!getVolatile_();
}
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
@@ -747,8 +840,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
return getDst() == slot.ptr;
}
-Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
.Case([&](IntegerType intType) -> Value {
@@ -802,7 +895,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
DeletionKind LLVM::MemsetOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
@@ -1059,8 +1153,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
-Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1074,7 +1168,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
DeletionKind LLVM::MemcpyOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
@@ -1109,7 +1204,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+ RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1123,7 +1219,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
@@ -1159,8 +1256,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
return memcpyStoresTo(*this, slot);
}
-Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1174,7 +1271,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
DeletionKind LLVM::MemmoveOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
reachingDefinition);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index 6c5250d527ade8..ebbf20f1b76b67 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
-Value memref::LoadOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
DeletionKind memref::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
return getMemRef() == slot.ptr;
}
-Value memref::StoreOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter) {
+Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ const DataLayout &dataLayout) {
return getValue();
}
@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
DeletionKind memref::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
- RewriterBase &rewriter, Value reachingDefinition) {
+ RewriterBase &rewriter, Value reachingDefinition,
+ const DataLayout &dataLayout) {
return DeletionKind::Delete;
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index abe565ea862f8f..1e620e46af84ea 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -165,7 +165,7 @@ class MemorySlotPromoter {
public:
MemorySlotPromoter(MemorySlot slot, PromotableAllocationOpInterface allocator,
RewriterBase &rewriter, DominanceInfo &dominance,
- MemorySlotPromotionInfo info,
+ const DataLayout &dataLayout, MemorySlotPromotionInfo info,
const Mem2RegStatistics &statistics);
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
@@ -204,6 +204,7 @@ class MemorySlotPromoter {
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
DominanceInfo &dominance;
+ const DataLayout &dataLayout;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
};
@@ -213,9 +214,11 @@ class MemorySlotPromoter {
MemorySlotPromoter::MemorySlotPromoter(
MemorySlot slot, PromotableAllocationOpInterface allocator,
RewriterBase &rewriter, DominanceInfo &dominance,
- MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
+ const DataLayout &dataLayout, MemorySlotPromotionInfo info,
+ const Mem2RegStatistics &statistics)
: slot(slot), allocator(allocator), rewriter(rewriter),
- dominance(dominance), info(std::move(info)), statistics(statistics) {
+ dominance(dominance), dataLayout(dataLayout), info(std::move(info)),
+ statistics(statistics) {
#ifndef NDEBUG
auto isResultOrNewBlockArgument = [&]() {
if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr))
@@ -435,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
if (memOp.storesTo(slot)) {
rewriter.setInsertionPointAfter(memOp);
- Value stored = memOp.getStored(slot, rewriter);
+ Value stored = memOp.getStored(slot, rewriter, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
@@ -568,8 +571,8 @@ void MemorySlotPromoter::removeBlockingUses() {
rewriter.setInsertionPointAfter(toPromote);
if (toPromoteMemOp.removeBlockingUses(
- slot, info.userToBlockingUses[toPromote], rewriter,
- reachingDef) == DeletionKind::Delete)
+ slot, info.userToBlockingUses[toPromote], rewriter, reachingDef,
+ dataLayout) == DeletionKind::Delete)
toErase.push_back(toPromote);
if (toPromoteMemOp.storesTo(slot))
if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
@@ -642,7 +645,7 @@ LogicalResult mlir::tryToPromoteMemorySlots(
MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
- MemorySlotPromoter(slot, allocator, rewriter, dominance,
+ MemorySlotPromoter(slot, allocator, rewriter, dominance, dataLayout,
std::move(*info), statistics)
.promoteSlot();
promotedAny = true;
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index fa5d842302d0f4..e724c2e8679501 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -448,19 +448,6 @@ llvm.func @store_load_forward() -> i32 {
// -----
-// CHECK-LABEL: llvm.func @store_load_wrong_type
-llvm.func @store_load_wrong_type() -> i16 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- %1 = llvm.mlir.constant(0 : i32) : i32
- // CHECK: = llvm.alloca
- %2 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- llvm.store %1, %2 {alignment = 4 : i64} : i32, !llvm.ptr
- %3 = llvm.load %2 {alignment = 2 : i64} : !llvm.ptr -> i16
- llvm.return %3 : i16
-}
-
-// -----
-
// CHECK-LABEL: llvm.func @merge_point_cycle
llvm.func @merge_point_cycle() {
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : i32
@@ -894,7 +881,7 @@ llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -
// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.alloca
+ // CHECK-NOT: llvm.alloca
%1 = ...
[truncated]
|
3193204
to
a0a1619
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. It's a bit of a shame that the most basic interfaces are becoming a bit complex due to arguably very LLVM-centric problems. Do you think there is a way to keep the simple API while still allowing this somehow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo nits!
The data layout may be relevant for other dialects as well though? |
Yeah most likely. I just mean that using it seems like an edge case (I don't expect it to be super frequent to have loads and stores on the same pointers with different types, right? This seems to only be a concern for dialects at a similar level of abstraction than LLVM). It would be nice if the (I expect) more common case where memory operations are well behaved did not have to bother. But it's only a small thing. |
Thanks for the reviews, will address the commenta tomorrow. |
We discussed the data layout parameter a bit more: We considered switching back to using the |
This commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences.
Note: There will be additional followups that introduce a similar functionality for stores, and there are plans to support accesses into the middle of memory slots.