Skip to content

Commit ac39fa7

Browse files
authored
[MLIR][Mem2Reg][LLVM] Enhance partial load support (#89094)
This commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences. Note: There will be additional followups that introduce a similar functionality for stores, and there are plans to support accesses into the middle of memory slots.
1 parent 8d6a9c0 commit ac39fa7

File tree

5 files changed

+277
-71
lines changed

5 files changed

+277
-71
lines changed

mlir/include/mlir/Interfaces/MemorySlotInterfaces.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
127127
}],
128128
"::mlir::Value", "getStored",
129129
(ins "const ::mlir::MemorySlot &":$slot,
130-
"::mlir::RewriterBase &":$rewriter)
130+
"::mlir::RewriterBase &":$rewriter,
131+
"const ::mlir::DataLayout &":$dataLayout)
131132
>,
132133
InterfaceMethod<[{
133134
Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
172173
(ins "const ::mlir::MemorySlot &":$slot,
173174
"const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
174175
"::mlir::RewriterBase &":$rewriter,
175-
"::mlir::Value":$reachingDefinition)
176+
"::mlir::Value":$reachingDefinition,
177+
"const ::mlir::DataLayout &":$dataLayout)
176178
>,
177179
];
178180
}

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 134 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
112112

113113
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
114114

115-
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
115+
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
116+
const DataLayout &dataLayout) {
116117
llvm_unreachable("getStored should not be called on LoadOp");
117118
}
118119

@@ -122,37 +123,121 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
122123
return getAddr() == slot.ptr;
123124
}
124125

125-
/// Checks that two types are the same or can be cast into one another.
126-
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
127-
return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
128-
!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
129-
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
126+
/// Checks if `type` can be used in any kind of conversion sequences.
127+
static bool isSupportedTypeForConversion(Type type) {
128+
// Aggregate types are not bitcastable.
129+
if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
130+
return false;
131+
132+
// LLVM vector types are only used for either pointers or target specific
133+
// types. These types cannot be casted in the general case, thus the memory
134+
// optimizations do not support them.
135+
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
136+
return false;
137+
138+
// Scalable types are not supported.
139+
if (auto vectorType = dyn_cast<VectorType>(type))
140+
return !vectorType.isScalable();
141+
return true;
130142
}
131143

144+
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145+
/// truncations.
146+
static bool areConversionCompatible(const DataLayout &layout, Type targetType,
147+
Type srcType) {
148+
if (targetType == srcType)
149+
return true;
150+
151+
if (!isSupportedTypeForConversion(targetType) ||
152+
!isSupportedTypeForConversion(srcType))
153+
return false;
154+
155+
// Pointer casts will only be sane when the bitsize of both pointer types is
156+
// the same.
157+
if (isa<LLVM::LLVMPointerType>(targetType) &&
158+
isa<LLVM::LLVMPointerType>(srcType))
159+
return layout.getTypeSize(targetType) == layout.getTypeSize(srcType);
160+
161+
return layout.getTypeSize(targetType) <= layout.getTypeSize(srcType);
162+
}
163+
164+
/// Checks if `dataLayout` describes a little endian layout.
165+
static bool isBigEndian(const DataLayout &dataLayout) {
166+
auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
167+
return endiannessStr && endiannessStr == "big";
168+
}
169+
170+
/// The size of a byte in bits.
171+
constexpr const static uint64_t kBitsInByte = 8;
172+
132173
/// Constructs operations that convert `inputValue` into a new value of type
133174
/// `targetType`. Assumes that this conversion is possible.
134175
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
135-
Value inputValue, Type targetType) {
136-
if (inputValue.getType() == targetType)
137-
return inputValue;
138-
139-
if (!isa<LLVM::LLVMPointerType>(targetType) &&
140-
!isa<LLVM::LLVMPointerType>(inputValue.getType()))
141-
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
176+
Value srcValue, Type targetType,
177+
const DataLayout &dataLayout) {
178+
// Get the types of the source and target values.
179+
Type srcType = srcValue.getType();
180+
assert(areConversionCompatible(dataLayout, targetType, srcType) &&
181+
"expected that the compatibility was checked before");
182+
183+
uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
184+
uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
185+
186+
// Nothing has to be done if the types are already the same.
187+
if (srcType == targetType)
188+
return srcValue;
189+
190+
// In the special case of casting one pointer to another, we want to generate
191+
// an address space cast. Bitcasts of pointers are not allowed and using
192+
// pointer to integer conversions are not equivalent due to the loss of
193+
// provenance.
194+
if (isa<LLVM::LLVMPointerType>(targetType) &&
195+
isa<LLVM::LLVMPointerType>(srcType))
196+
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
197+
srcValue);
198+
199+
IntegerType valueSizeInteger =
200+
rewriter.getIntegerType(srcTypeSize * kBitsInByte);
201+
Value replacement = srcValue;
202+
203+
// First, cast the value to a same-sized integer type.
204+
if (isa<LLVM::LLVMPointerType>(srcType))
205+
replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
206+
replacement);
207+
else if (replacement.getType() != valueSizeInteger)
208+
replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
209+
replacement);
210+
211+
// Truncate the integer if the size of the target is less than the value.
212+
if (targetTypeSize != srcTypeSize) {
213+
if (isBigEndian(dataLayout)) {
214+
uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
215+
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
216+
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
217+
replacement =
218+
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
219+
}
142220

143-
if (!isa<LLVM::LLVMPointerType>(targetType))
144-
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
221+
replacement = rewriter.create<LLVM::TruncOp>(
222+
loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
223+
replacement);
224+
}
145225

