Skip to content

Commit f534674

Browse files
giuserosyuxuanchen1997
authored andcommitted
[mlir] Fix block merging (#97697)
Summary: With this PR I am trying to address: #63230. What changed: - While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same `Value`. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block - After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument. - This last simplification clashed with `BufferDeallocationSimplification`. The reason, I think, is that the two simplifications are clashing. I.e., `BufferDeallocationSimplification` contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass. **Note**: this a rework of #96871 . I ran all the integration tests (`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250916
1 parent 6064e7e commit f534674

File tree

12 files changed

+445
-83
lines changed

12 files changed

+445
-83
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
463463
SplitDeallocWhenNotAliasingAnyOther,
464464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465465
analysis);
466+
// We don't want that the block structure changes invalidating the
467+
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
468+
// region simplification
469+
GreedyRewriteConfig config;
470+
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
466471
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
467472

468-
if (failed(
469-
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
473+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
474+
config)))
470475
signalPassFailure();
471476
}
472477
};

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 202 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@
99
#include "mlir/Transforms/RegionUtils.h"
1010
#include "mlir/Analysis/TopologicalSortUtils.h"
1111
#include "mlir/IR/Block.h"
12+
#include "mlir/IR/BuiltinOps.h"
1213
#include "mlir/IR/IRMapping.h"
1314
#include "mlir/IR/Operation.h"
1415
#include "mlir/IR/PatternMatch.h"
1516
#include "mlir/IR/RegionGraphTraits.h"
1617
#include "mlir/IR/Value.h"
1718
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1819
#include "mlir/Interfaces/SideEffectInterfaces.h"
20+
#include "mlir/Support/LogicalResult.h"
1921

2022
#include "llvm/ADT/DepthFirstIterator.h"
2123
#include "llvm/ADT/PostOrderIterator.h"
24+
#include "llvm/ADT/STLExtras.h"
25+
#include "llvm/ADT/SmallSet.h"
2226

2327
#include <deque>
28+
#include <iterator>
2429

2530
using namespace mlir;
2631

@@ -674,6 +679,91 @@ static bool ableToUpdatePredOperands(Block *block) {
674679
return true;
675680
}
676681

