Skip to content

Commit 0289ae5

Browse files
authored
[MLIR][LLVM][SROA] Support incorrectly typed memory accesses (#85813)
This commit relaxes the assumption of type consistency for LLVM dialect load and store operations in SROA. Instead, there is now a check that loads and stores are in the bounds specified by the sub-slot they access. This commit additionally removes the corresponding patterns from the type consistency pass, as they are no longer necessary. Note: It will be necessary to extend Mem2Reg with the logic for differently sized accesses as well. This is non-the-less a strict upgrade for productive flows, as the type consistency pass can produce invalid IR for some odd cases.
1 parent 90454a6 commit 0289ae5

File tree

8 files changed

+204
-263
lines changed

8 files changed

+204
-263
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
323323
}
324324

325325
def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
326-
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
326+
[DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
327+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
327328
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
328329
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
329330
dag args = (ins LLVM_AnyPointer:$addr,
@@ -402,7 +403,8 @@ def LLVM_LoadOp : LLVM_MemAccessOpBase<"load",
402403
}
403404

404405
def LLVM_StoreOp : LLVM_MemAccessOpBase<"store",
405-
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
406+
[DeclareOpInterfaceMethods<DestructurableAccessorOpInterface>,
407+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
406408
DeclareOpInterfaceMethods<PromotableMemOpInterface>,
407409
DeclareOpInterfaceMethods<SafeMemorySlotAccessOpInterface>]> {
408410
dag args = (ins LLVM_LoadableType:$value,

mlir/include/mlir/Dialect/LLVMIR/Transforms/TypeConsistency.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,6 @@ namespace LLVM {
2929
/// interpret pointee types as consistently as possible.
3030
std::unique_ptr<Pass> createTypeConsistencyPass();
3131

32-
/// Transforms uses of pointers to a whole struct to uses of pointers to the
33-
/// first element of a struct. This is achieved by inserting a GEP to the first
34-
/// element when possible.
35-
template <class User>
36-
class AddFieldGetterToStructDirectUse : public OpRewritePattern<User> {
37-
public:
38-
using OpRewritePattern<User>::OpRewritePattern;
39-
40-
LogicalResult matchAndRewrite(User user,
41-
PatternRewriter &rewriter) const override;
42-
};
43-
4432
/// Canonicalizes GEPs of which the base type and the pointer's type hint do not
4533
/// match. This is done by replacing the original GEP into a GEP with the type
4634
/// hint as a base type when an element of the hinted type aligns with the

mlir/include/mlir/Interfaces/MemorySlotInterfaces.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ struct MemorySlot {
2626

2727
/// Memory slot attached with information about its destructuring procedure.
2828
struct DestructurableMemorySlot : public MemorySlot {
29-
/// Maps an index within the memory slot to the type of the pointer that
30-
/// will be generated to access the element directly.
29+
/// Maps an index within the memory slot to the corresponding subelement type.
3130
DenseMap<Attribute, Type> elementPtrs;
3231
};
3332

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@
1313

1414
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1515
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
16-
#include "mlir/IR/IRMapping.h"
1716
#include "mlir/IR/Matchers.h"
1817
#include "mlir/IR/PatternMatch.h"
19-
#include "mlir/IR/ValueRange.h"
2018
#include "mlir/Interfaces/DataLayoutInterfaces.h"
2119
#include "mlir/Interfaces/MemorySlotInterfaces.h"
2220
#include "llvm/ADT/STLExtras.h"
@@ -71,12 +69,8 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
7169
if (!destructuredType)
7270
return {};
7371

74-
DenseMap<Attribute, Type> allocaTypeMap;
75-
for (Attribute index : llvm::make_first_range(destructuredType.value()))
76-
allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())});
77-
78-
return {
79-
DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}};
72+
return {DestructurableMemorySlot{{getResult(), getElemType()},
73+
*destructuredType}};
8074
}
8175

8276
DenseMap<Attribute, MemorySlot>
@@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
182176
return DeletionKind::Delete;
183177
}
184178

179+
/// Checks if `slot` can be accessed through the provided access type.
180+
static bool isValidAccessType(const MemorySlot &slot, Type accessType,
181+
const DataLayout &dataLayout) {
182+
return dataLayout.getTypeSize(accessType) <=
183+
dataLayout.getTypeSize(slot.elemType);
184+
}
185+
185186
LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
186187
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
187188
const DataLayout &dataLayout) {
188-
return success(getAddr() != slot.ptr || getType() == slot.elemType);
189+
return success(getAddr() != slot.ptr ||
190+
isValidAccessType(slot, getType(), dataLayout));
189191
}
190192

