Skip to content

Commit 055852b

Browse files
authored
Revert "[mlir] Fix block merging (#97697)"
This reverts commit c63125d.
1 parent 8608cc1 commit 055852b

File tree

12 files changed

+83
-445
lines changed

12 files changed

+83
-445
lines changed

mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -463,15 +463,10 @@ struct BufferDeallocationSimplificationPass
463463
SplitDeallocWhenNotAliasingAnyOther,
464464
RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
465465
analysis);
466-
// We don't want that the block structure changes invalidating the
467-
// `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
468-
// region simplification
469-
GreedyRewriteConfig config;
470-
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
471466
populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
472467

473-
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
474-
config)))
468+
if (failed(
469+
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
475470
signalPassFailure();
476471
}
477472
};

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 2 additions & 202 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,18 @@
99
#include "mlir/Transforms/RegionUtils.h"
1010
#include "mlir/Analysis/TopologicalSortUtils.h"
1111
#include "mlir/IR/Block.h"
12-
#include "mlir/IR/BuiltinOps.h"
1312
#include "mlir/IR/IRMapping.h"
1413
#include "mlir/IR/Operation.h"
1514
#include "mlir/IR/PatternMatch.h"
1615
#include "mlir/IR/RegionGraphTraits.h"
1716
#include "mlir/IR/Value.h"
1817
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1918
#include "mlir/Interfaces/SideEffectInterfaces.h"
20-
#include "mlir/Support/LogicalResult.h"
2119

2220
#include "llvm/ADT/DepthFirstIterator.h"
2321
#include "llvm/ADT/PostOrderIterator.h"
24-
#include "llvm/ADT/STLExtras.h"
25-
#include "llvm/ADT/SmallSet.h"
2622

2723
#include <deque>
28-
#include <iterator>
2924

3025
using namespace mlir;
3126

@@ -679,91 +674,6 @@ static bool ableToUpdatePredOperands(Block *block) {
679674
return true;
680675
}
681676

682-
/// Prunes the redundant list of arguments. E.g., if we are passing an argument
683-
/// list like [x, y, z, x] this would return [x, y, z] and it would update the
684-
/// `block` (to whom the argument are passed to) accordingly.
685-
static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
686-
const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
687-
RewriterBase &rewriter, Block *block) {
688-
689-
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
690-
newArguments.size(), SmallVector<Value, 8>());
691-
692-
if (newArguments.empty())
693-
return newArguments;
694-
695-
// `newArguments` is a 2D array of size `numLists` x `numArgs`
696-
unsigned numLists = newArguments.size();
697-
unsigned numArgs = newArguments[0].size();
698-
699-
// Map that for each arg index contains the index that we can use in place of
700-
// the original index. E.g., if we have newArgs = [x, y, z, x], we will have
701-
// idxToReplacement[3] = 0
702-
llvm::DenseMap<unsigned, unsigned> idxToReplacement;
703-
704-
// This is a useful data structure to track the first appearance of a Value
705-
// on a given list of arguments
706-
DenseMap<Value, unsigned> firstValueToIdx;
707-
for (unsigned j = 0; j < numArgs; ++j) {
708-
Value newArg = newArguments[0][j];
709-
if (!firstValueToIdx.contains(newArg))
710-
firstValueToIdx[newArg] = j;
711-
}
712-
713-
// Go through the first list of arguments (list 0).
714-
for (unsigned j = 0; j < numArgs; ++j) {
715-
bool shouldReplaceJ = false;
716-
unsigned replacement = 0;
717-
// Look back to see if there are possible redundancies in list 0. Please
718-
// note that we are using a map to annotate when an argument was seen first
719-
// to avoid a O(N^2) algorithm. This has the drawback that if we have two
720-
// lists like:
721-
// list0: [%a, %a, %a]
722-
// list1: [%c, %b, %b]
723-
// We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
724-
// point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
725-
// the number of arguments can be potentially unbounded we cannot afford a
726-
// O(N^2) algorithm (to search to all the possible pairs) and we need to
727-
// accept the trade-off.
728-
unsigned k = firstValueToIdx[newArguments[0][j]];
729-
if (k != j) {
730-
shouldReplaceJ = true;
731-
replacement = k;
732-
// If a possible redundancy is found, then scan the other lists: we
733-
// can prune the arguments if and only if they are redundant in every
734-
// list.
735-
for (unsigned i = 1; i < numLists; ++i)
736-
shouldReplaceJ =
737-
shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
738-
}
739-
// Save the replacement.
740-
if (shouldReplaceJ)
741-
idxToReplacement[j] = replacement;
742-
}
743-
744-
// Populate the pruned argument list.
745-
for (unsigned i = 0; i < numLists; ++i)
746-
for (unsigned j = 0; j < numArgs; ++j)
747-
if (!idxToReplacement.contains(j))
748-
newArgumentsPruned[i].push_back(newArguments[i][j]);
749-
750-
// Replace the block's redundant arguments.
751-
SmallVector<unsigned> toErase;
752-
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
753-
if (idxToReplacement.contains(idx)) {
754-
Value oldArg = block->getArgument(idx);
755-
Value newArg = block->getArgument(idxToReplacement[idx]);
756-
rewriter.replaceAllUsesWith(oldArg, newArg);
757-
toErase.push_back(idx);
758-
}
759-
}
760-
761-
// Erase the block's redundant arguments.
762-
for (unsigned idxToErase : llvm::reverse(toErase))
763-
block->eraseArgument(idxToErase);
764-
return newArgumentsPruned;
765-
}
766-
767677
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
768678
// Don't consider clusters that don't have blocks to merge.
769679
if (blocksToMerge.empty())
@@ -812,10 +722,6 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
812722
}
813723
}
814724
}
815-
816-
// Prune redundant arguments and update the leader block argument list
817-
newArguments = pruneRedundantArguments(newArguments, rewriter, leaderBlock);
818-
819725
// Update the predecessors for each of the blocks.
820726
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
821727
for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -912,108 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
912818
return success(anyChanged);
913819
}
914820

