Skip to content

Commit 5e37eeb

Browse files
committed
[MLIR][Mem2Reg][LLVM] Enhance partial load support
This commit improves LLVM dialect's Mem2Reg interfaces to support promotions of partial loads from larger memory slots. To support this, the Mem2Reg interface methods are extended with additional data layout parameters. The data layout is required to determine type sizes to produce correct conversion sequences.
1 parent ee284d2 commit 5e37eeb

File tree

5 files changed

+270
-70
lines changed

5 files changed

+270
-70
lines changed

mlir/include/mlir/Interfaces/MemorySlotInterfaces.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
127127
}],
128128
"::mlir::Value", "getStored",
129129
(ins "const ::mlir::MemorySlot &":$slot,
130-
"::mlir::RewriterBase &":$rewriter)
130+
"::mlir::RewriterBase &":$rewriter,
131+
"const ::mlir::DataLayout &":$dataLayout)
131132
>,
132133
InterfaceMethod<[{
133134
Checks that this operation can be promoted to no longer use the provided
@@ -172,7 +173,8 @@ def PromotableMemOpInterface : OpInterface<"PromotableMemOpInterface"> {
172173
(ins "const ::mlir::MemorySlot &":$slot,
173174
"const ::llvm::SmallPtrSetImpl<mlir::OpOperand *> &":$blockingUses,
174175
"::mlir::RewriterBase &":$rewriter,
175-
"::mlir::Value":$reachingDefinition)
176+
"::mlir::Value":$reachingDefinition,
177+
"const ::mlir::DataLayout &":$dataLayout)
176178
>,
177179
];
178180
}

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 136 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
112112

113113
bool LLVM::LoadOp::storesTo(const MemorySlot &slot) { return false; }
114114

115-
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
115+
Value LLVM::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
116+
const DataLayout &dataLayout) {
116117
llvm_unreachable("getStored should not be called on LoadOp");
117118
}
118119

@@ -122,37 +123,124 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
122123
return getAddr() == slot.ptr;
123124
}
124125

125-
/// Checks that two types are the same or can be cast into one another.
126-
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
127-
return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
128-
!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
129-
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
126+
/// Checks if `type` can be used in any kind of conversion sequences.
127+
static bool isSupportedTypeForConversion(Type type) {
128+
// Aggregate types are not bitcastable.
129+
if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
130+
return false;
131+
132+
// LLVM vector types are only used for either pointers or target specific
133+
// types. These types cannot be casted in the general case, thus the memory
134+
// optimizations do not support them.
135+
if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
136+
return false;
137+
138+
// Scalable types are not supported.
139+
if (auto vectorType = dyn_cast<VectorType>(type))
140+
return !vectorType.isScalable();
141+
return true;
142+
}
143+
144+
/// Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145+
/// truncations.
146+
static bool areConversionCompatible(const DataLayout &layout, Type lhs,
147+
Type rhs) {
148+
if (lhs == rhs)
149+
return true;
150+
151+
// Aggregate types cannot be casted.
152+
if (!isSupportedTypeForConversion(lhs) || !isSupportedTypeForConversion(rhs))
153+
return false;
154+
return layout.getTypeSize(lhs) <= layout.getTypeSize(rhs);
130155
}
131156

157+
/// Checks if `dataLayout` describes a little endian layout.
158+
static bool isLittleEndian(const DataLayout &dataLayout) {
159+
auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness());
160+
return !endiannessStr || endiannessStr == "little";
161+
}
162+
163+
/// The size of a byte in bits.
164+
constexpr const static uint64_t kBitsInByte = 8;
165+
132166
/// Constructs operations that convert `inputValue` into a new value of type
133167
/// `targetType`. Assumes that this conversion is possible.
134168
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
135-
Value inputValue, Type targetType) {
136-
if (inputValue.getType() == targetType)
137-
return inputValue;
169+
Value srcValue, Type targetType,
170+
const DataLayout &dataLayout) {
171+
// Get the types of the source and destination values.
172+
Type srcType = srcValue.getType();
173+
174+
uint64_t srcTypeSize = dataLayout.getTypeSize(srcType);
175+
uint64_t targetTypeSize = dataLayout.getTypeSize(targetType);
176+
177+
// Nothing has to be done if the types are already the same.
178+
if (srcType == targetType)
179+
return srcValue;
180+
181+
// The code below is currently not capable of handling aggregate types as it
182+
// makes use of bitcasts. Aggregates cannot be bitcast.
183+
// TODO: We should have a `LLVMAggregateType` base class to easily perform
184+
// this `isa`.
185+
if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
186+
isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
187+
return nullptr;
188+
189+
// In the special case of casting one pointer to another, we want to generate
190+
// an address space cast. Bitcasts of pointers are not allowed and using
191+
// pointer to integer conversions are not equivalent due to the loss or
192+
// provenance.
193+
if (isa<LLVM::LLVMPointerType>(targetType) &&
194+
isa<LLVM::LLVMPointerType>(srcType)) {
195+
// Abort the conversion if the pointers have different bitwidths.
196+
if (srcTypeSize != targetTypeSize)
197+
return nullptr;
198+
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
199+
srcValue);
200+
}
138201

