|
13 | 13 |
|
14 | 14 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
15 | 15 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
16 |
| -#include "mlir/IR/IRMapping.h" |
17 | 16 | #include "mlir/IR/Matchers.h"
|
18 | 17 | #include "mlir/IR/PatternMatch.h"
|
19 |
| -#include "mlir/IR/ValueRange.h" |
20 | 18 | #include "mlir/Interfaces/DataLayoutInterfaces.h"
|
21 | 19 | #include "mlir/Interfaces/MemorySlotInterfaces.h"
|
22 | 20 | #include "llvm/ADT/STLExtras.h"
|
@@ -71,12 +69,8 @@ SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
|
71 | 69 | if (!destructuredType)
|
72 | 70 | return {};
|
73 | 71 |
|
74 |
| - DenseMap<Attribute, Type> allocaTypeMap; |
75 |
| - for (Attribute index : llvm::make_first_range(destructuredType.value())) |
76 |
| - allocaTypeMap.insert({index, LLVM::LLVMPointerType::get(getContext())}); |
77 |
| - |
78 |
| - return { |
79 |
| - DestructurableMemorySlot{{getResult(), getElemType()}, {allocaTypeMap}}}; |
| 72 | + return {DestructurableMemorySlot{{getResult(), getElemType()}, |
| 73 | + *destructuredType}}; |
80 | 74 | }
|
81 | 75 |
|
82 | 76 | DenseMap<Attribute, MemorySlot>
|
@@ -182,17 +176,107 @@ DeletionKind LLVM::StoreOp::removeBlockingUses(
|
182 | 176 | return DeletionKind::Delete;
|
183 | 177 | }
|
184 | 178 |
|
| 179 | +/// Checks if `slot` can be accessed through the provided access type. |
| 180 | +static bool isValidAccessType(const MemorySlot &slot, Type accessType, |
| 181 | + const DataLayout &dataLayout) { |
| 182 | + return dataLayout.getTypeSize(accessType) <= |
| 183 | + dataLayout.getTypeSize(slot.elemType); |
| 184 | +} |
| 185 | + |
185 | 186 | LogicalResult LLVM::LoadOp::ensureOnlySafeAccesses(
|
186 | 187 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
|
187 | 188 | const DataLayout &dataLayout) {
|
188 |
| - return success(getAddr() != slot.ptr || getType() == slot.elemType); |
| 189 | + return success(getAddr() != slot.ptr || |
| 190 | + isValidAccessType(slot, getType(), dataLayout)); |
189 | 191 | }
|
190 | 192 |
|
191 | 193 | LogicalResult LLVM::StoreOp::ensureOnlySafeAccesses(
|
192 | 194 | const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
|
193 | 195 | const DataLayout &dataLayout) {
|
194 | 196 | return success(getAddr() != slot.ptr ||
|
195 |
| - getValue().getType() == slot.elemType); |
| 197 | + isValidAccessType(slot, getValue().getType(), dataLayout)); |
| 198 | +} |
| 199 | + |
| 200 | +/// Returns the subslot's type at the requested index. |
| 201 | +static Type getTypeAtIndex(const DestructurableMemorySlot &slot, |
| 202 | + Attribute index) { |
| 203 | + auto subelementIndexMap = |
| 204 | + slot.elemType.cast<DestructurableTypeInterface>().getSubelementIndexMap(); |
| 205 | + if (!subelementIndexMap) |
| 206 | + return {}; |
| 207 | + assert(!subelementIndexMap->empty()); |
| 208 | + |
| 209 | + // Note: Returns a null-type when no entry was found. |
| 210 | + return subelementIndexMap->lookup(index); |
| 211 | +} |
| 212 | + |
| 213 | +bool LLVM::LoadOp::canRewire(const DestructurableMemorySlot &slot, |
| 214 | + SmallPtrSetImpl<Attribute> &usedIndices, |
| 215 | + SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
| 216 | + const DataLayout &dataLayout) { |
| 217 | + if (getVolatile_()) |
| 218 | + return false; |
| 219 | + |
| 220 | + // A load always accesses the first element of the destructured slot. |
| 221 | + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); |
| 222 | + Type subslotType = getTypeAtIndex(slot, index); |
| 223 | + if (!subslotType) |
| 224 | + return false; |
| 225 | + |
| 226 | + // The access can only be replaced when the subslot is read within its bounds. |
| 227 | + if (dataLayout.getTypeSize(getType()) > dataLayout.getTypeSize(subslotType)) |
| 228 | + return false; |
| 229 | + |
| 230 | + usedIndices.insert(index); |
| 231 | + return true; |
| 232 | +} |
| 233 | + |
| 234 | +DeletionKind LLVM::LoadOp::rewire(const DestructurableMemorySlot &slot, |
| 235 | + DenseMap<Attribute, MemorySlot> &subslots, |
| 236 | + RewriterBase &rewriter, |
| 237 | + const DataLayout &dataLayout) { |
| 238 | + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); |
| 239 | + auto it = subslots.find(index); |
| 240 | + assert(it != subslots.end()); |
| 241 | + |
| 242 | + rewriter.modifyOpInPlace( |
| 243 | + *this, [&]() { getAddrMutable().set(it->getSecond().ptr); }); |
| 244 | + return DeletionKind::Keep; |
| 245 | +} |
| 246 | + |
| 247 | +bool LLVM::StoreOp::canRewire(const DestructurableMemorySlot &slot, |
| 248 | + SmallPtrSetImpl<Attribute> &usedIndices, |
| 249 | + SmallVectorImpl<MemorySlot> &mustBeSafelyUsed, |
| 250 | + const DataLayout &dataLayout) { |
| 251 | + if (getVolatile_()) |
| 252 | + return false; |
| 253 | + |
| 254 | + // A store always accesses the first element of the destructured slot. |
| 255 | + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); |
| 256 | + Type subslotType = getTypeAtIndex(slot, index); |
| 257 | + if (!subslotType) |
| 258 | + return false; |
| 259 | + |
| 260 | + // The access can only be replaced when the subslot is read within its bounds. |
| 261 | + if (dataLayout.getTypeSize(getValue().getType()) > |
| 262 | + dataLayout.getTypeSize(subslotType)) |
| 263 | + return false; |
| 264 | + |
| 265 | + usedIndices.insert(index); |
| 266 | + return true; |
| 267 | +} |
| 268 | + |
| 269 | +DeletionKind LLVM::StoreOp::rewire(const DestructurableMemorySlot &slot, |
| 270 | + DenseMap<Attribute, MemorySlot> &subslots, |
| 271 | + RewriterBase &rewriter, |
| 272 | + const DataLayout &dataLayout) { |
| 273 | + auto index = IntegerAttr::get(IntegerType::get(getContext(), 32), 0); |
| 274 | + auto it = subslots.find(index); |
| 275 | + assert(it != subslots.end()); |
| 276 | + |
| 277 | + rewriter.modifyOpInPlace( |
| 278 | + *this, [&]() { getAddrMutable().set(it->getSecond().ptr); }); |
| 279 | + return DeletionKind::Keep; |
196 | 280 | }
|
197 | 281 |
|
198 | 282 | //===----------------------------------------------------------------------===//
|
@@ -390,10 +474,8 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
|
390 | 474 | auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
|
391 | 475 | if (!firstLevelIndex)
|
392 | 476 | return false;
|
393 |
| - assert(slot.elementPtrs.contains(firstLevelIndex)); |
394 |
| - if (!llvm::isa<LLVM::LLVMPointerType>(slot.elementPtrs.at(firstLevelIndex))) |
395 |
| - return false; |
396 | 477 | mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
|
| 478 | + assert(slot.elementPtrs.contains(firstLevelIndex)); |
397 | 479 | usedIndices.insert(firstLevelIndex);
|
398 | 480 | return true;
|
399 | 481 | }
|
|
0 commit comments