Skip to content

[MLIR][Mem2Reg] Fix multi slot handling & move retry handling #91464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Interfaces/MemorySlotInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,12 @@ def PromotableAllocationOpInterface
Hook triggered once the promotion of a slot is complete. This can
also clean up the created default value if necessary.
This will only be called for slots declared by this operation.

Must return a new promotable allocation op if this operation produced
multiple promotable slots, nullopt otherwise.
}],
"void", "handlePromotionComplete",
"::std::optional<::mlir::PromotableAllocationOpInterface>",
"handlePromotionComplete",
(ins
"const ::mlir::MemorySlot &":$slot,
"::mlir::Value":$defaultValue,
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/Transforms/Mem2Reg.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#ifndef MLIR_TRANSFORMS_MEM2REG_H
#define MLIR_TRANSFORMS_MEM2REG_H

#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "llvm/ADT/Statistic.h"

Expand All @@ -23,8 +22,9 @@ struct Mem2RegStatistics {
llvm::Statistic *newBlockArgumentAmount = nullptr;
};

/// Attempts to promote the memory slots of the provided allocators. Succeeds if
/// at least one memory slot was promoted.
/// Attempts to promote the memory slots of the provided allocators. Iteratively
/// retries the promotion of all slots as promoting one slot might enable
/// subsequent promotions. Succeeds if at least one memory slot was promoted.
LogicalResult
tryToPromoteMemorySlots(ArrayRef<PromotableAllocationOpInterface> allocators,
OpBuilder &builder, const DataLayout &dataLayout,
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,14 @@ void LLVM::AllocaOp::handleBlockArgument(const MemorySlot &slot,
declareOp.getLocationExpr());
}

void LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
Value defaultValue,
OpBuilder &builder) {
std::optional<PromotableAllocationOpInterface>
LLVM::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
Value defaultValue,
OpBuilder &builder) {
if (defaultValue && defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
return std::nullopt;
}

SmallVector<DestructurableMemorySlot> LLVM::AllocaOp::getDestructurableSlots() {
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,14 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
});
}