139-
if (!isa<LLVM::LLVMPointerType>(targetType) &&
140-
!isa<LLVM::LLVMPointerType>(inputValue.getType()))
141-
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
202+
IntegerType valueSizeInteger =
203+
rewriter.getIntegerType(srcTypeSize * kBitsInByte);
204+
Value replacement = srcValue;
205+
206+
// First, cast the value to a same-sized integer type.
207+
if (isa<LLVM::LLVMPointerType>(srcType))
208+
replacement = rewriter.createOrFold<LLVM::PtrToIntOp>(loc, valueSizeInteger,
209+
replacement);
210+
else if (replacement.getType() != valueSizeInteger)
211+
replacement = rewriter.createOrFold<LLVM::BitcastOp>(loc, valueSizeInteger,
212+
replacement);
213+
214+
// Truncate the integer if the size of the read is less than the value.
215+
if (targetTypeSize != srcTypeSize) {
216+
if (!isLittleEndian(dataLayout)) {
217+
uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte;
218+
auto shiftConstant = rewriter.create<LLVM::ConstantOp>(
219+
loc, rewriter.getIntegerAttr(srcType, shiftAmount));
220+
replacement =
221+
rewriter.createOrFold<LLVM::LShrOp>(loc, srcValue, shiftConstant);
222+
}
142223

143-
if (!isa<LLVM::LLVMPointerType>(targetType))
144-
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
224+
replacement = rewriter.create<LLVM::TruncOp>(
225+
loc, rewriter.getIntegerType(targetTypeSize * kBitsInByte),
226+
replacement);
227+
}
145228

146-
if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
147-
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
229+
// Now cast the integer to the actual destination type if required.
230+
if (isa<LLVM::LLVMPointerType>(targetType))
231+
replacement =
232+
rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, replacement);
233+
else if (replacement.getType() != targetType)
234+
replacement =
235+
rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, replacement);
148236

149-
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
150-
inputValue);
237+
return replacement;
151238
}
152239

153-
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
154-
return createConversionSequence(rewriter, getLoc(), getValue(),
155-
slot.elemType);
240+
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
241+
const DataLayout &dataLayout) {
242+
return createConversionSequence(rewriter, getLoc(), getValue(), slot.elemType,
243+
dataLayout);
156244
}
157245

158246
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -167,17 +255,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
167255
// be removed (provided it loads the exact stored value and is not
168256
// volatile).
169257
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
170-
areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
258+
areConversionCompatible(dataLayout, getResult().getType(),
259+
slot.elemType) &&
171260
!getVolatile_();
172261
}
173262

174263
DeletionKind LLVM::LoadOp::removeBlockingUses(
175264
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
176-
RewriterBase &rewriter, Value reachingDefinition) {
265+
RewriterBase &rewriter, Value reachingDefinition,
266+
const DataLayout &dataLayout) {
177267
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
178268
// pointer.
179-
Value newResult = createConversionSequence(
180-
rewriter, getLoc(), reachingDefinition, getResult().getType());
269+
Value newResult =
270+
createConversionSequence(rewriter, getLoc(), reachingDefinition,
271+
getResult().getType(), dataLayout);
181272
rewriter.replaceAllUsesWith(getResult(), newResult);
182273
return DeletionKind::Delete;
183274
}
@@ -194,13 +285,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
194285
// store OF the slot pointer, only INTO the slot pointer.
195286
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
196287
getValue() != slot.ptr &&
197-
areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
288+
areConversionCompatible(dataLayout, slot.elemType,
289+
getValue().getType()) &&
198290
!getVolatile_();
199291
}
200292

201293
DeletionKind LLVM::StoreOp::removeBlockingUses(
202294
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
203-
RewriterBase &rewriter, Value reachingDefinition) {
295+
RewriterBase &rewriter, Value reachingDefinition,
296+
const DataLayout &dataLayout) {
204297
return DeletionKind::Delete;
205298
}
206299

@@ -747,8 +840,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
747840
return getDst() == slot.ptr;
748841
}
749842