191193
LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
192194
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
193195
const DataLayout &dataLayout) {
194196
return success(getAddr() != slot.ptr ||
195-
getValue().getType() == slot.elemType);
197+
isValidAccessType(slot, getValue().getType(), dataLayout));
198+
}
199+
200+
/// Returns the subslot's type at the requested index.
201+
static Type getTypeAtIndex(const DestructurableMemorySlot &slot,
202+
Attribute index) {
203+
auto subelementIndexMap =
204+
slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap();
205+
if (!subelementIndexMap)
206+
return {};
207+
assert(!subelementIndexMap->empty());
208+
209+
// Note: Returns a null-type when no entry was found.
210+
return subelementIndexMap->lookup(index);
211+
}
212+
213+
bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot,
214+
SmallPtrSetImpl<Attribute> &usedIndices,
215+
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
216+
const DataLayout &dataLayout) {
217+
if (getVolatile_())
218+
return false;
219+
220+
// A load always accesses the first element of the destructured slot.
221+
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
222+
Type subslotType = getTypeAtIndex(slot, index);
223+
if (!subslotType)
224+
return false;
225+
226+
// The access can only be replaced when the subslot is read within its bounds.
227+
if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType))
228+
return false;
229+
230+
usedIndices.insert(index);
231+
return true;
232+
}
233+
234+
DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot,
235+
DenseMap<Attribute, MemorySlot> &subslots,
236+
RewriterBase &rewriter,
237+
const DataLayout &dataLayout) {
238+
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
239+
auto it = subslots.find(index);
240+
assert(it != subslots.end());
241+
242+
rewriter.modifyOpInPlace(
243+
*this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
244+
return DeletionKind::Keep;
245+
}
246+
247+
bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot,
248+
SmallPtrSetImpl<Attribute> &usedIndices,
249+
SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
250+
const DataLayout &dataLayout) {
251+
if (getVolatile_())
252+
return false;
253+
254+
// A store always accesses the first element of the destructured slot.
255+
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
256+
Type subslotType = getTypeAtIndex(slot, index);
257+
if (!subslotType)
258+
return false;
259+
260+
// The access can only be replaced when the subslot is read within its bounds.
261+
if (dataLayout.getTypeSize(getValue().getType()) >
262+
dataLayout.getTypeSize(subslotType))
263+
return false;
264+
265+
usedIndices.insert(index);
266+
return true;
267+
}
268+
269+
DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot,
270+
DenseMap<Attribute, MemorySlot> &subslots,
271+
RewriterBase &rewriter,
272+
const DataLayout &dataLayout) {
273+
auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0);
274+
auto it = subslots.find(index);
275+
assert(it != subslots.end());
276+
277+
rewriter.modifyOpInPlace(
278+
*this, [&]() { getAddrMutable().set(it->getSecond().ptr); });
279+
return DeletionKind::Keep;
196280
}
197281

198282
//===----------------------------------------------------------------------===//
@@ -390,10 +474,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
390474
auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
391475
if (!firstLevelIndex)
392476
return false;
393-
assert(slot.elementPtrs.contains(firstLevelIndex));
394-
if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex)))
395-
return false;
396477
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
478+
assert(slot.elementPtrs.contains(firstLevelIndex));
397479
usedIndices.insert(firstLevelIndex);
398480
return true;
399481
}

mlir/lib/Dialect/LLVMIR/Transforms/TypeConsistency.cpp

Lines changed: 0 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -49,104 +49,6 @@ static bool areBitcastCompatible(DataLayout &layout, Type lhs, Type rhs) {
4949
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
5050
}
5151

