Skip to content

Commit 2a277b4

Browse files
committed
[MLIR][LLVM][Mem2Reg] Extends support for partial stores
This commit enhances the LLVM dialect's Mem2Reg interfaces to support partial stores to memory slots. To achieve this support, the `getStored` interface method has to be extended with a parameter of the reaching definition, which is now necessary to produce the resulting value after this store.
1 parent eaa2eac commit 2a277b4

File tree

5 files changed

+223
-51
lines changed

5 files changed

+223
-51
lines changed

mlir/include/mlir/Interfaces/MemorySlotInterfaces.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
128128
"::mlir::Value", "getStored",
129129
(ins "const ::mlir::MemorySlot &":$slot,
130130
"::mlir::RewriterBase &":$rewriter,
131+
"::mlir::Value":$reachingDef,
131132
"const ::mlir::DataLayout &":$dataLayout)
132133
>,
133134
InterfaceMethod<[{

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 97 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
113113
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
114114

115115
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
116-
const DataLayout &dataLayout) {
116+
Value reachingDef, const DataLayout &dataLayout) {
117117
llvm_unreachable("getStored should not be called on LoadOp");
118118
}
119119

@@ -144,7 +144,7 @@ static bool isSupportedTypeForConversion(Type type) {
144144
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145145
/// truncations.
146146
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
147-
Type srcType) {
147+
Type srcType, bool allowWidening = false) {
148148
if (targetType == srcType)
149149
return true;
150150

@@ -158,7 +158,8 @@ static bool areConversionCompatible(const DataLayout &layout, Type targetType,
158158
isa<LLVM::LLVMPointerType>(srcType))
159159
return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
160160

161-
return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
161+
return allowWidening ||
162+
layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
162163
}
163164

164165
/// Checks if `dataLayout` describes a little endian layout.
@@ -170,6 +171,35 @@ static bool isBigEndian(const DataLayout &dataLayout) {
170171
/// The size of a byte in bits.
171172
constexpr const static uint64_t kBitsInByte = 8;
172173

174+
/// Converts a value to an integer type of the same size.
175+
/// Assumes that the type can be converted.
176+
static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
177+
const DataLayout &dataLayout) {
178+
Type type = val.getType();
179+
assert(isSupportedTypeForConversion(type));
180+
181+
if (isa<IntegerType>(type))
182+
return val;
183+
184+
uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
185+
IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
186+
187+
if (isa<LLVM::LLVMPointerType>(type))
188+
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
189+
return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
190+
}
191+
192+
/// Converts an value with an integer type to `targetType`.
193+
static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
194+
Value val, Type targetType) {
195+
assert(isa<IntegerType>(val.getType()));
196+
if (val.getType() == targetType)
197+
return val;
198+
if (isa<LLVM::LLVMPointerType>(targetType))
199+
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
200+
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
201+
}
202+
173203
/// Constructs operations that convert `inputValue` into a new value of type
174204
/// `targetType`. Assumes that this conversion is possible.
175205
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
@@ -196,17 +226,8 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
196226
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
197227
srcValue);
198228

199-
IntegerType valueSizeInteger =
200-
rewriter.getIntegerType(srcTypeSize * kBitsInByte);
201-
Value replacement = srcValue;
202-
203229
// First, cast the value to a same-sized integer type.
204-
if (isa<LLVM::LLVMPointerType>(srcType))
205-
replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
206-
replacement);
207-
else if (replacement.getType() != valueSizeInteger)
208-
replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
209-
replacement);
230+
Value replacement = convertToIntValue(rewriter, loc, srcValue, dataLayout);
210231

211232
// Truncate the integer if the size of the target is less than the value.
212233
if (targetTypeSize != srcTypeSize) {
@@ -224,20 +245,67 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
224245
}
225246

226247
// Now cast the integer to the actual target type if required.
227-
if (isa<LLVM::LLVMPointerType>(targetType))
228-
replacement =
229-
rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
230-
else if (replacement.getType() != targetType)
231-
replacement =
232-
rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
233-
234-
return replacement;
248+
return convertIntValueToType(rewriter, loc, replacement, targetType);
235249
}
236250

