Skip to content

Commit 63d22f7

Browse files
authored
[MLIR][LLVM][SROA] Make GEP handling type agnostic (#86950)
This commit removes SROA's type consistency constraints from LLVM dialect's GEPOp. The checks for valid indexing are now purely done by computing the GEP's offset with the aid of the data layout. To simplify handling of "nested subslots", we are tricking the SROA by handing in memory slots that hold byte array types. This ensures that subsequent accesses only need to check if their access will be in-bounds. This lifts the requirement of determining the sub-types for all but the first level of subslots.
1 parent eb08c0f commit 63d22f7

File tree

2 files changed

+312
-59
lines changed

2 files changed

+312
-59
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 182 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include "llvm/ADT/STLExtras.h"
2121
#include "llvm/ADT/TypeSwitch.h"
2222

23+
#define DEBUG_TYPE "sroa"
24+
2325
using namespace mlir;
2426

2527
//===----------------------------------------------------------------------===//
@@ -431,29 +433,165 @@ DeletionKind LLVM::GEPOp::removeBlockingUses(
431433
return DeletionKind::Delete;
432434
}
433435

434-
static bool isFirstIndexZero(LLVM::GEPOp gep) {
435-
IntegerAttr index =
436-
llvm::dyn_cast_if_present<IntegerAttr>(gep.getIndices()[0]);
437-
return index && index.getInt() == 0;
436+
/// Returns the amount of bytes the provided GEP elements will offset the
437+
/// pointer by. Returns nullopt if no constant offset could be computed.
438+
static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout,
439+
LLVM::GEPOp gep) {
440+
// Collects all indices.
441+
SmallVector<uint64_t> indices;
442+
for (auto index : gep.getIndices()) {
443+
auto constIndex = dyn_cast<IntegerAttr>(index);
444+
if (!constIndex)
445+
return {};
446+
int64_t gepIndex = constIndex.getInt();
447+
// Negative indices are not supported.
448+
if (gepIndex < 0)
449+
return {};
450+
indices.push_back(gepIndex);
451+
}
452+
453+
Type currentType = gep.getElemType();
454+
uint64_t offset = indices[0] * dataLayout.getTypeSize(currentType);
455+
456+
for (uint64_t index : llvm::drop_begin(indices)) {
457+
bool shouldCancel =
458+
TypeSwitch<Type, bool>(currentType)
459+
.Case([&](LLVM::LLVMArrayType arrayType) {
460+
offset +=
461+
index * dataLayout.getTypeSize(arrayType.getElementType());
462+
currentType = arrayType.getElementType();
463+
return false;
464+
})
465+
.Case([&](LLVM::LLVMStructType structType) {
466+
ArrayRef<Type> body = structType.getBody();
467+
assert(index < body.size() && "expected valid struct indexing");
468+
for (uint32_t i : llvm::seq(index)) {
469+
if (!structType.isPacked())
470+
offset = llvm::alignTo(
471+
offset, dataLayout.getTypeABIAlignment(body[i]));
472+
offset += dataLayout.getTypeSize(body[i]);
473+
}
474+
475+
// Align for the current type as well.
476+
if (!structType.isPacked())
477+
offset = llvm::alignTo(
478+
offset, dataLayout.getTypeABIAlignment(body[index]));
479+
currentType = body[index];
480+
return false;
481+
})
482+
.Default([&](Type type) {
483+
LLVM_DEBUG(llvm::dbgs()
484+
<< "[sroa] Unsupported type for offset computations"
485+
<< type << "\n");
486+
return true;
487+
});
488+
489+
if (shouldCancel)
490+
return std::nullopt;
491+
}
492+
493+
return offset;
494+
}
495+
496+
namespace {
497+
/// A struct that stores both the index into the aggregate type of the slot as
498+
/// well as the corresponding byte offset in memory.
499+
struct SubslotAccessInfo {
500+
/// The parent slot's index that the access falls into.
501+
uint32_t index;
502+
/// The offset into the subslot of the access.
503+
uint64_t subslotOffset;
504+
};
505+
} // namespace
506+
507+
/// Computes subslot access information for an access into `slot` with the given
508+
/// offset.
509+
/// Returns nullopt when the offset is out-of-bounds or when the access is into
510+
/// the padding of `slot`.
511+
static std::optional<SubslotAccessInfo>
512+
getSubslotAccessInfo(const DestructurableMemorySlot &slot,
513+
const DataLayout &dataLayout, LLVM::GEPOp gep) {
514+
std::optional<uint64_t> offset = gepToByteOffset(dataLayout, gep);
515+
if (!offset)
516+
return {};
517+
518+
// Helper to check that a constant index is in the bounds of the GEP index
519+
// representation. LLVM dialects's GEP arguments have a limited bitwidth, thus
520+
// this additional check is necessary.
521+
auto isOutOfBoundsGEPIndex = [](uint64_t index) {
522+
return index >= (1 << LLVM::kGEPConstantBitWidth);
523+
};
524+
525+
Type type = slot.elemType;
526+
if (*offset >= dataLayout.getTypeSize(type))
527+
return {};
528+
return TypeSwitch<Type, std::optional<SubslotAccessInfo>>(type)
529+
.Case([&](LLVM::LLVMArrayType arrayType)
530+
-> std::optional<SubslotAccessInfo> {
531+
// Find which element of the array contains the offset.
532+
uint64_t elemSize = dataLayout.getTypeSize(arrayType.getElementType());
533+
uint64_t index = *offset / elemSize;
534+
if (isOutOfBoundsGEPIndex(index))
535+
return {};
536+
return SubslotAccessInfo{static_cast<uint32_t>(index),
537+
*offset - (index * elemSize)};
538+
})
539+
.Case([&](LLVM::LLVMStructType structType)
540+
-> std::optional<SubslotAccessInfo> {
541+
uint64_t distanceToStart = 0;
542+
// Walk over the elements of the struct to find in which of
543+
// them the offset is.
544+
for (auto [index, elem] : llvm::enumerate(structType.getBody())) {
545+
uint64_t elemSize = dataLayout.getTypeSize(elem);
546+
if (!structType.isPacked()) {
547+
distanceToStart = llvm::alignTo(
548+
distanceToStart, dataLayout.getTypeABIAlignment(elem));
549+
// If the offset is in padding, cancel the rewrite.
550+
if (offset < distanceToStart)
551+
return {};
552+
}
553+
554+
if (offset < distanceToStart + elemSize) {
555+
if (isOutOfBoundsGEPIndex(index))
556+
return {};
557+
// The offset is within this element, stop iterating the
558+
// struct and return the index.
559+
return SubslotAccessInfo{static_cast<uint32_t>(index),
560+
*offset - distanceToStart};
561+
}
562+
563+
// The offset is not within this element, continue walking
564+
// over the struct.
565+
distanceToStart += elemSize;
566+
}
567+
568+
return {};
569+
});
570+
}
571+
572+
/// Constructs a byte array type of the given size.
573+
static LLVM::LLVMArrayType getByteArrayType(MLIRContext *context,
574+
unsigned size) {
575+
auto byteType = IntegerType::get(context, 8);
576+
return LLVM::LLVMArrayType::get(context, byteType, size);
438577
}
439578

440579
LogicalResult LLVM::GEPOp::ensureOnlySafeAccesses(
441580
const MemorySlot &slot, SmallVectorImpl<MemorySlot> &mustBeSafelyUsed,
442581
const DataLayout &dataLayout) {
443582
if (getBase() != slot.ptr)
444583
return success();
445-
if (slot.elemType != getElemType())
446-
return failure();
447-
if (!isFirstIndexZero(*this))
584+
std::optional<uint64_t> gepOffset = gepToByteOffset(dataLayout, *this);
585+
if (!gepOffset)
448586
return failure();
449-
// Dynamic indices can be out-of-bounds (even negative), so an access with
450-
// dynamic indices can never be considered safe.
451-
if (!getDynamicIndices().empty())
587+
uint64_t slotSize = dataLayout.getTypeSize(slot.elemType);
588+
// Check that the access is strictly inside the slot.
589+
if (*gepOffset >= slotSize)
452590
return failure();
453-
Type reachedType = getResultPtrElementType();
454-
if (!reachedType)
455-
return failure();
456-
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
591+
// Every access that remains in bounds of the remaining slot is considered
592+
// legal.
593+
mustBeSafelyUsed.emplace_back<MemorySlot>(
594+
{getRes(), getByteArrayType(getContext(), slotSize - *gepOffset)});
457595
return success();
458596
}
459597

@@ -464,60 +602,45 @@ bool LLVM::GEPOp::canRewire(const DestructurableMemorySlot &slot,
464602
if (!isa<LLVM::LLVMPointerType>(getBase().getType()))
465603
return false;
466604

467-
if (getBase() != slot.ptr || slot.elemType != getElemType())
468-
return false;
469-
if (!isFirstIndexZero(*this))
470-
return false;
471-
// Dynamic indices can be out-of-bounds (even negative), so an access with
472-
// dynamic indices can never be properly rewired.
473-
if (!getDynamicIndices().empty())
474-
return false;
475-
Type reachedType = getResultPtrElementType();
476-
if (!reachedType || getIndices().size() < 2)
605+
if (getBase() != slot.ptr)
477606
return false;
478-
auto firstLevelIndex = dyn_cast<IntegerAttr>(getIndices()[1]);
479-
if (!firstLevelIndex)
607+
std::optional<SubslotAccessInfo> accessInfo =
608+
getSubslotAccessInfo(slot, dataLayout, *this);
609+
if (!accessInfo)
480610
return false;
481-
mustBeSafelyUsed.emplace_back<MemorySlot>({getResult(), reachedType});
482-
assert(slot.elementPtrs.contains(firstLevelIndex));
483-
usedIndices.insert(firstLevelIndex);
611+
auto indexAttr =
612+
IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
613+
assert(slot.elementPtrs.contains(indexAttr));
614+
usedIndices.insert(indexAttr);
615+
616+
// The remainder of the subslot should be accesses in-bounds. Thus, we create
617+
// a dummy slot with the size of the remainder.
618+
Type subslotType = slot.elementPtrs.lookup(indexAttr);
619+
uint64_t slotSize = dataLayout.getTypeSize(subslotType);
620+
LLVM::LLVMArrayType remainingSlotType =
621+
getByteArrayType(getContext(), slotSize - accessInfo->subslotOffset);
622+
mustBeSafelyUsed.emplace_back<MemorySlot>({getRes(), remainingSlotType});
623+
484624
return true;
485625
}
486626

487627
DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
488628
DenseMap<Attribute, MemorySlot> &subslots,
489629
RewriterBase &rewriter,
490630
const DataLayout &dataLayout) {
491-
IntegerAttr firstLevelIndex =
492-
llvm::dyn_cast_if_present<IntegerAttr>(getIndices()[1]);
493-
const MemorySlot &newSlot = subslots.at(firstLevelIndex);
494-
495-
ArrayRef<int32_t> remainingIndices = getRawConstantIndices().slice(2);
496-
497-
// If the GEP would become trivial after this transformation, eliminate it.
498-
// A GEP should only be eliminated if it has no indices (except the first
499-
// pointer index), as simplifying GEPs with all-zero indices would eliminate
500-
// structure information useful for further destruction.
501-
if (remainingIndices.empty()) {
502-
rewriter.replaceAllUsesWith(getResult(), newSlot.ptr);
503-
return DeletionKind::Delete;
504-
}
505-
506-
rewriter.modifyOpInPlace(*this, [&]() {
507-
// Rewire the indices by popping off the second index.
508-
// Start with a single zero, then add the indices beyond the second.
509-
SmallVector<int32_t> newIndices(1);
510-
newIndices.append(remainingIndices.begin(), remainingIndices.end());
511-
setRawConstantIndices(newIndices);
512-
513-
// Rewire the pointed type.
514-
setElemType(newSlot.elemType);
515-
516-
// Rewire the pointer.
517-
getBaseMutable().assign(newSlot.ptr);
518-
});
519-
520-
return DeletionKind::Keep;
631+
std::optional<SubslotAccessInfo> accessInfo =
632+
getSubslotAccessInfo(slot, dataLayout, *this);
633+
assert(accessInfo && "expected access info to be checked before");
634+
auto indexAttr =
635+
IntegerAttr::get(IntegerType::get(getContext(), 32), accessInfo->index);
636+
const MemorySlot &newSlot = subslots.at(indexAttr);
637+
638+
auto byteType = IntegerType::get(rewriter.getContext(), 8);
639+
auto newPtr = rewriter.createOrFold<LLVM::GEPOp>(
640+
getLoc(), getResult().getType(), byteType, newSlot.ptr,
641+
ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
642+
rewriter.replaceAllUsesWith(getResult(), newPtr);
643+
return DeletionKind::Delete;
521644
}
522645

523646
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)