Skip to content

Commit a75c52d

Browse files
committed
Fix block merging
1 parent 0cfd03a commit a75c52d

File tree

3 files changed

+137
-17
lines changed

3 files changed

+137
-17
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
463463
SplitDeallocWhenNotAliasingAnyOther,
464464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465465
analysis);
466+
// We don't want that the block structure changes invalidating the
467+
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
468+
// region simplification
469+
GreedyRewriteConfig config;
470+
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
466471
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
467472

468-
if (failed(
469-
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
473+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
474+
config)))
470475
signalPassFailure();
471476
}
472477
};

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 127 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@
99
#include "mlir/Transforms/RegionUtils.h"
1010
#include "mlir/Analysis/TopologicalSortUtils.h"
1111
#include "mlir/IR/Block.h"
12+
#include "mlir/IR/BuiltinOps.h"
1213
#include "mlir/IR/IRMapping.h"
1314
#include "mlir/IR/Operation.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/IR/RegionGraphTraits.h"
1617
#include "mlir/IR/Value.h"
1718
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1819
#include "mlir/Interfaces/SideEffectInterfaces.h"
20+
#include "mlir/Support/LogicalResult.h"
1921

2022
#include "llvm/ADT/DepthFirstIterator.h"
2123
#include "llvm/ADT/PostOrderIterator.h"
24+
#include "llvm/ADT/STLExtras.h"
25+
#include "llvm/ADT/SmallSet.h"
2226

2327
#include <deque>
28+
#include <iterator>
2429

2530
using namespace mlir;
2631

@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
699704
blockIterators.push_back(mergeBlock->begin());
700705

701706
// Update each of the predecessor terminators with the new arguments.
702-
SmallVector<SmallVector<Value, 8>, 2> newArguments(
703-
1 + blocksToMerge.size(),
704-
SmallVector<Value, 8>(operandsToMerge.size()));
707+
SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
708+
SmallVector<Value, 8>());
705709
unsigned curOpIndex = 0;
706710
for (const auto &it : llvm::enumerate(operandsToMerge)) {
707711
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
712716
Block::iterator &blockIter = blockIterators[i];
713717
std::advance(blockIter, nextOpOffset);
714718
auto &operand = blockIter->getOpOperand(it.value().second);
715-
newArguments[i][it.index()] = operand.get();
716-
717-
// Update the operand and insert an argument if this is the leader.
718-
if (i == 0) {
719-
Value operandVal = operand.get();
720-
operand.set(leaderBlock->addArgument(operandVal.getType(),
721-
operandVal.getLoc()));
719+
Value operandVal = operand.get();
720+
Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
721+
operandVal);
722+
if (it == newArguments[i].end()) {
723+
newArguments[i].push_back(operandVal);
724+
// Update the operand and insert an argument if this is the leader.
725+
if (i == 0) {
726+
operand.set(leaderBlock->addArgument(operandVal.getType(),
727+
operandVal.getLoc()));
728+
}
729+
} else if (i == 0) {
730+
// If this is the leader, update the operand but do not insert a new
731+
// argument. Instead, the opearand should point to one of the
732+
// arguments we already passed (and that contained `operandVal`)
733+
operand.set(leaderBlock->getArgument(
734+
std::distance(newArguments[i].begin(), it)));
722735
}
723736
}
724737
}
@@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
818831
return success(anyChanged);
819832
}
820833

