Skip to content

Commit 441b672

Browse files
authored
[mlir] Fix block merging (#102038)
With this PR I am trying to address: #63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note-1**: I ran all the integration tests (`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed. **Note-2**: I fixed a bug found by @Dinistro in #97697 . The issue was that, when looking for redundant arguments, I was not considering that the block might have already some arguments. So the index (in the block args list) of the i-th `newArgument` is `i+numOfOldArguments`.
1 parent 734c048 commit 441b672

File tree

12 files changed

+483
-83
lines changed

12 files changed

+483
-83
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
463463
SplitDeallocWhenNotAliasingAnyOther,
464464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465465
analysis);
466+
// We don't want that the block structure changes invalidating the
467+
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
468+
// region simplification
469+
GreedyRewriteConfig config;
470+
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
466471
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
467472

468-
if (failed(
469-
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
473+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
474+
config)))
470475
signalPassFailure();
471476
}
472477
};

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 210 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@
99
#include "mlir/Transforms/RegionUtils.h"
1010
#include "mlir/Analysis/TopologicalSortUtils.h"
1111
#include "mlir/IR/Block.h"
12+
#include "mlir/IR/BuiltinOps.h"
1213
#include "mlir/IR/IRMapping.h"
1314
#include "mlir/IR/Operation.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/IR/RegionGraphTraits.h"
1617
#include "mlir/IR/Value.h"
1718
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1819
#include "mlir/Interfaces/SideEffectInterfaces.h"
20+
#include "mlir/Support/LogicalResult.h"
1921

2022
#include "llvm/ADT/DepthFirstIterator.h"
2123
#include "llvm/ADT/PostOrderIterator.h"
24+
#include "llvm/ADT/STLExtras.h"
25+
#include "llvm/ADT/SmallSet.h"
2226

2327
#include <deque>
28+
#include <iterator>
2429

2530
using namespace mlir;
2631

@@ -674,6 +679,94 @@ static bool ableToUpdatePredOperands(Block *block) {
674679
return true;
675680
}
676681

