Skip to content

Commit 531ce38

Browse files
committed
address review comments
1 parent 2a277b4 commit 531ce38

File tree

3 files changed

+122
-75
lines changed

3 files changed

+122
-75
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 103 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -142,24 +142,29 @@ static bool isSupportedTypeForConversion(Type type) {
142142
}
143143

144144
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145-
/// truncations.
145+
/// truncations. Checks for narrowing or widening conversion compatibility
146+
/// depending on `narrowingConversion`.
146147
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
147-
Type srcType, bool allowWidening = false) {
148+
Type srcType, bool narrowingConversion) {
148149
if (targetType == srcType)
149150
return true;
150151

151152
if (!isSupportedTypeForConversion(targetType) ||
152153
!isSupportedTypeForConversion(srcType))
153154
return false;
154155

156+
uint64_t targetSize = layout.getTypeSize(targetType);
157+
uint64_t srcSize = layout.getTypeSize(srcType);
158+
155159
// Pointer casts will only be sane when the bitsize of both pointer types is
156160
// the same.
157161
if (isa<LLVM::LLVMPointerType>(targetType) &&
158162
isa<LLVM::LLVMPointerType>(srcType))
159-
return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
163+
return targetSize == srcSize;
160164

161-
return allowWidening ||
162-
layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
165+
if (narrowingConversion)
166+
return targetSize <= srcSize;
167+
return targetSize >= srcSize;
163168
}
164169

165170
/// Checks if `dataLayout` describes a little endian layout.
@@ -168,15 +173,13 @@ static bool isBigEndian(const DataLayout &dataLayout) {
168173
return endiannessStr && endiannessStr == "big";
169174
}
170175

171-
/// The size of a byte in bits.
172-
constexpr const static uint64_t kBitsInByte = 8;
173-
174176
/// Converts a value to an integer type of the same size.
175177
/// Assumes that the type can be converted.
176-
static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
177-
const DataLayout &dataLayout) {
178+
static Value castToSameSizedInt(RewriterBase &rewriter, Location loc, Value val,
179+
const DataLayout &dataLayout) {
178180
Type type = val.getType();
179-
assert(isSupportedTypeForConversion(type));
181+
assert(isSupportedTypeForConversion(type) &&
182+
"expected value to have a convertible type");
180183

181184
if (isa<IntegerType>(type))
182185
return val;
@@ -189,30 +192,30 @@ static Value convertToIntValue(RewriterBase &rewriter, Location loc, Value val,
189192
return rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger, val);
190193
}
191194

192-
/// Converts an value with an integer type to `targetType`.
193-
static Value convertIntValueToType(RewriterBase &rewriter, Location loc,
194-
Value val, Type targetType) {
195-
assert(isa<IntegerType>(val.getType()));
195+
/// Converts a value with an integer type to `targetType`.
196+
static Value castIntValueToSameSizedType(RewriterBase &rewriter, Location loc,
197+
Value val, Type targetType) {
198+
assert(isa<IntegerType>(val.getType()) &&
199+
"expected value to have an integer type");
200+
assert(isSupportedTypeForConversion(targetType) &&
201+
"expected the target type to be supported for conversions");
196202
if (val.getType() == targetType)
197203
return val;
198204
if (isa<LLVM::LLVMPointerType>(targetType))
199205
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, val);
200206
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, val);
201207
}
202208