682+
/// Prunes the redundant list of arguments. E.g., if we are passing an argument
683+
/// list like [x, y, z, x] this would return [x, y, z] and it would update the
684+
/// `block` (to whom the argument are passed to) accordingly.
685+
static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
686+
const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
687+
RewriterBase &rewriter, Block *block) {
688+
689+
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
690+
newArguments.size(), SmallVector<Value, 8>());
691+
692+
if (newArguments.empty())
693+
return newArguments;
694+
695+
// `newArguments` is a 2D array of size `numLists` x `numArgs`
696+
unsigned numLists = newArguments.size();
697+
unsigned numArgs = newArguments[0].size();
698+
699+
// Map that for each arg index contains the index that we can use in place of
700+
// the original index. E.g., if we have newArgs = [x, y, z, x], we will have
701+
// idxToReplacement[3] = 0
702+
llvm::DenseMap<unsigned, unsigned> idxToReplacement;
703+
704+
// This is a useful data structure to track the first appearance of a Value
705+
// on a given list of arguments
706+
DenseMap<Value, unsigned> firstValueToIdx;
707+
for (unsigned j = 0; j < numArgs; ++j) {
708+
Value newArg = newArguments[0][j];
709+
if (!firstValueToIdx.contains(newArg))
710+
firstValueToIdx[newArg] = j;
711+
}
712+
713+
// Go through the first list of arguments (list 0).
714+
for (unsigned j = 0; j < numArgs; ++j) {
715+
bool shouldReplaceJ = false;
716+
unsigned replacement = 0;
717+
// Look back to see if there are possible redundancies in list 0. Please
718+
// note that we are using a map to annotate when an argument was seen first
719+
// to avoid a O(N^2) algorithm. This has the drawback that if we have two
720+
// lists like:
721+
// list0: [%a, %a, %a]
722+
// list1: [%c, %b, %b]
723+
// We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
724+
// point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
725+
// the number of arguments can be potentially unbounded we cannot afford a
726+
// O(N^2) algorithm (to search to all the possible pairs) and we need to
727+
// accept the trade-off.
728+
unsigned k = firstValueToIdx[newArguments[0][j]];
729+
if (k != j) {
730+
shouldReplaceJ = true;
731+
replacement = k;
732+
// If a possible redundancy is found, then scan the other lists: we
733+
// can prune the arguments if and only if they are redundant in every
734+
// list.
735+
for (unsigned i = 1; i < numLists; ++i)
736+
shouldReplaceJ =
737+
shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
738+
}
739+
// Save the replacement.
740+
if (shouldReplaceJ)
741+
idxToReplacement[j] = replacement;
742+
}
743+
744+
// Populate the pruned argument list.
745+
for (unsigned i = 0; i < numLists; ++i)
746+
for (unsigned j = 0; j < numArgs; ++j)
747+
if (!idxToReplacement.contains(j))
748+
newArgumentsPruned[i].push_back(newArguments[i][j]);
749+
750+
// Replace the block's redundant arguments.
751+
SmallVector<unsigned> toErase;
752+
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
753+
if (idxToReplacement.contains(idx)) {
754+
Value oldArg = block->getArgument(idx);
755+
Value newArg = block->getArgument(idxToReplacement[idx]);
756+
rewriter.replaceAllUsesWith(oldArg, newArg);
757+
toErase.push_back(idx);
758+
}
759+
}
760+
761+
// Erase the block's redundant arguments.
762+
for (unsigned idxToErase : llvm::reverse(toErase))
763+
block->eraseArgument(idxToErase);
764+
return newArgumentsPruned;
765+
}
766+
677767
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
678768
// Don't consider clusters that don't have blocks to merge.
679769
if (blocksToMerge.empty())
@@ -722,6 +812,10 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
722812
}
723813
}
724814
}
815+
816+
// Prune redundant arguments and update the leader block argument list
817+
newArguments = pruneRedundantArguments(newArguments, rewriter, leaderBlock);
818+
725819
// Update the predecessors for each of the blocks.
726820
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
727821
for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -818,6 +912,108 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
818912
return success(anyChanged);
819913
}
820914

