-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][LLVM][Mem2Reg] Extends support for partial stores #89740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Christian Ulmann (Dinistro) ChangesThis commit enhances the LLVM dialect's Mem2Reg interfaces to support partial stores to memory slots. To achieve this support, the Full diff: https://github.com/llvm/llvm-project/pull/89740.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
index 8c642c0ed26aca..764fa6d547b2eb 100644
--- a/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
+++ b/mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
@@ -128,6 +128,7 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
"::mlir::Value", "getStored",
(ins "const ::mlir::MemorySlot &":$slot,
"::mlir::RewriterBase &":$rewriter,
+ "::mlir::Value":$reachingDef,
"const ::mlir::DataLayout &":$dataLayout)
>,
InterfaceMethod<[{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index f2ab3eae2c343e..230c7fe8001bc1 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -113,7 +113,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
- const DataLayout &dataLayout) {
+ Value reachingDef, const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -144,7 +144,7 @@ static bool isSupportedTypeForConversion(Type type) {
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
/// truncations.
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
- Type srcType) {
+ Type srcType, bool allowWidening = false) {
if (targetType == srcType)
return true;
@@ -158,7 +158,8 @@ static bool areConversionCompatible(const DataLayout &layout, Type targetType,
isa<LLVM::LLVMPointerType>(srcType))
return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
- return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
+ return allowWidening ||
+ layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
}
/// Checks if `dataLayout` describes a little endian layout.
@@ -170,6 +171,35 @@ static bool isBigEndian(const DataLayout &dataLayout) {
/// The size of a byte in bits.
constexpr const static uint64_t kBitsInByte = 8;
+/// Converts a value to an integer type of the same size.
+/// Assumes that the type can be converted.
+static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
+ const DataLayout &dataLayout) {
+ Type type = val.getType();
+ assert(isSupportedTypeForConversion(type));
+
+ if (isa<IntegerType>(type))
+ return val;
+
+ uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
+ IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
+
+ if (isa<LLVM::LLVMPointerType>(type))
+ return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
+ return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
+}
+
+/// Converts an value with an integer type to `targetType`.
+static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
+ Value val, Type targetType) {
+ assert(isa<IntegerType>(val.getType()));
+ if (val.getType() == targetType)
+ return val;
+ if (isa<LLVM::LLVMPointerType>(targetType))
+ return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
+ return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
+}
+
/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
@@ -196,17 +226,8 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
srcValue);
- IntegerType valueSizeInteger =
- rewriter.getIntegerType(srcTypeSize * kBitsInByte);
- Value replacement = srcValue;
-
// First, cast the value to a same-sized integer type.
- if (isa<LLVM::LLVMPointerType>(srcType))
- replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
- replacement);
- else if (replacement.getType() != valueSizeInteger)
- replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
- replacement);
+ Value replacement = convertToIntValue(rewriter, loc, srcValue, dataLayout);
// Truncate the integer if the size of the target is less than the value.
if (targetTypeSize != srcTypeSize) {
@@ -224,20 +245,67 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
}
// Now cast the integer to the actual target type if required.
- if (isa<LLVM::LLVMPointerType>(targetType))
- replacement =
- rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
- else if (replacement.getType() != targetType)
- replacement =
- rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
-
- return replacement;
+ return convertIntValueToType(rewriter, loc, replacement, targetType);
}
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
- return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
- dataLayout);
+ uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(getValue().getType());
+ uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(slot.elemType);
+ if (slotTypeSize <= valueTypeSize)
+ return createConversionSequence(rewriter, getLoc(), getValue(),
+ slot.elemType, dataLayout);
+
+ assert(reachingDef && reachingDef.getType() == slot.elemType &&
+ "expected the reaching definition's type to slot's type");
+
+ // In the case where the store only overwrites parts of the memory,
+ // bit fiddling is required to construct the new value.
+
+ // First convert both values to integers of the same size.
+ Value defAsInt =
+ convertToIntValue(rewriter, getLoc(), reachingDef, dataLayout);
+ Value valueAsInt =
+ convertToIntValue(rewriter, getLoc(), getValue(), dataLayout);
+ // Extend the value to the size of the reaching definition.
+ valueAsInt = rewriter.createOrFold<LLVM::ZExtOp>(getLoc(), defAsInt.getType(),
+ valueAsInt);
+ uint64_t sizeDifference = slotTypeSize - valueTypeSize;
+ if (isBigEndian(dataLayout)) {
+ // On big endian systems, a store to the base pointer overwrites the most
+ // significant bits. To accomodate for this, the stored value needs to be
+ // shifted into the according position.
+ Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
+ getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
+ valueAsInt = rewriter.createOrFold<LLVM::ShlOp>(getLoc(), valueAsInt,
+ bigEndianShift);
+ }
+
+ // Construct the mask that is used to erase the bits that are overwritten by
+ // the store.
+ APInt maskValue;
+ if (isBigEndian(dataLayout)) {
+ // Build a mask that has the most significant bits set to zero.
+ // Note: This is the same as 2^sizeDifference - 1
+ maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
+ } else {
+ // Build a mask that has the least significant bits set to zero.
+ // Note: This is the same as -(2^valueTypeSize)
+ maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
+ maskValue.flipAllBits();
+ }
+
+ // Mask out the affected bits ...
+ Value mask = rewriter.create<LLVM::ConstantOp>(
+ getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
+ Value masked = rewriter.createOrFold<LLVM::AndOp>(getLoc(), defAsInt, mask);
+
+ // ... and combine the result with the new value.
+ Value combined =
+ rewriter.createOrFold<LLVM::OrOp>(getLoc(), masked, valueAsInt);
+
+ return convertIntValueToType(rewriter, getLoc(), combined, slot.elemType);
}
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -283,7 +351,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr &&
areConversionCompatible(dataLayout, slot.elemType,
- getValue().getType()) &&
+ getValue().getType(),
+ /*allowWidening=*/true) &&
!getVolatile_();
}
@@ -838,6 +907,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
// TODO: Support non-integer types.
return TypeSwitch<Type, Value>(slot.elemType)
@@ -1149,6 +1219,7 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1199,7 +1270,7 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
- RewriterBase &rewriter,
+ RewriterBase &rewriter, Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
@@ -1252,6 +1323,7 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
}
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return memcpyGetStored(*this, slot, rewriter);
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
index ebbf20f1b76b67..958c5f0c8dbc75 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
@@ -161,6 +161,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
llvm_unreachable("getStored should not be called on LoadOp");
}
@@ -242,6 +243,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
}
Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
+ Value reachingDef,
const DataLayout &dataLayout) {
return getValue();
}
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 0c1ce70f070852..d6881b600aea7b 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -438,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
if (memOp.storesTo(slot)) {
rewriter.setInsertionPointAfter(memOp);
- Value stored = memOp.getStored(slot, rewriter, dataLayout);
+ Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
@@ -452,6 +452,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
Value reachingDef) {
+ assert(reachingDef && "expected an initial reaching def to be provided");
if (region->hasOneBlock()) {
computeReachingDefInBlock(®ion->front(), reachingDef);
return;
@@ -508,12 +509,11 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
}
job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
+ assert(job.reachingDef);
if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
if (info.mergePoints.contains(blockOperand.get())) {
- if (!job.reachingDef)
- job.reachingDef = getLazyDefaultValue();
rewriter.modifyOpInPlace(terminator, [&]() {
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
.append(job.reachingDef);
@@ -601,7 +601,7 @@ void MemorySlotPromoter::removeBlockingUses() {
}
void MemorySlotPromoter::promoteSlot() {
- computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
+ computeReachingDefInRegion(slot.ptr.getParentRegion(), getLazyDefaultValue());
// Now that reaching definitions are known, remove all users.
removeBlockingUses();
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 644d30f9f9f133..130a8fce2def14 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -856,28 +856,6 @@ llvm.func @stores_with_different_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64
// -----
-// Verifies that stores with smaller bitsize inputs are not replaced. A trivial
-// implementation will be incorrect due to endianness considerations.
-
-// CHECK-LABEL: @stores_with_different_type_sizes
-llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
- %0 = llvm.mlir.constant(1 : i32) : i32
- // CHECK: llvm.alloca
- %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
- llvm.cond_br %cond, ^bb1, ^bb2
-^bb1:
- llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
- llvm.br ^bb3
-^bb2:
- llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
- llvm.br ^bb3
-^bb3:
- %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
- llvm.return %2 : f64
-}
-
-// -----
-
// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
@@ -1047,3 +1025,122 @@ llvm.func @scalable_llvm_vector() -> i16 {
%2 = llvm.load %1 : !llvm.ptr -> i16
llvm.return %2 : i16
}
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding
+// CHECK-SAME: %[[ARG:.+]]: i16
+llvm.func @smaller_store_forwarding(%arg : i16) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+ %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-65536 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ llvm.store %arg, %1 : i16, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+ // CHECK-LABEL: @smaller_store_forwarding_big_endian
+ // CHECK-SAME: %[[ARG:.+]]: i16
+ llvm.func @smaller_store_forwarding_big_endian(%arg : i16) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i32
+ %1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[ARG]] : i16 to i32
+ // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(65535 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+ llvm.store %arg, %1 : i16, !llvm.ptr
+ llvm.return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @smaller_store_forwarding_type_mix
+// CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+ %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+ // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-256 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+ llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+ llvm.return
+}
+
+// -----
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<"dlti.endianness", "big">
+>} {
+ // CHECK-LABEL: @smaller_store_forwarding_type_mix
+ // CHECK-SAME: %[[ARG:.+]]: vector<1xi8>
+ llvm.func @smaller_store_forwarding_type_mix(%arg : vector<1xi8>) {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : f32
+ %1 = llvm.alloca %0 x f32 : (i32) -> !llvm.ptr
+
+ // CHECK: %[[CASTED_DEF:.+]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+ // CHECK: %[[CASTED_ARG:.+]] = llvm.bitcast %[[ARG]] : vector<1xi8> to i8
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CASTED_ARG]] : i8 to i32
+ // CHECK: %[[SHIFT_WIDTH:.+]] = llvm.mlir.constant(24 : i32) : i32
+ // CHECK: %[[SHIFTED:.+]] = llvm.shl %[[ZEXT]], %[[SHIFT_WIDTH]]
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(16777215 : i32) : i32
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[CASTED_DEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[SHIFTED]]
+ // CHECK: %[[CASTED_NEW_DEF:.+]] = llvm.bitcast %[[NEW_DEF]] : i32 to f32
+ llvm.store %arg, %1 : vector<1xi8>, !llvm.ptr
+ llvm.return
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @stores_with_different_types_branches
+// CHECK-SAME: %[[ARG0:.+]]: i64
+// CHECK-SAME: %[[ARG1:.+]]: f32
+llvm.func @stores_with_different_types_branches(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ // CHECK: %[[UNDEF:.+]] = llvm.mlir.undef : i64
+ %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+ llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
+ // CHECK: llvm.br ^[[BB3:.+]](%[[ARG0]] : i64)
+ llvm.br ^bb3
+^bb2:
+ llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+ // CHECK: %[[CAST:.+]] = llvm.bitcast %[[ARG1]] : f32 to i32
+ // CHECK: %[[ZEXT:.+]] = llvm.zext %[[CAST]] : i32 to i64
+ // CHECK: %[[MASK:.+]] = llvm.mlir.constant(-4294967296 : i64) : i64
+ // CHECK: %[[MASKED:.+]] = llvm.and %[[UNDEF]], %[[MASK]]
+ // CHECK: %[[NEW_DEF:.+]] = llvm.or %[[MASKED]], %[[ZEXT]]
+ // CHECK: llvm.br ^[[BB3]](%[[NEW_DEF]] : i64)
+ llvm.br ^bb3
+^bb3:
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+ llvm.return %2 : f64
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I added some naming / nit comments and have some questions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think some of those changes start to touch where cost models matter. It would be nice to revisit once the work on cost models in MLIR gets there!
@@ -191,7 +191,7 @@ class MemorySlotPromoter { | |||
|
|||
/// Lazily-constructed default value representing the content of the slot when | |||
/// no store has been executed. This function may mutate IR. | |||
Value getLazyDefaultValue(); | |||
Value getOrCreateDefaultValue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't forget to update the doc comment for the defaultValue
field.
|
||
// CHECK-LABEL: @smaller_store_forwarding | ||
// CHECK-SAME: %[[ARG:.+]]: i16 | ||
llvm.func @smaller_store_forwarding(%arg : i16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know that little-endian is the default? On tests where that matters, I'd be in favor of also explicitly setting endianness as little, if not for correctness just for readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Little endian is the default. Currently, MLIR's data layout does not properly convey this information, but LLVM's does. I'll try to allocate time to address this matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! LGTM
} | ||
|
||
/// Constructs operations that convert `srcValue` into a new value of type | ||
/// `targetType`. Performs bitlevel extraction if the source type is larger than |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// `targetType`. Performs bitlevel extraction if the source type is larger than | |
/// `targetType`. Performs bit-level extraction if the source type is larger than |
nit: I believe it is two words
This commit enhances the LLVM dialect's Mem2Reg interfaces to support partial stores to memory slots. To achieve this support, the `getStored` interface method has to be extended with a parameter of the reaching definition, which is now necessary to produce the resulting value after this store.
e52a703
to
d97edec
Compare
Rebasing in the hope of unblocking the clang-format action issue |
This commit enhances the LLVM dialect's Mem2Reg interfaces to support partial stores to memory slots. To achieve this support, the
getStored
interface method has to be extended with a parameter of the reaching definition, which is now necessary to produce the resulting value after this store.