915-
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
916-
Block &block) {
917-
SmallVector<size_t> argsToErase;
918-
919-
// Go through the arguments of the block.
920-
for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
921-
bool sameArg = true;
922-
Value commonValue;
923-
924-
// Go through the block predecessor and flag if they pass to the block
925-
// different values for the same argument.
926-
for (auto predIt = block.pred_begin(), predE = block.pred_end();
927-
predIt != predE; ++predIt) {
928-
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
929-
if (!branch) {
930-
sameArg = false;
931-
break;
932-
}
933-
unsigned succIndex = predIt.getSuccessorIndex();
934-
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
935-
auto branchOperands = succOperands.getForwardedOperands();
936-
if (!commonValue) {
937-
commonValue = branchOperands[argIdx];
938-
} else {
939-
if (branchOperands[argIdx] != commonValue) {
940-
sameArg = false;
941-
break;
942-
}
943-
}
944-
}
945-
946-
// If they are passing the same value, drop the argument.
947-
if (commonValue && sameArg) {
948-
argsToErase.push_back(argIdx);
949-
950-
// Remove the argument from the block.
951-
rewriter.replaceAllUsesWith(blockOperand, commonValue);
952-
}
953-
}
954-
955-
// Remove the arguments.
956-
for (auto argIdx : llvm::reverse(argsToErase)) {
957-
block.eraseArgument(argIdx);
958-
959-
// Remove the argument from the branch ops.
960-
for (auto predIt = block.pred_begin(), predE = block.pred_end();
961-
predIt != predE; ++predIt) {
962-
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
963-
unsigned succIndex = predIt.getSuccessorIndex();
964-
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
965-
succOperands.erase(argIdx);
966-
}
967-
}
968-
return success(!argsToErase.empty());
969-
}
970-
971-
/// This optimization drops redundant argument to blocks. I.e., if a given
972-
/// argument to a block receives the same value from each of the block
973-
/// predecessors, we can remove the argument from the block and use directly the
974-
/// original value. This is a simple example:
975-
///
976-
/// %cond = llvm.call @rand() : () -> i1
977-
/// %val0 = llvm.mlir.constant(1 : i64) : i64
978-
/// %val1 = llvm.mlir.constant(2 : i64) : i64
979-
/// %val2 = llvm.mlir.constant(3 : i64) : i64
980-
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
981-
/// : i64)
982-
///
983-
/// ^bb1(%arg0 : i64, %arg1 : i64):
984-
/// llvm.call @foo(%arg0, %arg1)
985-
///
986-
/// The previous IR can be rewritten as:
987-
/// %cond = llvm.call @rand() : () -> i1
988-
/// %val0 = llvm.mlir.constant(1 : i64) : i64
989-
/// %val1 = llvm.mlir.constant(2 : i64) : i64
990-
/// %val2 = llvm.mlir.constant(3 : i64) : i64
991-
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
992-
///
993-
/// ^bb1(%arg0 : i64):
994-
/// llvm.call @foo(%val0, %arg0)
995-
///
996-
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
997-
MutableArrayRef<Region> regions) {
998-
llvm::SmallSetVector<Region *, 1> worklist;
999-
for (Region &region : regions)
1000-
worklist.insert(&region);
1001-
bool anyChanged = false;
1002-
while (!worklist.empty()) {
1003-
Region *region = worklist.pop_back_val();
1004-
1005-
// Add any nested regions to the worklist.
1006-
for (Block &block : *region) {
1007-
anyChanged = succeeded(dropRedundantArguments(rewriter, block));
1008-
1009-
for (Operation &op : block)
1010-
for (Region &nestedRegion : op.getRegions())
1011-
worklist.insert(&nestedRegion);
1012-
}
1013-
}
1014-
return success(anyChanged);
1015-
}
1016-
1017821
//===----------------------------------------------------------------------===//
1018822
// Region Simplification
1019823
//===----------------------------------------------------------------------===//
@@ -1028,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
1028832
bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
1029833
bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
1030834
bool mergedIdenticalBlocks = false;
1031-
bool droppedRedundantArguments = false;
1032-
if (mergeBlocks) {
835+
if (mergeBlocks)
1033836
mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
1034-
droppedRedundantArguments =
1035-
succeeded(dropRedundantArguments(rewriter, regions));
1036-
}
1037837
return success(eliminatedBlocks || eliminatedOpsOrArgs ||
1038-
mergedIdenticalBlocks || droppedRedundantArguments);
838+
mergedIdenticalBlocks);
1039839
}

mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -178,32 +178,28 @@ func.func @condBranchDynamicTypeNested(
178178
// CHECK-NEXT: ^bb1
179179
// CHECK-NOT: bufferization.dealloc
180180
// CHECK-NOT: bufferization.clone
181-
// CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
181+
// CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
182182
// CHECK: ^bb2([[IDX:%.*]]:{{.*}})
183183
// CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
184184
// CHECK-NEXT: test.buffer_based
185185
// CHECK-NEXT: [[NOT_ARG0:%.+]] = arith.xori [[ARG0]], %true
186186
// CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
187187
// CHECK-NOT: bufferization.dealloc
188188
// CHECK-NOT: bufferization.clone
189-
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
189+
// CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
190190
// CHECK-NEXT: ^bb3:
191191
// CHECK-NOT: bufferization.dealloc
192192
// CHECK-NOT: bufferization.clone
193-
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
194-
// CHECK-NEXT: ^bb4:
193+
// CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
194+
// CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
195195
// CHECK-NOT: bufferization.dealloc
196196
// CHECK-NOT: bufferization.clone
197-
// CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
198-
// CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
199-
// CHECK-NOT: bufferization.dealloc
200-
// CHECK-NOT: bufferization.clone
201-
// CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
202-
// CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
197+
// CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
198+
// CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
203199
// CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
204200
// CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
205-
// CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
206-
// CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
201+
// CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
202+
// CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
207203
// CHECK: test.copy
208204
// CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
209205
// CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])

mlir/test/Dialect/Linalg/detensorize_entry_block.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
1515
// CHECK-LABEL: @main
1616
// CHECK-SAME: (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
1717
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
18-
// CHECK: cf.br ^{{.*}}
19-
// CHECK: ^{{.*}}:
20-
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
18+
// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
19+
// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
20+
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
2121
// CHECK: return %[[ELEMENTS]] : tensor<f32>

0 commit comments

Comments
 (0)