146-
if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
147-
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
226+
// Now cast the integer to the actual target type if required.
227+
if (isa<LLVM::LLVMPointerType>(targetType))
228+
replacement =
229+
rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
230+
else if (replacement.getType() != targetType)
231+
replacement =
232+
rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
148233

149-
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
150-
inputValue);
234+
return replacement;
151235
}
152236

153-
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
154-
return createConversionSequence(rewriter, getLoc(), getValue(),
155-
slot.elemType);
237+
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
238+
const DataLayout &dataLayout) {
239+
return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
240+
dataLayout);
156241
}
157242

158243
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +252,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
167252
// be removed (provided it loads the exact stored value and is not
168253
// volatile).
169254
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
170-
areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
255+
areConversionCompatible(dataLayout, getResult().getType(),
256+
slot.elemType) &&
171257
!getVolatile_();
172258
}
173259

174260
DeletionKind LLVM::LoadOp::removeBlockingUses(
175261
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
176-
RewriterBase &rewriter, Value reachingDefinition) {
262+
RewriterBase &rewriter, Value reachingDefinition,
263+
const DataLayout &dataLayout) {
177264
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
178265
// pointer.
179-
Value newResult = createConversionSequence(
180-
rewriter, getLoc(), reachingDefinition, getResult().getType());
266+
Value newResult =
267+
createConversionSequence(rewriter, getLoc(), reachingDefinition,
268+
getResult().getType(), dataLayout);
181269
rewriter.replaceAllUsesWith(getResult(), newResult);
182270
return DeletionKind::Delete;
183271
}
@@ -194,13 +282,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
194282
// store OF the slot pointer, only INTO the slot pointer.
195283
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
196284
getValue() != slot.ptr &&
197-
areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
285+
areConversionCompatible(dataLayout, slot.elemType,
286+
getValue().getType()) &&
198287
!getVolatile_();
199288
}
200289

201290
DeletionKind LLVM::StoreOp::removeBlockingUses(
202291
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
203-
RewriterBase &rewriter, Value reachingDefinition) {
292+
RewriterBase &rewriter, Value reachingDefinition,
293+
const DataLayout &dataLayout) {
204294
return DeletionKind::Delete;
205295
}
206296

@@ -747,8 +837,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
747837
return getDst() == slot.ptr;
748838
}
749839

