Skip to content

Commit 6e9ea6e

Browse files
authored
[MLIR][LLVM][Mem2Reg] Extends support for partial stores (#89740)
This commit enhances the LLVM dialect's Mem2Reg interfaces to support partial stores to memory slots. To achieve this support, the `getStored` interface method has to be extended with a parameter of the reaching definition, which is now necessary to produce the resulting value after this store.
1 parent 9f2a068 commit 6e9ea6e

File tree

5 files changed

+305
-87
lines changed

5 files changed

+305
-87
lines changed

mlir/include/mlir/Interfaces/MemorySlotInterfaces.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
128128
"::mlir::Value", "getStored",
129129
(ins "const ::mlir::MemorySlot &":$slot,
130130
"::mlir::RewriterBase &":$rewriter,
131+
"::mlir::Value":$reachingDef,
131132
"const ::mlir::DataLayout &":$dataLayout)
132133
>,
133134
InterfaceMethod<[{

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 159 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
113113
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
114114

115115
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
116-
const DataLayout &dataLayout) {
116+
Value reachingDef, const DataLayout &dataLayout) {
117117
llvm_unreachable("getStored should not be called on LoadOp");
118118
}
119119

@@ -142,23 +142,29 @@ static bool isSupportedTypeForConversion(Type type) {
142142
}
143143

144144
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145-
/// truncations.
145+
/// truncations. Checks for narrowing or widening conversion compatibility
146+
/// depending on `narrowingConversion`.
146147
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
147-
Type srcType) {
148+
Type srcType, bool narrowingConversion) {
148149
if (targetType == srcType)
149150
return true;
150151

151152
if (!isSupportedTypeForConversion(targetType) ||
152153
!isSupportedTypeForConversion(srcType))
153154
return false;
154155

156+
uint64_t targetSize = layout.getTypeSize(targetType);
157+
uint64_t srcSize = layout.getTypeSize(srcType);
158+
155159
// Pointer casts will only be sane when the bitsize of both pointer types is
156160
// the same.
157161
if (isa<LLVM::LLVMPointerType>(targetType) &&
158162
isa<LLVM::LLVMPointerType>(srcType))
159-
return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
163+
return targetSize == srcSize;
160164

161-
return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
165+
if (narrowingConversion)
166+
return targetSize <= srcSize;
167+
return targetSize >= srcSize;
162168
}
163169

164170
/// Checks if `dataLayout` describes a little endian layout.
@@ -167,22 +173,49 @@ static bool isBigEndian(const DataLayout &dataLayout) {
167173
return endiannessStr && endiannessStr == "big";
168174
}
169175

170-
/// The size of a byte in bits.
171-
constexpr const static uint64_t kBitsInByte = 8;
176+
/// Converts a value to an integer type of the same size.
177+
/// Assumes that the type can be converted.
178+
static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
179+
const DataLayout &dataLayout) {
180+
Type type = val.getType();
181+
assert(isSupportedTypeForConversion(type) &&
182+
"expected value to have a convertible type");
183+
184+
if (isa<IntegerType>(type))
185+
return val;
186+
187+
uint64_t typeBitSize = dataLayout.getTypeSizeInBits(type);
188+
IntegerType valueSizeInteger = rewriter.getIntegerType(typeBitSize);
189+
190+
if (isa<LLVM::LLVMPointerType>(type))
191+
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger, val);
192+
return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
193+
}
194+
195+
/// Converts a value with an integer type to `targetType`.
196+
static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
197+
Value val, Type targetType) {
198+
assert(isa<IntegerType>(val.getType()) &&
199+
"expected value to have an integer type");
200+
assert(isSupportedTypeForConversion(targetType) &&
201+
"expected the target type to be supported for conversions");
202+
if (val.getType() == targetType)
203+
return val;
204+
if (isa<LLVM::LLVMPointerType>(targetType))
205+
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
206+
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
207+
}
172208