915+
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
916+
Block &block) {
917+
SmallVector<size_t> argsToErase;
918+
919+
// Go through the arguments of the block.
920+
for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
921+
bool sameArg = true;
922+
Value commonValue;
923+
924+
// Go through the block predecessor and flag if they pass to the block
925+
// different values for the same argument.
926+
for (auto predIt = block.pred_begin(), predE = block.pred_end();
927+
predIt != predE; ++predIt) {
928+
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
929+
if (!branch) {
930+
sameArg = false;
931+
break;
932+
}
933+
unsigned succIndex = predIt.getSuccessorIndex();
934+
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
935+
auto branchOperands = succOperands.getForwardedOperands();
936+
if (!commonValue) {
937+
commonValue = branchOperands[argIdx];
938+
} else {
939+
if (branchOperands[argIdx] != commonValue) {
940+
sameArg = false;
941+
break;
942+
}
943+
}
944+
}
945+
946+
// If they are passing the same value, drop the argument.
947+
if (commonValue && sameArg) {
948+
argsToErase.push_back(argIdx);
949+
950+
// Remove the argument from the block.
951+
rewriter.replaceAllUsesWith(blockOperand, commonValue);
952+
}
953+
}
954+
955+
// Remove the arguments.
956+
for (auto argIdx : llvm::reverse(argsToErase)) {
957+
block.eraseArgument(argIdx);
958+
959+
// Remove the argument from the branch ops.
960+
for (auto predIt = block.pred_begin(), predE = block.pred_end();
961+
predIt != predE; ++predIt) {
962+
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
963+
unsigned succIndex = predIt.getSuccessorIndex();
964+
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
965+
succOperands.erase(argIdx);
966+
}
967+
}
968+
return success(!argsToErase.empty());
969+
}
970+
971+
/// This optimization drops redundant argument to blocks. I.e., if a given
972+
/// argument to a block receives the same value from each of the block
973+
/// predecessors, we can remove the argument from the block and use directly the
974+
/// original value. This is a simple example:
975+
///
976+
/// %cond = llvm.call @rand() : () -> i1
977+
/// %val0 = llvm.mlir.constant(1 : i64) : i64
978+
/// %val1 = llvm.mlir.constant(2 : i64) : i64
979+
/// %val2 = llvm.mlir.constant(3 : i64) : i64
980+
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
981+
/// : i64)
982+
///
983+
/// ^bb1(%arg0 : i64, %arg1 : i64):
984+
/// llvm.call @foo(%arg0, %arg1)
985+
///
986+
/// The previous IR can be rewritten as:
987+
/// %cond = llvm.call @rand() : () -> i1
988+
/// %val0 = llvm.mlir.constant(1 : i64) : i64
989+
/// %val1 = llvm.mlir.constant(2 : i64) : i64
990+
/// %val2 = llvm.mlir.constant(3 : i64) : i64
991+
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
992+
///
993+
/// ^bb1(%arg0 : i64):
994+
/// llvm.call @foo(%val0, %arg0)
995+
///
996+
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
997+
MutableArrayRef<Region> regions) {
998+
llvm::SmallSetVector<Region *, 1> worklist;
999+
for (Region &region : regions)
1000+
worklist.insert(&region);
1001+
bool anyChanged = false;
1002+
while (!worklist.empty()) {
1003+
Region *region = worklist.pop_back_val();
1004+
1005+
// Add any nested regions to the worklist.
1006+
for (Block &block : *region) {
1007+
anyChanged = succeeded(dropRedundantArguments(rewriter, block));
1008+
1009+
for (Operation &op : block)
1010+
for (Region &nestedRegion : op.getRegions())
1011+
worklist.insert(&nestedRegion);
1012+
}
1013+
}
1014+
return success(anyChanged);
1015+
}
1016+
8211017
//===----------------------------------------------------------------------===//
8221018
// Region Simplification
8231019
//===----------------------------------------------------------------------===//
@@ -832,8 +1028,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
8321028
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
8331029
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
8341030
bool mergedIdenticalBlocks = false;
835-
if (mergeBlocks)
1031+
bool droppedRedundantArguments = false;
1032+
if (mergeBlocks) {
8361033
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
1034+
droppedRedundantArguments =
1035+
succeeded(dropRedundantArguments(rewriter, regions));
1036+
}
8371037
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
838-
mergedIdenticalBlocks);
1038+
mergedIdenticalBlocks || droppedRedundantArguments);
8391039
}

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,28 +178,32 @@ func.func @condBranchDynamicTypeNested(
178178
// CHECK-NEXT: ^bb1
179179
// CHECK-NOT: bufferization.dealloc
180180
// CHECK-NOT: bufferization.clone
181-
// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
181+
// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
182182
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
183183
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
184184
// CHECK-NEXT: test.buffer_based
185185
// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
186186
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
187187
// CHECK-NOT: bufferization.dealloc
188188
// CHECK-NOT: bufferization.clone
189-
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
189+
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
190190
// CHECK-NEXT: ^bb3:
191191
// CHECK-NOT: bufferization.dealloc
192192
// CHECK-NOT: bufferization.clone
193-
// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
194-
// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
193+
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
194+
// CHECK-NEXT: ^bb4:
195195
// CHECK-NOT: bufferization.dealloc
196196
// CHECK-NOT: bufferization.clone
197-
// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
198-
// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
197+
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
198+
// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
199+
// CHECK-NOT: bufferization.dealloc
200+
// CHECK-NOT: bufferization.clone
201+
// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
202+
// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
199203
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
200204
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
201-
// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
202-
// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
205+
// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
206+
// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
203207
// CHECK: test.copy
204208
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
205209
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])

mlir/test/Dialect/Linalg/detensorize_entry_block.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
1515
// CHECK-LABEL: @main
1616
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
1717
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
18-
// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
19-
// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
20-
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
18+
// CHECK: cf.br ^{{.*}}
19+
// CHECK: ^{{.*}}:
20+
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
2121
// CHECK: return %[[ELEMENTS]] : tensor<f32>

0 commit comments

Comments
 (0)