52-
//===----------------------------------------------------------------------===//
53-
// AddFieldGetterToStructDirectUse
54-
//===----------------------------------------------------------------------===//
55-
56-
/// Gets the type of the first subelement of `type` if `type` is destructurable,
57-
/// nullptr otherwise.
58-
static Type getFirstSubelementType(Type type) {
59-
auto destructurable = dyn_cast<DestructurableTypeInterface>(type);
60-
if (!destructurable)
61-
return nullptr;
62-
63-
Type subelementType = destructurable.getTypeAtIndex(
64-
IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0));
65-
if (subelementType)
66-
return subelementType;
67-
68-
return nullptr;
69-
}
70-
71-
/// Extracts a pointer to the first field of an `elemType` from the address
72-
/// pointer of the provided MemOp, and rewires the MemOp so it uses that pointer
73-
/// instead.
74-
template <class MemOp>
75-
static void insertFieldIndirection(MemOp op, PatternRewriter &rewriter,
76-
Type elemType) {
77-
PatternRewriter::InsertionGuard guard(rewriter);
78-
79-
rewriter.setInsertionPointAfterValue(op.getAddr());
80-
SmallVector<GEPArg> firstTypeIndices{0, 0};
81-
82-
Value properPtr = rewriter.create<GEPOp>(
83-
op->getLoc(), LLVM::LLVMPointerType::get(op.getContext()), elemType,
84-
op.getAddr(), firstTypeIndices);
85-
86-
rewriter.modifyOpInPlace(op,
87-
[&]() { op.getAddrMutable().assign(properPtr); });
88-
}
89-
90-
template <>
91-
LogicalResult AddFieldGetterToStructDirectUse<LoadOp>::matchAndRewrite(
92-
LoadOp load, PatternRewriter &rewriter) const {
93-
PatternRewriter::InsertionGuard guard(rewriter);
94-
95-
Type inconsistentElementType =
96-
isElementTypeInconsistent(load.getAddr(), load.getType());
97-
if (!inconsistentElementType)
98-
return failure();
99-
Type firstType = getFirstSubelementType(inconsistentElementType);
100-
if (!firstType)
101-
return failure();
102-
DataLayout layout = DataLayout::closest(load);
103-
if (!areBitcastCompatible(layout, firstType, load.getResult().getType()))
104-
return failure();
105-
106-
insertFieldIndirection<LoadOp>(load, rewriter, inconsistentElementType);
107-
108-
// If the load does not use the first type but a type that can be casted from
109-
// it, add a bitcast and change the load type.
110-
if (firstType != load.getResult().getType()) {
111-
rewriter.setInsertionPointAfterValue(load.getResult());
112-
BitcastOp bitcast = rewriter.create<BitcastOp>(
113-
load->getLoc(), load.getResult().getType(), load.getResult());
114-
rewriter.modifyOpInPlace(load,
115-
[&]() { load.getResult().setType(firstType); });
116-
rewriter.replaceAllUsesExcept(load.getResult(), bitcast.getResult(),
117-
bitcast);
118-
}
119-
120-
return success();
121-
}
122-
123-
template <>
124-
LogicalResult AddFieldGetterToStructDirectUse<StoreOp>::matchAndRewrite(
125-
StoreOp store, PatternRewriter &rewriter) const {
126-
PatternRewriter::InsertionGuard guard(rewriter);
127-
128-
Type inconsistentElementType =
129-
isElementTypeInconsistent(store.getAddr(), store.getValue().getType());
130-
if (!inconsistentElementType)
131-
return failure();
132-
Type firstType = getFirstSubelementType(inconsistentElementType);
133-
if (!firstType)
134-
return failure();
135-
136-
DataLayout layout = DataLayout::closest(store);
137-
// Check that the first field has the right type or can at least be bitcast
138-
// to the right type.
139-
if (!areBitcastCompatible(layout, firstType, store.getValue().getType()))
140-
return failure();
141-
142-
insertFieldIndirection<StoreOp>(store, rewriter, inconsistentElementType);
143-
144-
rewriter.modifyOpInPlace(
145-
store, [&]() { store.getValueMutable().assign(store.getValue()); });
146-
147-
return success();
148-
}
149-
15052
//===----------------------------------------------------------------------===//
15153
// CanonicalizeAlignedGep
15254
//===----------------------------------------------------------------------===//
@@ -684,9 +586,6 @@ struct LLVMTypeConsistencyPass
684586
: public LLVM::impl::LLVMTypeConsistencyBase<LLVMTypeConsistencyPass> {
685587
void runOnOperation() override {
686588
RewritePatternSet rewritePatterns(&getContext());
687-
rewritePatterns.add<AddFieldGetterToStructDirectUse<LoadOp>>(&getContext());
688-
rewritePatterns.add<AddFieldGetterToStructDirectUse<StoreOp>>(
689-
&getContext());
690589
rewritePatterns.add<CanonicalizeAlignedGep>(&getContext());
691590
rewritePatterns.add<SplitStores>(&getContext(), maxVectorSplitSize);
692591
rewritePatterns.add<BitcastStores>(&getContext());

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,8 @@ memref::AllocaOp::getDestructurableSlots() {
120120
if (!destructuredType)
121121
return {};
122122

123-
DenseMap<Attribute, Type> indexMap;
124-
for (auto const &[index, type] : *destructuredType)
125-
indexMap.insert({index, MemRefType::get({}, type)});
126-
127-
return {DestructurableMemorySlot{{getMemref(), memrefType}, indexMap}};
123+
return {
124+
DestructurableMemorySlot{{getMemref(), memrefType}, *destructuredType}};
128125
}
129126

130127
DenseMap<Attribute, MemorySlot>

0 commit comments

Comments
 (0)