9
9
#include " mlir/Transforms/RegionUtils.h"
10
10
#include " mlir/Analysis/TopologicalSortUtils.h"
11
11
#include " mlir/IR/Block.h"
12
+ #include " mlir/IR/BuiltinOps.h"
12
13
#include " mlir/IR/IRMapping.h"
13
14
#include " mlir/IR/Operation.h"
14
15
#include " mlir/IR/PatternMatch.h"
15
16
#include " mlir/IR/RegionGraphTraits.h"
16
17
#include " mlir/IR/Value.h"
17
18
#include " mlir/Interfaces/ControlFlowInterfaces.h"
18
19
#include " mlir/Interfaces/SideEffectInterfaces.h"
20
+ #include " mlir/Support/LogicalResult.h"
19
21
20
22
#include " llvm/ADT/DepthFirstIterator.h"
21
23
#include " llvm/ADT/PostOrderIterator.h"
24
+ #include " llvm/ADT/STLExtras.h"
25
+ #include " llvm/ADT/SmallSet.h"
22
26
23
27
#include < deque>
28
+ #include < iterator>
24
29
25
30
using namespace mlir ;
26
31
@@ -674,6 +679,91 @@ static bool ableToUpdatePredOperands(Block *block) {
674
679
return true ;
675
680
}
676
681
682
+ // / Prunes the redundant list of arguments. E.g., if we are passing an argument
683
+ // / list like [x, y, z, x] this would return [x, y, z] and it would update the
684
+ // / `block` (to whom the argument are passed to) accordingly.
685
+ static SmallVector<SmallVector<Value, 8 >, 2 > pruneRedundantArguments (
686
+ const SmallVector<SmallVector<Value, 8 >, 2 > &newArguments,
687
+ RewriterBase &rewriter, Block *block) {
688
+
689
+ SmallVector<SmallVector<Value, 8 >, 2 > newArgumentsPruned (
690
+ newArguments.size (), SmallVector<Value, 8 >());
691
+
692
+ if (newArguments.empty ())
693
+ return newArguments;
694
+
695
+ // `newArguments` is a 2D array of size `numLists` x `numArgs`
696
+ unsigned numLists = newArguments.size ();
697
+ unsigned numArgs = newArguments[0 ].size ();
698
+
699
+ // Map that for each arg index contains the index that we can use in place of
700
+ // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
701
+ // idxToReplacement[3] = 0
702
+ llvm::DenseMap<unsigned , unsigned > idxToReplacement;
703
+
704
+ // This is a useful data structure to track the first appearance of a Value
705
+ // on a given list of arguments
706
+ DenseMap<Value, unsigned > firstValueToIdx;
707
+ for (unsigned j = 0 ; j < numArgs; ++j) {
708
+ Value newArg = newArguments[0 ][j];
709
+ if (!firstValueToIdx.contains (newArg))
710
+ firstValueToIdx[newArg] = j;
711
+ }
712
+
713
+ // Go through the first list of arguments (list 0).
714
+ for (unsigned j = 0 ; j < numArgs; ++j) {
715
+ bool shouldReplaceJ = false ;
716
+ unsigned replacement = 0 ;
717
+ // Look back to see if there are possible redundancies in list 0. Please
718
+ // note that we are using a map to annotate when an argument was seen first
719
+ // to avoid a O(N^2) algorithm. This has the drawback that if we have two
720
+ // lists like:
721
+ // list0: [%a, %a, %a]
722
+ // list1: [%c, %b, %b]
723
+ // We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
724
+ // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
725
+ // the number of arguments can be potentially unbounded we cannot afford a
726
+ // O(N^2) algorithm (to search to all the possible pairs) and we need to
727
+ // accept the trade-off.
728
+ unsigned k = firstValueToIdx[newArguments[0 ][j]];
729
+ if (k != j) {
730
+ shouldReplaceJ = true ;
731
+ replacement = k;
732
+ // If a possible redundancy is found, then scan the other lists: we
733
+ // can prune the arguments if and only if they are redundant in every
734
+ // list.
735
+ for (unsigned i = 1 ; i < numLists; ++i)
736
+ shouldReplaceJ =
737
+ shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
738
+ }
739
+ // Save the replacement.
740
+ if (shouldReplaceJ)
741
+ idxToReplacement[j] = replacement;
742
+ }
743
+
744
+ // Populate the pruned argument list.
745
+ for (unsigned i = 0 ; i < numLists; ++i)
746
+ for (unsigned j = 0 ; j < numArgs; ++j)
747
+ if (!idxToReplacement.contains (j))
748
+ newArgumentsPruned[i].push_back (newArguments[i][j]);
749
+
750
+ // Replace the block's redundant arguments.
751
+ SmallVector<unsigned > toErase;
752
+ for (auto [idx, arg] : llvm::enumerate (block->getArguments ())) {
753
+ if (idxToReplacement.contains (idx)) {
754
+ Value oldArg = block->getArgument (idx);
755
+ Value newArg = block->getArgument (idxToReplacement[idx]);
756
+ rewriter.replaceAllUsesWith (oldArg, newArg);
757
+ toErase.push_back (idx);
758
+ }
759
+ }
760
+
761
+ // Erase the block's redundant arguments.
762
+ for (unsigned idxToErase : llvm::reverse (toErase))
763
+ block->eraseArgument (idxToErase);
764
+ return newArgumentsPruned;
765
+ }
766
+
677
767
LogicalResult BlockMergeCluster::merge (RewriterBase &rewriter) {
678
768
// Don't consider clusters that don't have blocks to merge.
679
769
if (blocksToMerge.empty ())
@@ -722,6 +812,10 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
722
812
}
723
813
}
724
814
}
815
+
816
+ // Prune redundant arguments and update the leader block argument list
817
+ newArguments = pruneRedundantArguments (newArguments, rewriter, leaderBlock);
818
+
725
819
// Update the predecessors for each of the blocks.
726
820
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
727
821
for (auto predIt = block->pred_begin (), predE = block->pred_end ();
@@ -818,6 +912,108 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
818
912
return success (anyChanged);
819
913
}
820
914
915
+ static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
916
+ Block &block) {
917
+ SmallVector<size_t > argsToErase;
918
+
919
+ // Go through the arguments of the block.
920
+ for (auto [argIdx, blockOperand] : llvm::enumerate (block.getArguments ())) {
921
+ bool sameArg = true ;
922
+ Value commonValue;
923
+
924
+ // Go through the block predecessor and flag if they pass to the block
925
+ // different values for the same argument.
926
+ for (auto predIt = block.pred_begin (), predE = block.pred_end ();
927
+ predIt != predE; ++predIt) {
928
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator ());
929
+ if (!branch) {
930
+ sameArg = false ;
931
+ break ;
932
+ }
933
+ unsigned succIndex = predIt.getSuccessorIndex ();
934
+ SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
935
+ auto branchOperands = succOperands.getForwardedOperands ();
936
+ if (!commonValue) {
937
+ commonValue = branchOperands[argIdx];
938
+ } else {
939
+ if (branchOperands[argIdx] != commonValue) {
940
+ sameArg = false ;
941
+ break ;
942
+ }
943
+ }
944
+ }
945
+
946
+ // If they are passing the same value, drop the argument.
947
+ if (commonValue && sameArg) {
948
+ argsToErase.push_back (argIdx);
949
+
950
+ // Remove the argument from the block.
951
+ rewriter.replaceAllUsesWith (blockOperand, commonValue);
952
+ }
953
+ }
954
+
955
+ // Remove the arguments.
956
+ for (auto argIdx : llvm::reverse (argsToErase)) {
957
+ block.eraseArgument (argIdx);
958
+
959
+ // Remove the argument from the branch ops.
960
+ for (auto predIt = block.pred_begin (), predE = block.pred_end ();
961
+ predIt != predE; ++predIt) {
962
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator ());
963
+ unsigned succIndex = predIt.getSuccessorIndex ();
964
+ SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
965
+ succOperands.erase (argIdx);
966
+ }
967
+ }
968
+ return success (!argsToErase.empty ());
969
+ }
970
+
971
+ // / This optimization drops redundant argument to blocks. I.e., if a given
972
+ // / argument to a block receives the same value from each of the block
973
+ // / predecessors, we can remove the argument from the block and use directly the
974
+ // / original value. This is a simple example:
975
+ // /
976
+ // / %cond = llvm.call @rand() : () -> i1
977
+ // / %val0 = llvm.mlir.constant(1 : i64) : i64
978
+ // / %val1 = llvm.mlir.constant(2 : i64) : i64
979
+ // / %val2 = llvm.mlir.constant(3 : i64) : i64
980
+ // / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
981
+ // / : i64)
982
+ // /
983
+ // / ^bb1(%arg0 : i64, %arg1 : i64):
984
+ // / llvm.call @foo(%arg0, %arg1)
985
+ // /
986
+ // / The previous IR can be rewritten as:
987
+ // / %cond = llvm.call @rand() : () -> i1
988
+ // / %val0 = llvm.mlir.constant(1 : i64) : i64
989
+ // / %val1 = llvm.mlir.constant(2 : i64) : i64
990
+ // / %val2 = llvm.mlir.constant(3 : i64) : i64
991
+ // / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
992
+ // /
993
+ // / ^bb1(%arg0 : i64):
994
+ // / llvm.call @foo(%val0, %arg0)
995
+ // /
996
+ static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
997
+ MutableArrayRef<Region> regions) {
998
+ llvm::SmallSetVector<Region *, 1 > worklist;
999
+ for (Region ®ion : regions)
1000
+ worklist.insert (®ion);
1001
+ bool anyChanged = false ;
1002
+ while (!worklist.empty ()) {
1003
+ Region *region = worklist.pop_back_val ();
1004
+
1005
+ // Add any nested regions to the worklist.
1006
+ for (Block &block : *region) {
1007
+ anyChanged = succeeded (dropRedundantArguments (rewriter, block));
1008
+
1009
+ for (Operation &op : block)
1010
+ for (Region &nestedRegion : op.getRegions ())
1011
+ worklist.insert (&nestedRegion);
1012
+ }
1013
+ }
1014
+ return success (anyChanged);
1015
+ }
1016
+
821
1017
// ===----------------------------------------------------------------------===//
822
1018
// Region Simplification
823
1019
// ===----------------------------------------------------------------------===//
@@ -832,8 +1028,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
832
1028
bool eliminatedBlocks = succeeded (eraseUnreachableBlocks (rewriter, regions));
833
1029
bool eliminatedOpsOrArgs = succeeded (runRegionDCE (rewriter, regions));
834
1030
bool mergedIdenticalBlocks = false ;
835
- if (mergeBlocks)
1031
+ bool droppedRedundantArguments = false ;
1032
+ if (mergeBlocks) {
836
1033
mergedIdenticalBlocks = succeeded (mergeIdenticalBlocks (rewriter, regions));
1034
+ droppedRedundantArguments =
1035
+ succeeded (dropRedundantArguments (rewriter, regions));
1036
+ }
837
1037
return success (eliminatedBlocks || eliminatedOpsOrArgs ||
838
- mergedIdenticalBlocks);
1038
+ mergedIdenticalBlocks || droppedRedundantArguments );
839
1039
}
0 commit comments