173-
/// Constructs operations that convert `inputValue` into a new value of type
174-
/// `targetType`. Assumes that this conversion is possible.
175-
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
176-
Value srcValue, Type targetType,
177-
const DataLayout &dataLayout) {
178-
// Get the types of the source and target values.
209+
/// Constructs operations that convert `srcValue` into a new value of type
210+
/// `targetType`. Assumes the types have the same bitsize.
211+
static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
212+
Value srcValue, Type targetType,
213+
const DataLayout &dataLayout) {
179214
Type srcType = srcValue.getType();
180-
assert(areConversionCompatible(dataLayout, targetType, srcType) &&
215+
assert(areConversionCompatible(dataLayout, targetType, srcType,
216+
/*narrowingConversion=*/true) &&
181217
"expected that the compatibility was checked before");
182218

183-
uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
184-
uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
185-
186219
// Nothing has to be done if the types are already the same.
187220
if (srcType == targetType)
188221
return srcValue;
@@ -196,48 +229,117 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
196229
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
197230
srcValue);
198231

199-
IntegerType valueSizeInteger =
200-
rewriter.getIntegerType(srcTypeSize * kBitsInByte);
201-
Value replacement = srcValue;
232+
// For all other castable types, casting through integers is necessary.
233+
Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
234+
return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
235+
}
236+
237+
/// Constructs operations that convert `srcValue` into a new value of type
238+
/// `targetType`. Performs bit-level extraction if the source type is larger
239+
/// than the target type. Assumes that this conversion is possible.
240+
static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
241+
Value srcValue, Type targetType,
242+
const DataLayout &dataLayout) {
243+
// Get the types of the source and target values.
244+
Type srcType = srcValue.getType();
245+
assert(areConversionCompatible(dataLayout, targetType, srcType,
246+
/*narrowingConversion=*/true) &&
247+
"expected that the compatibility was checked before");
248+
249+
uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
250+
uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
251+
if (srcTypeSize == targetTypeSize)
252+
return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
202253

203254
// First, cast the value to a same-sized integer type.
204-
if (isa<LLVM::LLVMPointerType>(srcType))
205-
replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
206-
replacement);
207-
else if (replacement.getType() != valueSizeInteger)
208-
replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
209-
replacement);
255+
Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
210256

211257
// Truncate the integer if the size of the target is less than the value.
212-
if (targetTypeSize != srcTypeSize) {
213-
if (isBigEndian(dataLayout)) {
214-
uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
215-
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
216-
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
217-
replacement =
218-
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
219-
}
220-
221-
replacement = rewriter.create<LLVM::TruncOp>(
222-
loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
223-
replacement);
258+
if (isBigEndian(dataLayout)) {
259+
uint64_t shiftAmount = srcTypeSize - targetTypeSize;
260+
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
261+
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
262+
replacement =
263+
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
224264
}
225265

266+
replacement = rewriter.create<LLVM::TruncOp>(
267+
loc, rewriter.getIntegerType(targetTypeSize), replacement);
268+
226269
// Now cast the integer to the actual target type if required.
227-
if (isa<LLVM::LLVMPointerType>(targetType))
228-
replacement =
229-
rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
230-
else if (replacement.getType() != targetType)
231-
replacement =
232-
rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
270+
return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
271+
}
272+
273+
/// Constructs operations that insert the bits of `srcValue` into the
274+
/// "beginning" of `reachingDef` (beginning is endianness dependent).
275+
/// Assumes that this conversion is possible.
276+
static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
277+
Value srcValue, Value reachingDef,
278+
const DataLayout &dataLayout) {
279+
280+
assert(areConversionCompatible(dataLayout, reachingDef.getType(),
281+
srcValue.getType(),
282+
/*narrowingConversion=*/false) &&
283+
"expected that the compatibility was checked before");
284+
uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
285+
uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
286+
if (slotTypeSize == valueTypeSize)
287+
return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
288+
dataLayout);
289+
290+
// In the case where the store only overwrites parts of the memory,
291+
// bit fiddling is required to construct the new value.
292+
293+
// First convert both values to integers of the same size.
294+
Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
295+
Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
296+
// Extend the value to the size of the reaching definition.
297+
valueAsInt =
298+
rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
299+
uint64_t sizeDifference = slotTypeSize - valueTypeSize;
300+
if (isBigEndian(dataLayout)) {
301+
// On big endian systems, a store to the base pointer overwrites the most
302+
// significant bits. To accomodate for this, the stored value needs to be
303+
// shifted into the according position.
304+
Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
305+
loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
306+
valueAsInt =
307+
rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
308+
}
309+
310+
// Construct the mask that is used to erase the bits that are overwritten by
311+
// the store.
312+
APInt maskValue;
313+
if (isBigEndian(dataLayout)) {
314+
// Build a mask that has the most significant bits set to zero.
315+
// Note: This is the same as 2^sizeDifference - 1
316+
maskValue = APInt::getAllOnes(sizeDifference).zext(slotTypeSize);
317+
} else {
318+
// Build a mask that has the least significant bits set to zero.
319+
// Note: This is the same as -(2^valueTypeSize)
320+
maskValue = APInt::getAllOnes(valueTypeSize).zext(slotTypeSize);
321+
maskValue.flipAllBits();
322+
}
323+
324+
// Mask out the affected bits ...
325+
Value mask = rewriter.create<LLVM::ConstantOp>(
326+
loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
327+
Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
328+
329+
// ... and combine the result with the new value.
330+
Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
233331

234-
return replacement;
332+
return castIntValueToSameSizedType(rewriter, loc, combined,
333+
reachingDef.getType());
235334
}
236335