203-
/// Constructs operations that convert `inputValue` into a new value of type
204-
/// `targetType`. Assumes that this conversion is possible.
205-
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
206-
Value srcValue, Type targetType,
207-
const DataLayout &dataLayout) {
208-
// Get the types of the source and target values.
209+
/// Constructs operations that convert `srcValue` into a new value of type
210+
/// `targetType`. Assumes the types have the same bitsize.
211+
static Value castSameSizedTypes(RewriterBase &rewriter, Location loc,
212+
Value srcValue, Type targetType,
213+
const DataLayout &dataLayout) {
209214
Type srcType = srcValue.getType();
210-
assert(areConversionCompatible(dataLayout, targetType, srcType) &&
215+
assert(areConversionCompatible(dataLayout, targetType, srcType,
216+
/*narrowingConversion=*/true) &&
211217
"expected that the compatibility was checked before");
212218

213-
uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
214-
uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
215-
216219
// Nothing has to be done if the types are already the same.
217220
if (srcType == targetType)
218221
return srcValue;
@@ -226,60 +229,83 @@ static Value createConversionSequence(RewriterBase &rewriter, Location loc,
226229
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
227230
srcValue);
228231

232+
// For all other castable types, casting through integers is necessary.
233+
Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
234+
return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
235+
}
236+
237+
/// Constructs operations that convert `srcValue` into a new value of type
238+
/// `targetType`. Performs bitlevel extraction if the source type is larger than
239+
/// the target type.
240+
/// Assumes that this conversion is possible.
241+
static Value createExtractAndCast(RewriterBase &rewriter, Location loc,
242+
Value srcValue, Type targetType,
243+
const DataLayout &dataLayout) {
244+
// Get the types of the source and target values.
245+
Type srcType = srcValue.getType();
246+
assert(areConversionCompatible(dataLayout, targetType, srcType,
247+
/*narrowingConversion=*/true) &&
248+
"expected that the compatibility was checked before");
249+
250+
uint64_t srcTypeSize = dataLayout.getTypeSizeInBits(srcType);
251+
uint64_t targetTypeSize = dataLayout.getTypeSizeInBits(targetType);
252+
if (srcTypeSize == targetTypeSize)
253+
return castSameSizedTypes(rewriter, loc, srcValue, targetType, dataLayout);
254+
229255
// First, cast the value to a same-sized integer type.
230-
Value replacement = convertToIntValue(rewriter, loc, srcValue, dataLayout);
256+
Value replacement = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
231257

232258
// Truncate the integer if the size of the target is less than the value.
233-
if (targetTypeSize != srcTypeSize) {
234-
if (isBigEndian(dataLayout)) {
235-
uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
236-
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
237-
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
238-
replacement =
239-
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
240-
}
241-
242-
replacement = rewriter.create<LLVM::TruncOp>(
243-
loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
244-
replacement);
259+
if (isBigEndian(dataLayout)) {
260+
uint64_t shiftAmount = srcTypeSize - targetTypeSize;
261+
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
262+
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
263+
replacement =
264+
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
245265
}
246266

267+
replacement = rewriter.create<LLVM::TruncOp>(
268+
loc, rewriter.getIntegerType(targetTypeSize), replacement);
269+
247270
// Now cast the integer to the actual target type if required.
248-
return convertIntValueToType(rewriter, loc, replacement, targetType);
271+
return castIntValueToSameSizedType(rewriter, loc, replacement, targetType);
249272
}
250273

251-
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
252-
Value reachingDef,
253-
const DataLayout &dataLayout) {
254-
uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(getValue().getType());
255-
uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(slot.elemType);
256-
if (slotTypeSize <= valueTypeSize)
257-
return createConversionSequence(rewriter, getLoc(), getValue(),
258-
slot.elemType, dataLayout);
274+
/// Constructs operations that insert the bits of `srcValue` into the
275+
/// "beginning" of `reachingDef` (beginning is endianness dependent).
276+
/// Assumes that this conversion is possible.
277+
static Value createInsertAndCast(RewriterBase &rewriter, Location loc,
278+
Value srcValue, Value reachingDef,
279+
const DataLayout &dataLayout) {
259280

260-
assert(reachingDef && reachingDef.getType() == slot.elemType &&
261-
"expected the reaching definition's type to slot's type");
281+
assert(areConversionCompatible(dataLayout, reachingDef.getType(),
282+
srcValue.getType(),
283+
/*narrowingConversion=*/false) &&
284+
"expected that the compatibility was checked before");
285+
uint64_t valueTypeSize = dataLayout.getTypeSizeInBits(srcValue.getType());
286+
uint64_t slotTypeSize = dataLayout.getTypeSizeInBits(reachingDef.getType());
287+
if (slotTypeSize == valueTypeSize)
288+
return castSameSizedTypes(rewriter, loc, srcValue, reachingDef.getType(),
289+
dataLayout);
262290

263291
// In the case where the store only overwrites parts of the memory,
264292
// bit fiddling is required to construct the new value.
265293

266294
// First convert both values to integers of the same size.
267-
Value defAsInt =
268-
convertToIntValue(rewriter, getLoc(), reachingDef, dataLayout);
269-
Value valueAsInt =
270-
convertToIntValue(rewriter, getLoc(), getValue(), dataLayout);
295+
Value defAsInt = castToSameSizedInt(rewriter, loc, reachingDef, dataLayout);
296+
Value valueAsInt = castToSameSizedInt(rewriter, loc, srcValue, dataLayout);
271297
// Extend the value to the size of the reaching definition.
272-
valueAsInt = rewriter.createOrFold<LLVM::ZExtOp>(getLoc(), defAsInt.getType(),
273-
valueAsInt);
298+
valueAsInt =
299+
rewriter.createOrFold<LLVM::ZExtOp>(loc, defAsInt.getType(), valueAsInt);
274300
uint64_t sizeDifference = slotTypeSize - valueTypeSize;
275301
if (isBigEndian(dataLayout)) {
276302
// On big endian systems, a store to the base pointer overwrites the most
277303
// significant bits. To accomodate for this, the stored value needs to be
278304
// shifted into the according position.
279305
Value bigEndianShift = rewriter.create<LLVM::ConstantOp>(
280-
getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
281-
valueAsInt = rewriter.createOrFold<LLVM::ShlOp>(getLoc(), valueAsInt,
282-
bigEndianShift);
306+
loc, rewriter.getIntegerAttr(defAsInt.getType(), sizeDifference));
307+
valueAsInt =
308+
rewriter.createOrFold<LLVM::ShlOp>(loc, valueAsInt, bigEndianShift);
283309
}
284310

285311
// Construct the mask that is used to erase the bits that are overwritten by
@@ -298,14 +324,23 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
298324

299325
// Mask out the affected bits ...
300326
Value mask = rewriter.create<LLVM::ConstantOp>(
301-
getLoc(), rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
302-
Value masked = rewriter.createOrFold<LLVM::AndOp>(getLoc(), defAsInt, mask);
327+
loc, rewriter.getIntegerAttr(defAsInt.getType(), maskValue));
328+
Value masked = rewriter.createOrFold<LLVM::AndOp>(loc, defAsInt, mask);
303329

304330
// ... and combine the result with the new value.
305-
Value combined =
306-
rewriter.createOrFold<LLVM::OrOp>(getLoc(), masked, valueAsInt);
331+
Value combined = rewriter.createOrFold<LLVM::OrOp>(loc, masked, valueAsInt);
332+
333+
return castIntValueToSameSizedType(rewriter, loc, combined,
334+
reachingDef.getType());
335+
}
307336

308-
return convertIntValueToType(rewriter, getLoc(), combined, slot.elemType);
337+
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
338+
Value reachingDef,
339+
const DataLayout &dataLayout) {
340+
assert(reachingDef && reachingDef.getType() == slot.elemType &&
341+
"expected the reaching definition's type to match the slot's type");
342+
return createInsertAndCast(rewriter, getLoc(), getValue(), reachingDef,
343+
dataLayout);
309344
}
310345

311346
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -317,11 +352,10 @@ bool LLVM::LoadOp::canUsesBeRemoved(
317352
Value blockingUse = (*blockingUses.begin())->get();
318353
// If the blocking use is the slot ptr itself, there will be enough
319354
// context to reconstruct the result of the load at removal time, so it can
320-
// be removed (provided it loads the exact stored value and is not
321-
// volatile).
355+
// be removed (provided it is not volatile).
322356
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
323357
areConversionCompatible(dataLayout, getResult().getType(),
324-
slot.elemType) &&
358+
slot.elemType, /*narrowingConversion=*/true) &&
325359
!getVolatile_();
326360
}
327361

@@ -331,9 +365,8 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
331365
const DataLayout &dataLayout) {
332366
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
333367
// pointer.
334-
Value newResult =
335-
createConversionSequence(rewriter, getLoc(), reachingDefinition,
336-
getResult().getType(), dataLayout);
368+
Value newResult = createExtractAndCast(rewriter, getLoc(), reachingDefinition,
369+
getResult().getType(), dataLayout);
337370
rewriter.replaceAllUsesWith(getResult(), newResult);
338371
return DeletionKind::Delete;
339372
}
@@ -352,7 +385,7 @@ bool LLVM::StoreOp::canUsesBeRemoved(
352385
getValue() != slot.ptr &&
353386
areConversionCompatible(dataLayout, slot.elemType,
354387
getValue().getType(),
355-
/*allowWidening=*/true) &&
388+
/*narrowingConversion=*/false) &&
356389
!getVolatile_();
357390
}
358391

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class MemorySlotPromoter {
191191

192192
/// Lazily-constructed default value representing the content of the slot when
193193
/// no store has been executed. This function may mutate IR.
194-
Value getLazyDefaultValue();
194+
Value getOrCreateDefaultValue();
195195

196196
MemorySlot slot;
197197
PromotableAllocationOpInterface allocator;
@@ -232,7 +232,7 @@ MemorySlotPromoter::MemorySlotPromoter(
232232
#endif // NDEBUG
233233
}
234234

235-
Value MemorySlotPromoter::getLazyDefaultValue() {
235+
Value MemorySlotPromoter::getOrCreateDefaultValue() {
236236
if (defaultValue)
237237
return defaultValue;
238238

@@ -567,7 +567,7 @@ void MemorySlotPromoter::removeBlockingUses() {
567567
// If no reaching definition is known, this use is outside the reach of
568568
// the slot. The default value should thus be used.
569569
if (!reachingDef)
570-
reachingDef = getLazyDefaultValue();
570+
reachingDef = getOrCreateDefaultValue();
571571

572572
rewriter.setInsertionPointAfter(toPromote);
573573
if (toPromoteMemOp.removeBlockingUses(
@@ -601,7 +601,8 @@ void MemorySlotPromoter::removeBlockingUses() {
601601
}
602602

603603
void MemorySlotPromoter::promoteSlot() {
604-
computeReachingDefInRegion(slot.ptr.getParentRegion(), getLazyDefaultValue());
604+
computeReachingDefInRegion(slot.ptr.getParentRegion(),
605+
getOrCreateDefaultValue());
605606

606607
// Now that reaching definitions are known, remove all users.
607608
removeBlockingUses();
@@ -617,7 +618,7 @@ void MemorySlotPromoter::promoteSlot() {
617618
succOperands.size() + 1 == mergePoint->getNumArguments());
618619
if (succOperands.size() + 1 == mergePoint->getNumArguments())
619620
rewriter.modifyOpInPlace(
620-
user, [&]() { succOperands.append(getLazyDefaultValue()); });
621+
user, [&]() { succOperands.append(getOrCreateDefaultValue()); });
621622
}
622623
}
623624

mlir/test/Dialect/LLVMIR/mem2reg.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,3 +1144,16 @@ llvm.func @stores_with_different_types_branches(%arg0: i64, %arg1: f32, %cond: i
11441144
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
11451145
llvm.return %2 : f64
11461146
}
1147+
1148+
// -----
1149+
1150+
// Verifiy that mem2reg does not touch stores with undefined semantics.
1151+
1152+
// CHECK-LABEL: @store_out_of_bounds
1153+
llvm.func @store_out_of_bounds(%arg : i64) {
1154+
%0 = llvm.mlir.constant(1 : i32) : i32
1155+
// CHECK: llvm.alloca
1156+
%1 = llvm.alloca %0 x i32 : (i32) -> !llvm.ptr
1157+
llvm.store %arg, %1 : i64, !llvm.ptr
1158+
llvm.return
1159+
}

0 commit comments

Comments
 (0)