9
9
#include " mlir/Transforms/RegionUtils.h"
10
10
#include " mlir/Analysis/TopologicalSortUtils.h"
11
11
#include " mlir/IR/Block.h"
12
- #include " mlir/IR/BuiltinOps.h"
13
12
#include " mlir/IR/IRMapping.h"
14
13
#include " mlir/IR/Operation.h"
15
14
#include " mlir/IR/PatternMatch.h"
16
15
#include " mlir/IR/RegionGraphTraits.h"
17
16
#include " mlir/IR/Value.h"
18
17
#include " mlir/Interfaces/ControlFlowInterfaces.h"
19
18
#include " mlir/Interfaces/SideEffectInterfaces.h"
20
- #include " mlir/Support/LogicalResult.h"
21
19
22
20
#include " llvm/ADT/DepthFirstIterator.h"
23
21
#include " llvm/ADT/PostOrderIterator.h"
24
- #include " llvm/ADT/STLExtras.h"
25
- #include " llvm/ADT/SmallSet.h"
26
22
27
23
#include < deque>
28
- #include < iterator>
29
24
30
25
using namespace mlir ;
31
26
@@ -679,91 +674,6 @@ static bool ableToUpdatePredOperands(Block *block) {
679
674
return true ;
680
675
}
681
676
682
- // / Prunes the redundant list of arguments. E.g., if we are passing an argument
683
- // / list like [x, y, z, x] this would return [x, y, z] and it would update the
684
- // / `block` (to whom the argument are passed to) accordingly.
685
- static SmallVector<SmallVector<Value, 8 >, 2 > pruneRedundantArguments (
686
- const SmallVector<SmallVector<Value, 8 >, 2 > &newArguments,
687
- RewriterBase &rewriter, Block *block) {
688
-
689
- SmallVector<SmallVector<Value, 8 >, 2 > newArgumentsPruned (
690
- newArguments.size (), SmallVector<Value, 8 >());
691
-
692
- if (newArguments.empty ())
693
- return newArguments;
694
-
695
- // `newArguments` is a 2D array of size `numLists` x `numArgs`
696
- unsigned numLists = newArguments.size ();
697
- unsigned numArgs = newArguments[0 ].size ();
698
-
699
- // Map that for each arg index contains the index that we can use in place of
700
- // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
701
- // idxToReplacement[3] = 0
702
- llvm::DenseMap<unsigned , unsigned > idxToReplacement;
703
-
704
- // This is a useful data structure to track the first appearance of a Value
705
- // on a given list of arguments
706
- DenseMap<Value, unsigned > firstValueToIdx;
707
- for (unsigned j = 0 ; j < numArgs; ++j) {
708
- Value newArg = newArguments[0 ][j];
709
- if (!firstValueToIdx.contains (newArg))
710
- firstValueToIdx[newArg] = j;
711
- }
712
-
713
- // Go through the first list of arguments (list 0).
714
- for (unsigned j = 0 ; j < numArgs; ++j) {
715
- bool shouldReplaceJ = false ;
716
- unsigned replacement = 0 ;
717
- // Look back to see if there are possible redundancies in list 0. Please
718
- // note that we are using a map to annotate when an argument was seen first
719
- // to avoid a O(N^2) algorithm. This has the drawback that if we have two
720
- // lists like:
721
- // list0: [%a, %a, %a]
722
- // list1: [%c, %b, %b]
723
- // We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
724
- // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
725
- // the number of arguments can be potentially unbounded we cannot afford a
726
- // O(N^2) algorithm (to search to all the possible pairs) and we need to
727
- // accept the trade-off.
728
- unsigned k = firstValueToIdx[newArguments[0 ][j]];
729
- if (k != j) {
730
- shouldReplaceJ = true ;
731
- replacement = k;
732
- // If a possible redundancy is found, then scan the other lists: we
733
- // can prune the arguments if and only if they are redundant in every
734
- // list.
735
- for (unsigned i = 1 ; i < numLists; ++i)
736
- shouldReplaceJ =
737
- shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
738
- }
739
- // Save the replacement.
740
- if (shouldReplaceJ)
741
- idxToReplacement[j] = replacement;
742
- }
743
-
744
- // Populate the pruned argument list.
745
- for (unsigned i = 0 ; i < numLists; ++i)
746
- for (unsigned j = 0 ; j < numArgs; ++j)
747
- if (!idxToReplacement.contains (j))
748
- newArgumentsPruned[i].push_back (newArguments[i][j]);
749
-
750
- // Replace the block's redundant arguments.
751
- SmallVector<unsigned > toErase;
752
- for (auto [idx, arg] : llvm::enumerate (block->getArguments ())) {
753
- if (idxToReplacement.contains (idx)) {
754
- Value oldArg = block->getArgument (idx);
755
- Value newArg = block->getArgument (idxToReplacement[idx]);
756
- rewriter.replaceAllUsesWith (oldArg, newArg);
757
- toErase.push_back (idx);
758
- }
759
- }
760
-
761
- // Erase the block's redundant arguments.
762
- for (unsigned idxToErase : llvm::reverse (toErase))
763
- block->eraseArgument (idxToErase);
764
- return newArgumentsPruned;
765
- }
766
-
767
677
LogicalResult BlockMergeCluster::merge (RewriterBase &rewriter) {
768
678
// Don't consider clusters that don't have blocks to merge.
769
679
if (blocksToMerge.empty ())
@@ -812,10 +722,6 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
812
722
}
813
723
}
814
724
}
815
-
816
- // Prune redundant arguments and update the leader block argument list
817
- newArguments = pruneRedundantArguments (newArguments, rewriter, leaderBlock);
818
-
819
725
// Update the predecessors for each of the blocks.
820
726
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
821
727
for (auto predIt = block->pred_begin (), predE = block->pred_end ();
@@ -912,108 +818,6 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
912
818
return success (anyChanged);
913
819
}
914
820
915
- static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
916
- Block &block) {
917
- SmallVector<size_t > argsToErase;
918
-
919
- // Go through the arguments of the block.
920
- for (auto [argIdx, blockOperand] : llvm::enumerate (block.getArguments ())) {
921
- bool sameArg = true ;
922
- Value commonValue;
923
-
924
- // Go through the block predecessor and flag if they pass to the block
925
- // different values for the same argument.
926
- for (auto predIt = block.pred_begin (), predE = block.pred_end ();
927
- predIt != predE; ++predIt) {
928
- auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator ());
929
- if (!branch) {
930
- sameArg = false ;
931
- break ;
932
- }
933
- unsigned succIndex = predIt.getSuccessorIndex ();
934
- SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
935
- auto branchOperands = succOperands.getForwardedOperands ();
936
- if (!commonValue) {
937
- commonValue = branchOperands[argIdx];
938
- } else {
939
- if (branchOperands[argIdx] != commonValue) {
940
- sameArg = false ;
941
- break ;
942
- }
943
- }
944
- }
945
-
946
- // If they are passing the same value, drop the argument.
947
- if (commonValue && sameArg) {
948
- argsToErase.push_back (argIdx);
949
-
950
- // Remove the argument from the block.
951
- rewriter.replaceAllUsesWith (blockOperand, commonValue);
952
- }
953
- }
954
-
955
- // Remove the arguments.
956
- for (auto argIdx : llvm::reverse (argsToErase)) {
957
- block.eraseArgument (argIdx);
958
-
959
- // Remove the argument from the branch ops.
960
- for (auto predIt = block.pred_begin (), predE = block.pred_end ();
961
- predIt != predE; ++predIt) {
962
- auto branch = cast<BranchOpInterface>((*predIt)->getTerminator ());
963
- unsigned succIndex = predIt.getSuccessorIndex ();
964
- SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
965
- succOperands.erase (argIdx);
966
- }
967
- }
968
- return success (!argsToErase.empty ());
969
- }
970
-
971
- // / This optimization drops redundant argument to blocks. I.e., if a given
972
- // / argument to a block receives the same value from each of the block
973
- // / predecessors, we can remove the argument from the block and use directly the
974
- // / original value. This is a simple example:
975
- // /
976
- // / %cond = llvm.call @rand() : () -> i1
977
- // / %val0 = llvm.mlir.constant(1 : i64) : i64
978
- // / %val1 = llvm.mlir.constant(2 : i64) : i64
979
- // / %val2 = llvm.mlir.constant(3 : i64) : i64
980
- // / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
981
- // / : i64)
982
- // /
983
- // / ^bb1(%arg0 : i64, %arg1 : i64):
984
- // / llvm.call @foo(%arg0, %arg1)
985
- // /
986
- // / The previous IR can be rewritten as:
987
- // / %cond = llvm.call @rand() : () -> i1
988
- // / %val0 = llvm.mlir.constant(1 : i64) : i64
989
- // / %val1 = llvm.mlir.constant(2 : i64) : i64
990
- // / %val2 = llvm.mlir.constant(3 : i64) : i64
991
- // / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
992
- // /
993
- // / ^bb1(%arg0 : i64):
994
- // / llvm.call @foo(%val0, %arg0)
995
- // /
996
- static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
997
- MutableArrayRef<Region> regions) {
998
- llvm::SmallSetVector<Region *, 1 > worklist;
999
- for (Region ®ion : regions)
1000
- worklist.insert (®ion);
1001
- bool anyChanged = false ;
1002
- while (!worklist.empty ()) {
1003
- Region *region = worklist.pop_back_val ();
1004
-
1005
- // Add any nested regions to the worklist.
1006
- for (Block &block : *region) {
1007
- anyChanged = succeeded (dropRedundantArguments (rewriter, block));
1008
-
1009
- for (Operation &op : block)
1010
- for (Region &nestedRegion : op.getRegions ())
1011
- worklist.insert (&nestedRegion);
1012
- }
1013
- }
1014
- return success (anyChanged);
1015
- }
1016
-
1017
821
// ===----------------------------------------------------------------------===//
1018
822
// Region Simplification
1019
823
// ===----------------------------------------------------------------------===//
@@ -1028,12 +832,8 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
1028
832
bool eliminatedBlocks = succeeded (eraseUnreachableBlocks (rewriter, regions));
1029
833
bool eliminatedOpsOrArgs = succeeded (runRegionDCE (rewriter, regions));
1030
834
bool mergedIdenticalBlocks = false ;
1031
- bool droppedRedundantArguments = false ;
1032
- if (mergeBlocks) {
835
+ if (mergeBlocks)
1033
836
mergedIdenticalBlocks = succeeded (mergeIdenticalBlocks (rewriter, regions));
1034
- droppedRedundantArguments =
1035
- succeeded (dropRedundantArguments (rewriter, regions));
1036
- }
1037
837
return success (eliminatedBlocks || eliminatedOpsOrArgs ||
1038
- mergedIdenticalBlocks || droppedRedundantArguments );
838
+ mergedIdenticalBlocks);
1039
839
}
0 commit comments