682+
/// Prunes the redundant list of new arguments. E.g., if we are passing an
683+
/// argument list like [x, y, z, x] this would return [x, y, z] and it would
684+
/// update the `block` (to whom the argument are passed to) accordingly. The new
685+
/// arguments are passed as arguments at the back of the block, hence we need to
686+
/// know how many `numOldArguments` were before, in order to correctly replace
687+
/// the new arguments in the block
688+
static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
689+
const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
690+
RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
691+
692+
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
693+
newArguments.size(), SmallVector<Value, 8>());
694+
695+
if (newArguments.empty())
696+
return newArguments;
697+
698+
// `newArguments` is a 2D array of size `numLists` x `numArgs`
699+
unsigned numLists = newArguments.size();
700+
unsigned numArgs = newArguments[0].size();
701+
702+
// Map that for each arg index contains the index that we can use in place of
703+
// the original index. E.g., if we have newArgs = [x, y, z, x], we will have
704+
// idxToReplacement[3] = 0
705+
llvm::DenseMap<unsigned, unsigned> idxToReplacement;
706+
707+
// This is a useful data structure to track the first appearance of a Value
708+
// on a given list of arguments
709+
DenseMap<Value, unsigned> firstValueToIdx;
710+
for (unsigned j = 0; j < numArgs; ++j) {
711+
Value newArg = newArguments[0][j];
712+
if (!firstValueToIdx.contains(newArg))
713+
firstValueToIdx[newArg] = j;
714+
}
715+
716+
// Go through the first list of arguments (list 0).
717+
for (unsigned j = 0; j < numArgs; ++j) {
718+
// Look back to see if there are possible redundancies in list 0. Please
719+
// note that we are using a map to annotate when an argument was seen first
720+
// to avoid a O(N^2) algorithm. This has the drawback that if we have two
721+
// lists like:
722+
// list0: [%a, %a, %a]
723+
// list1: [%c, %b, %b]
724+
// We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
725+
// point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
726+
// the number of arguments can be potentially unbounded we cannot afford a
727+
// O(N^2) algorithm (to search to all the possible pairs) and we need to
728+
// accept the trade-off.
729+
unsigned k = firstValueToIdx[newArguments[0][j]];
730+
if (k == j)
731+
continue;
732+
733+
bool shouldReplaceJ = true;
734+
unsigned replacement = k;
735+
// If a possible redundancy is found, then scan the other lists: we
736+
// can prune the arguments if and only if they are redundant in every
737+
// list.
738+
for (unsigned i = 1; i < numLists; ++i)
739+
shouldReplaceJ =
740+
shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
741+
// Save the replacement.
742+
if (shouldReplaceJ)
743+
idxToReplacement[j] = replacement;
744+
}
745+
746+
// Populate the pruned argument list.
747+
for (unsigned i = 0; i < numLists; ++i)
748+
for (unsigned j = 0; j < numArgs; ++j)
749+
if (!idxToReplacement.contains(j))
750+
newArgumentsPruned[i].push_back(newArguments[i][j]);
751+
752+
// Replace the block's redundant arguments.
753+
SmallVector<unsigned> toErase;
754+
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
755+
if (idxToReplacement.contains(idx)) {
756+
Value oldArg = block->getArgument(numOldArguments + idx);
757+
Value newArg =
758+
block->getArgument(numOldArguments + idxToReplacement[idx]);
759+
rewriter.replaceAllUsesWith(oldArg, newArg);
760+
toErase.push_back(numOldArguments + idx);
761+
}
762+
}
763+
764+
// Erase the block's redundant arguments.
765+
for (unsigned idxToErase : llvm::reverse(toErase))
766+
block->eraseArgument(idxToErase);
767+
return newArgumentsPruned;
768+
}
769+
677770
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
678771
// Don't consider clusters that don't have blocks to merge.
679772
if (blocksToMerge.empty())
@@ -703,6 +796,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
703796
1 + blocksToMerge.size(),
704797
SmallVector<Value, 8>(operandsToMerge.size()));
705798
unsigned curOpIndex = 0;
799+
unsigned numOldArguments = leaderBlock->getNumArguments();
706800
for (const auto &it : llvm::enumerate(operandsToMerge)) {
707801
unsigned nextOpOffset = it.value().first - curOpIndex;
708802
curOpIndex = it.value().first;
@@ -722,6 +816,11 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
722816
}
723817
}
724818
}
819+
820+
// Prune redundant arguments and update the leader block argument list
821+
newArguments = pruneRedundantArguments(newArguments, rewriter,
822+
numOldArguments, leaderBlock);
823+
725824
// Update the predecessors for each of the blocks.
726825
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
727826
for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -818,6 +917,111 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
818917
return success(anyChanged);
819918
}
820919