void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
Value defaultValue,
OpBuilder &builder) {
std::optional<PromotableAllocationOpInterface>
memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
Value defaultValue,
OpBuilder &builder) {
if (defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();
this->erase();
return std::nullopt;
}

void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
Expand Down
85 changes: 56 additions & 29 deletions mlir/lib/Transforms/Mem2Reg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ class MemorySlotPromoter {
/// Actually promotes the slot by mutating IR. Promoting a slot DOES
/// invalidate the MemorySlotPromotionInfo of other slots. Preparation of
/// promotion info should NOT be performed in batches.
void promoteSlot();
/// Returns a promotable allocation op if a new allocator was created, nullopt
/// otherwise.
std::optional<PromotableAllocationOpInterface> promoteSlot();

private:
/// Computes the reaching definition for all the operations that require
Expand Down Expand Up @@ -595,7 +597,8 @@ void MemorySlotPromoter::removeBlockingUses() {
"after promotion, the slot pointer should not be used anymore");
}

void MemorySlotPromoter::promoteSlot() {
std::optional<PromotableAllocationOpInterface>
MemorySlotPromoter::promoteSlot() {
computeReachingDefInRegion(slot.ptr.getParentRegion(),
getOrCreateDefaultValue());

Expand All @@ -622,7 +625,7 @@ void MemorySlotPromoter::promoteSlot() {
if (statistics.promotedAmount)
(*statistics.promotedAmount)++;

allocator.handlePromotionComplete(slot, defaultValue, builder);
return allocator.handlePromotionComplete(slot, defaultValue, builder);
}

LogicalResult mlir::tryToPromoteMemorySlots(
Expand All @@ -636,20 +639,50 @@ LogicalResult mlir::tryToPromoteMemorySlots(
// lazily and cached to avoid expensive recomputation.
BlockIndexCache blockIndexCache;

for (PromotableAllocationOpInterface allocator : allocators) {
for (MemorySlot slot : allocator.getPromotableSlots()) {
if (slot.ptr.use_empty())
continue;

MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
MemorySlotPromoter(slot, allocator, builder, dominance, dataLayout,
std::move(*info), statistics, blockIndexCache)
.promoteSlot();
promotedAny = true;
SmallVector<PromotableAllocationOpInterface> workList(allocators.begin(),
allocators.end());

SmallVector<PromotableAllocationOpInterface> newWorkList;
newWorkList.reserve(workList.size());
while (true) {
bool changesInThisRound = false;
for (PromotableAllocationOpInterface allocator : workList) {
bool changedAllocator = false;
for (MemorySlot slot : allocator.getPromotableSlots()) {
if (slot.ptr.use_empty())
continue;

MemorySlotPromotionAnalyzer analyzer(slot, dominance, dataLayout);
std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo();
if (info) {
std::optional<PromotableAllocationOpInterface> newAllocator =
MemorySlotPromoter(slot, allocator, builder, dominance,
dataLayout, std::move(*info), statistics,
blockIndexCache)
.promoteSlot();
changedAllocator = true;
// Add newly created allocators to the worklist for further
// processing.
if (newAllocator)
newWorkList.push_back(*newAllocator);

// A break is required, since promoting a slot may invalidate the
// remaining slots of an allocator.
break;
}
}
if (!changedAllocator)
newWorkList.push_back(allocator);
changesInThisRound |= changedAllocator;
}
if (!changesInThisRound)
break;
promotedAny = true;

// Swap the vector's backing memory and clear the entries in newWorkList
// afterwards. This ensures that additional heap allocations can be avoided.
workList.swap(newWorkList);
newWorkList.clear();
}

return success(promotedAny);
Expand Down Expand Up @@ -677,22 +710,16 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {

OpBuilder builder(&region.front(), region.front().begin());

// Promoting a slot can allow for further promotion of other slots,
// promotion is tried until no promotion succeeds.
while (true) {
SmallVector<PromotableAllocationOpInterface> allocators;
// Build a list of allocators to attempt to promote the slots of.
region.walk([&](PromotableAllocationOpInterface allocator) {
allocators.emplace_back(allocator);
});

// Attempt promoting until no promotion succeeds.
if (failed(tryToPromoteMemorySlots(allocators, builder, dataLayout,
dominance, statistics)))
break;
SmallVector<PromotableAllocationOpInterface> allocators;
// Build a list of allocators to attempt to promote the slots of.
region.walk([&](PromotableAllocationOpInterface allocator) {
allocators.emplace_back(allocator);
});

// Attempt promoting as many of the slots as possible.
if (succeeded(tryToPromoteMemorySlots(allocators, builder, dataLayout,
dominance, statistics)))
changed = true;
}
}
if (!changed)
markAllAnalysesPreserved();
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Transforms/mem2reg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s

// Verifies that allocators with mutliple slots are handled properly.

// CHECK-LABEL: func.func @multi_slot_alloca
func.func @multi_slot_alloca() -> (i32, i32) {
// CHECK-NOT: test.multi_slot_alloca
%1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
%3 = memref.load %1[] : memref<i32>
%4 = memref.load %2[] : memref<i32>
return %3, %4 : i32, i32
}

// -----

// Verifies that a multi slot allocator can be partially promoted.

func.func private @consumer(memref<i32>)

// CHECK-LABEL: func.func @multi_slot_alloca_only_second
func.func @multi_slot_alloca_only_second() -> (i32, i32) {
// CHECK: %{{[[:alnum:]]+}} = test.multi_slot_alloca
%1, %2 = test.multi_slot_alloca : () -> (memref<i32>, memref<i32>)
func.call @consumer(%1) : (memref<i32>) -> ()
%3 = memref.load %1[] : memref<i32>
%4 = memref.load %2[] : memref<i32>
return %3, %4 : i32, i32
}
59 changes: 59 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOpDefs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"

using namespace mlir;
using namespace test;
Expand Down Expand Up @@ -1172,3 +1173,61 @@ void TestOpWithVersionedProperties::writeToMlirBytecode(
writer.writeVarInt(prop.value1);
writer.writeVarInt(prop.value2);
}

//===----------------------------------------------------------------------===//
// TestMultiSlotAlloca
//===----------------------------------------------------------------------===//

llvm::SmallVector<MemorySlot> TestMultiSlotAlloca::getPromotableSlots() {
SmallVector<MemorySlot> slots;
for (Value result : getResults()) {
slots.push_back(MemorySlot{
result, cast<MemRefType>(result.getType()).getElementType()});
}
return slots;
}

Value TestMultiSlotAlloca::getDefaultValue(const MemorySlot &slot,
OpBuilder &builder) {
return builder.create<TestOpConstant>(getLoc(), slot.elemType,
builder.getI32IntegerAttr(42));
}

void TestMultiSlotAlloca::handleBlockArgument(const MemorySlot &slot,
BlockArgument argument,
OpBuilder &builder) {
// Not relevant for testing.
}

std::optional<PromotableAllocationOpInterface>
TestMultiSlotAlloca::handlePromotionComplete(const MemorySlot &slot,
Value defaultValue,
OpBuilder &builder) {
if (defaultValue && defaultValue.use_empty())
defaultValue.getDefiningOp()->erase();

if (getNumResults() == 1) {
erase();
return std::nullopt;
}

SmallVector<Type> newTypes;
SmallVector<Value> remainingValues;

for (Value oldResult : getResults()) {
if (oldResult == slot.ptr)
continue;
remainingValues.push_back(oldResult);
newTypes.push_back(oldResult.getType());
}

OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(*this);
auto replacement = builder.create<TestMultiSlotAlloca>(getLoc(), newTypes);
for (auto [oldResult, newResult] :
llvm::zip_equal(remainingValues, replacement.getResults()))
oldResult.replaceAllUsesWith(newResult);

erase();
return replacement;
}
1 change: 1 addition & 0 deletions mlir/test/lib/Dialect/Test/TestOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/MemorySlotInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
Expand Down
11 changes: 11 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include "mlir/Interfaces/DestinationStyleOpInterface.td"
include "mlir/Interfaces/InferIntRangeInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/MemorySlotInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"


Expand Down Expand Up @@ -3167,4 +3168,14 @@ def TestOpOptionallyImplementingInterface
let arguments = (ins BoolAttr:$implementsInterface);
}

//===----------------------------------------------------------------------===//
// Test Mem2Reg
//===----------------------------------------------------------------------===//

def TestMultiSlotAlloca : TEST_Op<"multi_slot_alloca",
[DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
let results = (outs Variadic<MemRefOf<[I32]>>:$results);
let assemblyFormat = "attr-dict `:` functional-type(operands, results)";
}

#endif // TEST_OPS
Loading