750-
Value LLVM::MemsetOp::getStored(const MemorySlot &slot,
751-
RewriterBase &rewriter) {
843+
Value LLVM::MemsetOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
844+
const DataLayout &dataLayout) {
752845
// TODO: Support non-integer types.
753846
return TypeSwitch<Type, Value>(slot.elemType)
754847
.Case([&](IntegerType intType) -> Value {
@@ -802,7 +895,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
802895

803896
DeletionKind LLVM::MemsetOp::removeBlockingUses(
804897
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
805-
RewriterBase &rewriter, Value reachingDefinition) {
898+
RewriterBase &rewriter, Value reachingDefinition,
899+
const DataLayout &dataLayout) {
806900
return DeletionKind::Delete;
807901
}
808902

@@ -1059,8 +1153,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
10591153
return memcpyStoresTo(*this, slot);
10601154
}
10611155

1062-
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot,
1063-
RewriterBase &rewriter) {
1156+
Value LLVM::MemcpyOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1157+
const DataLayout &dataLayout) {
10641158
return memcpyGetStored(*this, slot, rewriter);
10651159
}
10661160

@@ -1074,7 +1168,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
10741168

10751169
DeletionKind LLVM::MemcpyOp::removeBlockingUses(
10761170
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1077-
RewriterBase &rewriter, Value reachingDefinition) {
1171+
RewriterBase &rewriter, Value reachingDefinition,
1172+
const DataLayout &dataLayout) {
10781173
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
10791174
reachingDefinition);
10801175
}
@@ -1109,7 +1204,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
11091204
}
11101205

11111206
Value LLVM::MemcpyInlineOp::getStored(const MemorySlot &slot,
1112-
RewriterBase &rewriter) {
1207+
RewriterBase &rewriter,
1208+
const DataLayout &dataLayout) {
11131209
return memcpyGetStored(*this, slot, rewriter);
11141210
}
11151211

@@ -1123,7 +1219,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
11231219

11241220
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses(
11251221
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1126-
RewriterBase &rewriter, Value reachingDefinition) {
1222+
RewriterBase &rewriter, Value reachingDefinition,
1223+
const DataLayout &dataLayout) {
11271224
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
11281225
reachingDefinition);
11291226
}
@@ -1159,8 +1256,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
11591256
return memcpyStoresTo(*this, slot);
11601257
}
11611258

1162-
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot,
1163-
RewriterBase &rewriter) {
1259+
Value LLVM::MemmoveOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
1260+
const DataLayout &dataLayout) {
11641261
return memcpyGetStored(*this, slot, rewriter);
11651262
}
11661263

@@ -1174,7 +1271,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
11741271

11751272
DeletionKind LLVM::MemmoveOp::removeBlockingUses(
11761273
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1177-
RewriterBase &rewriter, Value reachingDefinition) {
1274+
RewriterBase &rewriter, Value reachingDefinition,
1275+
const DataLayout &dataLayout) {
11781276
return memcpyRemoveBlockingUses(*this, slot, blockingUses, rewriter,
11791277
reachingDefinition);
11801278
}

mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
160160

161161
bool memref::LoadOp::storesTo(const MemorySlot &slot) { return false; }
162162

163-
Value memref::LoadOp::getStored(const MemorySlot &slot,
164-
RewriterBase &rewriter) {
163+
Value memref::LoadOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
164+
const DataLayout &dataLayout) {
165165
llvm_unreachable("getStored should not be called on LoadOp");
166166
}
167167

@@ -178,7 +178,8 @@ bool memref::LoadOp::canUsesBeRemoved(
178178

179179
DeletionKind memref::LoadOp::removeBlockingUses(
180180
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
181-
RewriterBase &rewriter, Value reachingDefinition) {
181+
RewriterBase &rewriter, Value reachingDefinition,
182+
const DataLayout &dataLayout) {
182183
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
183184
// pointer.
184185
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
@@ -240,8 +241,8 @@ bool memref::StoreOp::storesTo(const MemorySlot &slot) {
240241
return getMemRef() == slot.ptr;
241242
}
242243

243-
Value memref::StoreOp::getStored(const MemorySlot &slot,
244-
RewriterBase &rewriter) {
244+
Value memref::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter,
245+
const DataLayout &dataLayout) {
245246
return getValue();
246247
}
247248

@@ -258,7 +259,8 @@ bool memref::StoreOp::canUsesBeRemoved(
258259

259260
DeletionKind memref::StoreOp::removeBlockingUses(
260261
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
261-
RewriterBase &rewriter, Value reachingDefinition) {
262+
RewriterBase &rewriter, Value reachingDefinition,
263+
const DataLayout &dataLayout) {
262264
return DeletionKind::Delete;
263265
}
264266

0 commit comments

Comments
 (0)