920+
/// If a block's argument is always the same across different invocations, then
921+
/// drop the argument and use the value directly inside the block
922+
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
923+
Block &block) {
924+
SmallVector<size_t> argsToErase;
925+
926+
// Go through the arguments of the block.
927+
for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
928+
bool sameArg = true;
929+
Value commonValue;
930+
931+
// Go through the block predecessor and flag if they pass to the block
932+
// different values for the same argument.
933+
for (Block::pred_iterator predIt = block.pred_begin(),
934+
predE = block.pred_end();
935+
predIt != predE; ++predIt) {
936+
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
937+
if (!branch) {
938+
sameArg = false;
939+
break;
940+
}
941+
unsigned succIndex = predIt.getSuccessorIndex();
942+
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
943+
auto branchOperands = succOperands.getForwardedOperands();
944+
if (!commonValue) {
945+
commonValue = branchOperands[argIdx];
946+
continue;
947+
}
948+
if (branchOperands[argIdx] != commonValue) {
949+
sameArg = false;
950+
break;
951+
}
952+
}
953+
954+
// If they are passing the same value, drop the argument.
955+
if (commonValue && sameArg) {
956+
argsToErase.push_back(argIdx);
957+
958+
// Remove the argument from the block.
959+
rewriter.replaceAllUsesWith(blockOperand, commonValue);
960+
}
961+
}
962+
963+
// Remove the arguments.
964+
for (size_t argIdx : llvm::reverse(argsToErase)) {
965+
block.eraseArgument(argIdx);
966+
967+
// Remove the argument from the branch ops.
968+
for (auto predIt = block.pred_begin(), predE = block.pred_end();
969+
predIt != predE; ++predIt) {
970+
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
971+
unsigned succIndex = predIt.getSuccessorIndex();
972+
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
973+
succOperands.erase(argIdx);
974+
}
975+
}
976+
return success(!argsToErase.empty());
977+
}
978+
979+
/// This optimization drops redundant argument to blocks. I.e., if a given
980+
/// argument to a block receives the same value from each of the block
981+
/// predecessors, we can remove the argument from the block and use directly the
982+
/// original value. This is a simple example:
983+
///
984+
/// %cond = llvm.call @rand() : () -> i1
985+
/// %val0 = llvm.mlir.constant(1 : i64) : i64
986+
/// %val1 = llvm.mlir.constant(2 : i64) : i64
987+
/// %val2 = llvm.mlir.constant(3 : i64) : i64
988+
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
989+
/// : i64)
990+
///
991+
/// ^bb1(%arg0 : i64, %arg1 : i64):
992+
/// llvm.call @foo(%arg0, %arg1)
993+
///
994+
/// The previous IR can be rewritten as:
995+
/// %cond = llvm.call @rand() : () -> i1
996+
/// %val0 = llvm.mlir.constant(1 : i64) : i64
997+
/// %val1 = llvm.mlir.constant(2 : i64) : i64
998+
/// %val2 = llvm.mlir.constant(3 : i64) : i64
999+
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
1000+
///
1001+
/// ^bb1(%arg0 : i64):
1002+
/// llvm.call @foo(%val0, %arg0)
1003+
///
1004+
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
1005+
MutableArrayRef<Region> regions) {
1006+
llvm::SmallSetVector<Region *, 1> worklist;
1007+
for (Region &region : regions)
1008+
worklist.insert(&region);
1009+
bool anyChanged = false;
1010+
while (!worklist.empty()) {
1011+
Region *region = worklist.pop_back_val();
1012+
1013+
// Add any nested regions to the worklist.
1014+
for (Block &block : *region) {
1015+
anyChanged = succeeded(dropRedundantArguments(rewriter, block));
1016+
1017+
for (Operation &op : block)
1018+
for (Region &nestedRegion : op.getRegions())
1019+
worklist.insert(&nestedRegion);
1020+
}
1021+
}
1022+
return success(anyChanged);
1023+
}
1024+
8211025
//===----------------------------------------------------------------------===//
8221026
// Region Simplification
8231027
//===----------------------------------------------------------------------===//
@@ -832,8 +1036,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
8321036
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
8331037
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
8341038
bool mergedIdenticalBlocks = false;
835-
if (mergeBlocks)
1039+
bool droppedRedundantArguments = false;
1040+
if (mergeBlocks) {
8361041
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
1042+
droppedRedundantArguments =
1043+
succeeded(dropRedundantArguments(rewriter, regions));
1044+
}
8371045
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
838-
mergedIdenticalBlocks);
1046+
mergedIdenticalBlocks || droppedRedundantArguments);
8391047
}

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,28 +178,32 @@ func.func @condBranchDynamicTypeNested(
178178
// CHECK-NEXT: ^bb1
179179
// CHECK-NOT: bufferization.dealloc
180180
// CHECK-NOT: bufferization.clone
181-
// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
181+
// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
182182
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
183183
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
184184
// CHECK-NEXT: test.buffer_based
185185
// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
186186
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
187187
// CHECK-NOT: bufferization.dealloc
188188
// CHECK-NOT: bufferization.clone
189-
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
189+
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
190190
// CHECK-NEXT: ^bb3:
191191
// CHECK-NOT: bufferization.dealloc
192192
// CHECK-NOT: bufferization.clone
193-
// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
194-
// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
193+
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
194+
// CHECK-NEXT: ^bb4:
195195
// CHECK-NOT: bufferization.dealloc
196196
// CHECK-NOT: bufferization.clone
197-
// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
198-
// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
197+
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
198+
// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
199+
// CHECK-NOT: bufferization.dealloc
200+
// CHECK-NOT: bufferization.clone
201+
// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
202+
// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
199203
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
200204
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
201-
// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
202-
// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
205+
// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
206+
// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
203207
// CHECK: test.copy
204208
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
205209
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])

mlir/test/Dialect/Linalg/detensorize_entry_block.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
1515
// CHECK-LABEL: @main
1616
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
1717
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
18-
// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
19-
// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
20-
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
18+
// CHECK: cf.br ^{{.*}}
19+
// CHECK: ^{{.*}}:
20+
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
2121
// CHECK: return %[[ELEMENTS]] : tensor<f32>

0 commit comments

Comments
 (0)