237336
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
337+
Value reachingDef,
238338
const DataLayout &dataLayout) {
239-
return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
240-
dataLayout);
339+
assert(reachingDef && reachingDef.getType() == slot.elemType &&
340+
"expected the reaching definition's type to match the slot's type");
341+
return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
342+
dataLayout);
241343
}
242344

243345
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -249,11 +351,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
249351
Value blockingUse = (*blockingUses.begin())->get();
250352
// If the blocking use is the slot ptr itself, there will be enough
251353
// context to reconstruct the result of the load at removal time, so it can
252-
// be removed (provided it loads the exact stored value and is not
253-
// volatile).
354+
// be removed (provided it is not volatile).
254355
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
255356
areConversionCompatible(dataLayout, getResult().getType(),
256-
slot.elemType) &&
357+
slot.elemType, /*narrowingConversion=*/true) &&
257358
!getVolatile_();
258359
}
259360

@@ -263,9 +364,8 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
263364
const DataLayout &dataLayout) {
264365
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
265366
// pointer.
266-
Value newResult =
267-
createConversionSequence(rewriter, getLoc(), reachingDefinition,
268-
getResult().getType(), dataLayout);
367+
Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
368+
getResult().getType(), dataLayout);
269369
rewriter.replaceAllUsesWith(getResult(), newResult);
270370
return DeletionKind::Delete;
271371
}
@@ -283,7 +383,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
283383
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
284384
getValue() != slot.ptr &&
285385
areConversionCompatible(dataLayout, slot.elemType,
286-
getValue().getType()) &&
386+
getValue().getType(),
387+
/*narrowingConversion=*/false) &&
287388
!getVolatile_();
288389
}
289390

@@ -838,6 +939,7 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
838939
}
839940

840941
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
942+
Value reachingDef,
841943
const DataLayout &dataLayout) {
842944
// TODO: Support non-integer types.
843945
return TypeSwitch<Type, Value>(slot.elemType)
@@ -1149,6 +1251,7 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
11491251
}
11501252

11511253
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1254+
Value reachingDef,
11521255
const DataLayout &dataLayout) {
11531256
return memcpyGetStored(*this, slot, rewriter);
11541257
}
@@ -1199,7 +1302,7 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
11991302
}
12001303

12011304
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1202-
RewriterBase &rewriter,
1305+
RewriterBase &rewriter, Value reachingDef,
12031306
const DataLayout &dataLayout) {
12041307
return memcpyGetStored(*this, slot, rewriter);
12051308
}
@@ -1252,6 +1355,7 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
12521355
}
12531356

12541357
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1358+
Value reachingDef,
12551359
const DataLayout &dataLayout) {
12561360
return memcpyGetStored(*this, slot, rewriter);
12571361
}

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
161161
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
162162

163163
Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
164+
Value reachingDef,
164165
const DataLayout &dataLayout) {
165166
llvm_unreachable("getStored should not be called on LoadOp");
166167
}
@@ -242,6 +243,7 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
242243
}
243244

244245
Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
246+
Value reachingDef,
245247
const DataLayout &dataLayout) {
246248
return getValue();
247249
}

0 commit comments

Comments
 (0)