834+
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
835+
Block &block) {
836+
SmallVector<size_t> argsToErase;
837+
838+
// Go through the arguments of the block
839+
for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
840+
bool sameArg = true;
841+
Value commonValue;
842+
843+
// Go through the block predecessor and flag if they pass to the block
844+
// different values for the same argument
845+
for (auto predIt = block.pred_begin(), predE = block.pred_end();
846+
predIt != predE; ++predIt) {
847+
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
848+
if (!branch) {
849+
sameArg = false;
850+
break;
851+
}
852+
unsigned succIndex = predIt.getSuccessorIndex();
853+
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
854+
auto operands = succOperands.getForwardedOperands();
855+
if (!commonValue) {
856+
commonValue = operands[argIdx];
857+
} else {
858+
if (operands[argIdx] != commonValue) {
859+
sameArg = false;
860+
break;
861+
}
862+
}
863+
}
864+
865+
// If they are passing the same value, drop the argument
866+
if (commonValue && sameArg) {
867+
argsToErase.push_back(argIdx);
868+
869+
// Remove the argument from the block
870+
Value argVal = block.getArgument(argIdx);
871+
rewriter.replaceAllUsesWith(argVal, commonValue);
872+
}
873+
}
874+
875+
// Remove the arguments
876+
for (auto argIdx : llvm::reverse(argsToErase)) {
877+
block.eraseArgument(argIdx);
878+
879+
// Remove the argument from the branch ops
880+
for (auto predIt = block.pred_begin(), predE = block.pred_end();
881+
predIt != predE; ++predIt) {
882+
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
883+
unsigned succIndex = predIt.getSuccessorIndex();
884+
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
885+
succOperands.erase(argIdx);
886+
}
887+
}
888+
return success(!argsToErase.empty());
889+
}
890+
891+
/// This optimization drops redundant argument to blocks. I.e., if a given
892+
/// argument to a block receives the same value from each of the block
893+
/// predecessors, we can remove the argument from the block and use directly the
894+
/// original value. This is a simple example:
895+
///
896+
/// %cond = llvm.call @rand() : () -> i1
897+
/// %val0 = llvm.mlir.constant(1 : i64) : i64
898+
/// %val1 = llvm.mlir.constant(2 : i64) : i64
899+
/// %val2 = llvm.mlir.constant(2 : i64) : i64
900+
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
901+
/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
902+
/// llvm.call @foo(%arg0, %arg1)
903+
///
904+
/// The previous IR can be rewritten as:
905+
/// %cond = llvm.call @rand() : () -> i1
906+
/// %val = llvm.mlir.constant(1 : i64) : i64
907+
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
908+
/// ^bb1(%arg0 : i64):
909+
/// llvm.call @foo(%val0, %arg1)
910+
///
911+
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
912+
MutableArrayRef<Region> regions) {
913+
llvm::SmallSetVector<Region *, 1> worklist;
914+
for (auto &region : regions)
915+
worklist.insert(&region);
916+
bool anyChanged = false;
917+
while (!worklist.empty()) {
918+
Region *region = worklist.pop_back_val();
919+
920+
// Add any nested regions to the worklist.
921+
for (Block &block : *region) {
922+
anyChanged = succeeded(dropRedundantArguments(rewriter, block));
923+
924+
for (auto &op : block)
925+
for (auto &nestedRegion : op.getRegions())
926+
worklist.insert(&nestedRegion);
927+
}
928+
}
929+
return success(anyChanged);
930+
}
931+
821932
//===----------------------------------------------------------------------===//
822933
// Region Simplification
823934
//===----------------------------------------------------------------------===//
@@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
832943
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
833944
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
834945
bool mergedIdenticalBlocks = false;
835-
if (mergeBlocks)
946+
bool droppedRedundantArguments = false;
947+
if (mergeBlocks) {
836948
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
949+
droppedRedundantArguments =
950+
succeeded(dropRedundantArguments(rewriter, regions));
951+
}
837952
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
838-
mergedIdenticalBlocks);
953+
mergedIdenticalBlocks || droppedRedundantArguments);
839954
}

mlir/test/Dialect/Linalg/detensorize_entry_block.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
1515
// CHECK-LABEL: @main
1616
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
1717
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
18-
// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
19-
// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
20-
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
18+
// CHECK: cf.br ^{{.*}}
19+
// CHECK: ^{{.*}}:
20+
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
2121
// CHECK: return %[[ELEMENTS]] : tensor<f32>

0 commit comments

Comments
 (0)