750-
Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
751-
RewriterBase &rewriter) {
840+
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
841+
const DataLayout &dataLayout) {
752842
// TODO: Support non-integer types.
753843
return TypeSwitch<Type, Value>(slot.elemType)
754844
.Case([&](IntegerType intType) -> Value {
@@ -802,7 +892,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
802892

803893
DeletionKind LLVM::MemsetOp::removeBlockingUses(
804894
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
805-
RewriterBase &rewriter, Value reachingDefinition) {
895+
RewriterBase &rewriter, Value reachingDefinition,
896+
const DataLayout &dataLayout) {
806897
return DeletionKind::Delete;
807898
}
808899

@@ -1059,8 +1150,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
10591150
return memcpyStoresTo(*this, slot);
10601151
}
10611152

1062-
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
1063-
RewriterBase &rewriter) {
1153+
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1154+
const DataLayout &dataLayout) {
10641155
return memcpyGetStored(*this, slot, rewriter);
10651156
}
10661157

@@ -1074,7 +1165,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
10741165

10751166
DeletionKind LLVM::MemcpyOp::removeBlockingUses(
10761167
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1077-
RewriterBase &rewriter, Value reachingDefinition) {
1168+
RewriterBase &rewriter, Value reachingDefinition,
1169+
const DataLayout &dataLayout) {
10781170
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
10791171
reachingDefinition);
10801172
}
@@ -1109,7 +1201,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
11091201
}
11101202

11111203
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1112-
RewriterBase &rewriter) {
1204+
RewriterBase &rewriter,
1205+
const DataLayout &dataLayout) {
11131206
return memcpyGetStored(*this, slot, rewriter);
11141207
}
11151208

@@ -1123,7 +1216,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
11231216

11241217
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
11251218
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1126-
RewriterBase &rewriter, Value reachingDefinition) {
1219+
RewriterBase &rewriter, Value reachingDefinition,
1220+
const DataLayout &dataLayout) {
11271221
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
11281222
reachingDefinition);
11291223
}
@@ -1159,8 +1253,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
11591253
return memcpyStoresTo(*this, slot);
11601254
}
11611255

1162-
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
1163-
RewriterBase &rewriter) {
1256+
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1257+
const DataLayout &dataLayout) {
11641258
return memcpyGetStored(*this, slot, rewriter);
11651259
}
11661260

@@ -1174,7 +1268,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
11741268

11751269
DeletionKind LLVM::MemmoveOp::removeBlockingUses(
11761270
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1177-
RewriterBase &rewriter, Value reachingDefinition) {
1271+
RewriterBase &rewriter, Value reachingDefinition,
1272+
const DataLayout &dataLayout) {
11781273
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
11791274
reachingDefinition);
11801275
}

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
160160

161161
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
162162

163-
Value memref::LoadOp::getStored(const MemorySlot &slot,
164-
RewriterBase &rewriter) {
163+
Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
164+
const DataLayout &dataLayout) {
165165
llvm_unreachable("getStored should not be called on LoadOp");
166166
}
167167

@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
178178

179179
DeletionKind memref::LoadOp::removeBlockingUses(
180180
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
181-
RewriterBase &rewriter, Value reachingDefinition) {
181+
RewriterBase &rewriter, Value reachingDefinition,
182+
const DataLayout &dataLayout) {
182183
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
183184
// pointer.
184185
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
240241
return getMemRef() == slot.ptr;
241242
}
242243

243-
Value memref::StoreOp::getStored(const MemorySlot &slot,
244-
RewriterBase &rewriter) {
244+
Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
245+
const DataLayout &dataLayout) {
245246
return getValue();
246247
}
247248

@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
258259

259260
DeletionKind memref::StoreOp::removeBlockingUses(
260261
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
261-
RewriterBase &rewriter, Value reachingDefinition) {
262+
RewriterBase &rewriter, Value reachingDefinition,
263+
const DataLayout &dataLayout) {
262264
return DeletionKind::Delete;
263265
}
264266

0 commit comments

Comments
 (0)