237251
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
252+
Value reachingDef,
238253
const DataLayout &dataLayout) {
239-
return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
240-
dataLayout);
254+
uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(getValue().getType());
255+
uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(slot.elemType);
256+
if (slotTypeSize <= valueTypeSize)
257+
return createConversionSequence(rewriter, getLoc(), getValue(),
258+
slot.elemType, dataLayout);
259+
260+
assert(reachingDef && reachingDef.getType() == slot.elemType &&
261+
"expected the reaching definition's type to slot's type");
262+
263+
// In the case where the store only overwrites parts of the memory,
264+
// bit fiddling is required to construct the new value.
265+
266+
// First convert both values to integers of the same size.
267+
Value defAsInt =
268+
convertToIntValue(rewriter, getLoc(), reachingDef, dataLayout);
269+
Value valueAsInt =
270+
convertToIntValue(rewriter, getLoc(), getValue(), dataLayout);
271+
// Extend the value to the size of the reaching definition.
272+
valueAsInt = rewriter.createOrFold<LLVM::ZExtOp>(getLoc(), defAsInt.getType(),
273+
valueAsInt);
274+
uint64_t sizeDifference = slotTypeSize - valueTypeSize;
275+
if (isBigEndian(dataLayout)) {
276+
// On big endian systems, a store to the base pointer overwrites the most
277+
// significant bits. To accomodate for this, the stored value needs to be
278+
// shifted into the according position.
279+
Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
280+
getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
281+
valueAsInt = rewriter.createOrFold<LLVM::ShlOp>(getLoc(), valueAsInt,
282+
bigEndianShift);
283+
}
284+
285+
// Construct the mask that is used to erase the bits that are overwritten by
286+
// the store.
287+
APInt maskValue;
288+
if (isBigEndian(dataLayout)) {
289+
// Build a mask that has the most significant bits set to zero.
290+
// Note: This is the same as 2^sizeDifference - 1
291+
maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
292+
} else {
293+
// Build a mask that has the least significant bits set to zero.
294+
// Note: This is the same as -(2^valueTypeSize)
295+
maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
296+
maskValue.flipAllBits();
297+
}
298+
299+
// Mask out the affected bits ...
300+
Value mask = rewriter.create<LLVM::ConstantOp>(
301+
getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
302+
Value masked = rewriter.createOrFold<LLVM::AndOp>(getLoc(), defAsInt, mask);
303+
304+
// ... and combine the result with the new value.
305+
Value combined =
306+
rewriter.createOrFold<LLVM::OrOp>(getLoc(), masked, valueAsInt);
307+
308+
return convertIntValueToType(rewriter, getLoc(), combined, slot.elemType);
241309
}
242310

243311
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -283,7 +351,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
283351
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
284352
getValue() != slot.ptr &&
285353
areConversionCompatible(dataLayout, slot.elemType,
286-
getValue().getType()) &&
354+
getValue().getType(),
355+
/*allowWidening=*/true) &&
287356
!getVolatile_();
288357
}
289358

@@ -838,6 +907,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
838907
}
839908

840909
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
910+
Value reachingDef,
841911
const DataLayout &dataLayout) {
842912
// TODO: Support non-integer types.
843913
return TypeSwitch<Type, Value>(slot.elemType)
@@ -1149,6 +1219,7 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
11491219
}
11501220

11511221
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1222+
Value reachingDef,
11521223
const DataLayout &dataLayout) {
11531224
return memcpyGetStored(*this, slot, rewriter);
11541225
}
@@ -1199,7 +1270,7 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
11991270
}
12001271

12011272
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1202-
RewriterBase &rewriter,
1273+
RewriterBase &rewriter, Value reachingDef,
12031274
const DataLayout &dataLayout) {
12041275
return memcpyGetStored(*this, slot, rewriter);
12051276
}
@@ -1252,6 +1323,7 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
12521323
}
12531324

12541325
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1326+
Value reachingDef,
12551327
const DataLayout &dataLayout) {
12561328
return memcpyGetStored(*this, slot, rewriter);
12571329
}

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
161161
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
162162

163163
Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
164+
Value reachingDef,
164165
const DataLayout &dataLayout) {
165166
llvm_unreachable("getStored should not be called on LoadOp");
166167
}
@@ -242,6 +243,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
242243
}
243244

244245
Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
246+
Value reachingDef,
245247
const DataLayout &dataLayout) {
246248
return getValue();
247249
}

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
438438

439439
if (memOp.storesTo(slot)) {
440440
rewriter.setInsertionPointAfter(memOp);
441-
Value stored = memOp.getStored(slot, rewriter, dataLayout);
441+
Value stored = memOp.getStored(slot, rewriter, reachingDef, dataLayout);
442442
assert(stored && "a memory operation storing to a slot must provide a "
443443
"new definition of the slot");
444444
reachingDef = stored;
@@ -452,6 +452,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
452452

453453
void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
454454
Value reachingDef) {
455+
assert(reachingDef && "expected an initial reaching def to be provided");
455456
if (region->hasOneBlock()) {
456457
computeReachingDefInBlock(&region->front(), reachingDef);
457458
return;
@@ -508,12 +509,11 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
508509
}
509510

510511
job.reachingDef = computeReachingDefInBlock(block, job.reachingDef);
512+
assert(job.reachingDef);
511513

512514
if (auto terminator = dyn_cast<BranchOpInterface>(block->getTerminator())) {
513515
for (BlockOperand &blockOperand : terminator->getBlockOperands()) {
514516
if (info.mergePoints.contains(blockOperand.get())) {
515-
if (!job.reachingDef)
516-
job.reachingDef = getLazyDefaultValue();
517517
rewriter.modifyOpInPlace(terminator, [&]() {
518518
terminator.getSuccessorOperands(blockOperand.getOperandNumber())
519519
.append(job.reachingDef);
@@ -601,7 +601,7 @@ void MemorySlotPromoter::removeBlockingUses() {
601601
}
602602

603603
void MemorySlotPromoter::promoteSlot() {
604-
computeReachingDefInRegion(slot.ptr.getParentRegion(), {});
604+
computeReachingDefInRegion(slot.ptr.getParentRegion(), getLazyDefaultValue());
605605

606606
// Now that reaching definitions are known, remove all users.
607607
removeBlockingUses();

0 commit comments

Comments
 (0)