9
9
#include " mlir/Transforms/RegionUtils.h"
10
10
#include " mlir/Analysis/TopologicalSortUtils.h"
11
11
#include " mlir/IR/Block.h"
12
+ #include " mlir/IR/BuiltinOps.h"
12
13
#include " mlir/IR/IRMapping.h"
13
14
#include " mlir/IR/Operation.h"
14
15
#include " mlir/IR/PatternMatch.h"
15
16
#include " mlir/IR/RegionGraphTraits.h"
16
17
#include " mlir/IR/Value.h"
17
18
#include " mlir/Interfaces/ControlFlowInterfaces.h"
18
19
#include " mlir/Interfaces/SideEffectInterfaces.h"
20
+ #include " mlir/Support/LogicalResult.h"
19
21
20
22
#include " llvm/ADT/DepthFirstIterator.h"
21
23
#include " llvm/ADT/PostOrderIterator.h"
24
+ #include " llvm/ADT/STLExtras.h"
25
+ #include " llvm/ADT/SmallSet.h"
22
26
23
27
#include < deque>
28
+ #include < iterator>
24
29
25
30
using namespace mlir ;
26
31
@@ -674,6 +679,94 @@ static bool ableToUpdatePredOperands(Block *block) {
674
679
return true ;
675
680
}
676
681
682
+ // / Prunes the redundant list of new arguments. E.g., if we are passing an
683
+ // / argument list like [x, y, z, x] this would return [x, y, z] and it would
684
+ // / update the `block` (to whom the argument are passed to) accordingly. The new
685
+ // / arguments are passed as arguments at the back of the block, hence we need to
686
+ // / know how many `numOldArguments` were before, in order to correctly replace
687
+ // / the new arguments in the block
688
+ static SmallVector<SmallVector<Value, 8 >, 2 > pruneRedundantArguments (
689
+ const SmallVector<SmallVector<Value, 8 >, 2 > &newArguments,
690
+ RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
691
+
692
+ SmallVector<SmallVector<Value, 8 >, 2 > newArgumentsPruned (
693
+ newArguments.size (), SmallVector<Value, 8 >());
694
+
695
+ if (newArguments.empty ())
696
+ return newArguments;
697
+
698
+ // `newArguments` is a 2D array of size `numLists` x `numArgs`
699
+ unsigned numLists = newArguments.size ();
700
+ unsigned numArgs = newArguments[0 ].size ();
701
+
702
+ // Map that for each arg index contains the index that we can use in place of
703
+ // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
704
+ // idxToReplacement[3] = 0
705
+ llvm::DenseMap<unsigned , unsigned > idxToReplacement;
706
+
707
+ // This is a useful data structure to track the first appearance of a Value
708
+ // on a given list of arguments
709
+ DenseMap<Value, unsigned > firstValueToIdx;
710
+ for (unsigned j = 0 ; j < numArgs; ++j) {
711
+ Value newArg = newArguments[0 ][j];
712
+ if (!firstValueToIdx.contains (newArg))
713
+ firstValueToIdx[newArg] = j;
714
+ }
715
+
716
+ // Go through the first list of arguments (list 0).
717
+ for (unsigned j = 0 ; j < numArgs; ++j) {
718
+ // Look back to see if there are possible redundancies in list 0. Please
719
+ // note that we are using a map to annotate when an argument was seen first
720
+ // to avoid a O(N^2) algorithm. This has the drawback that if we have two
721
+ // lists like:
722
+ // list0: [%a, %a, %a]
723
+ // list1: [%c, %b, %b]
724
+ // We cannot simplify it, because firstValueToIdx[%a] = 0, but we cannot
725
+ // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c). However, since
726
+ // the number of arguments can be potentially unbounded we cannot afford a
727
+ // O(N^2) algorithm (to search to all the possible pairs) and we need to
728
+ // accept the trade-off.
729
+ unsigned k = firstValueToIdx[newArguments[0 ][j]];
730
+ if (k == j)
731
+ continue ;
732
+
733
+ bool shouldReplaceJ = true ;
734
+ unsigned replacement = k;
735
+ // If a possible redundancy is found, then scan the other lists: we
736
+ // can prune the arguments if and only if they are redundant in every
737
+ // list.
738
+ for (unsigned i = 1 ; i < numLists; ++i)
739
+ shouldReplaceJ =
740
+ shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
741
+ // Save the replacement.
742
+ if (shouldReplaceJ)
743
+ idxToReplacement[j] = replacement;
744
+ }
745
+
746
+ // Populate the pruned argument list.
747
+ for (unsigned i = 0 ; i < numLists; ++i)
748
+ for (unsigned j = 0 ; j < numArgs; ++j)
749
+ if (!idxToReplacement.contains (j))
750
+ newArgumentsPruned[i].push_back (newArguments[i][j]);
751
+
752
+ // Replace the block's redundant arguments.
753
+ SmallVector<unsigned > toErase;
754
+ for (auto [idx, arg] : llvm::enumerate (block->getArguments ())) {
755
+ if (idxToReplacement.contains (idx)) {
756
+ Value oldArg = block->getArgument (numOldArguments + idx);
757
+ Value newArg =
758
+ block->getArgument (numOldArguments + idxToReplacement[idx]);
759
+ rewriter.replaceAllUsesWith (oldArg, newArg);
760
+ toErase.push_back (numOldArguments + idx);
761
+ }
762
+ }
763
+
764
+ // Erase the block's redundant arguments.
765
+ for (unsigned idxToErase : llvm::reverse (toErase))
766
+ block->eraseArgument (idxToErase);
767
+ return newArgumentsPruned;
768
+ }
769
+
677
770
LogicalResult BlockMergeCluster::merge (RewriterBase &rewriter) {
678
771
// Don't consider clusters that don't have blocks to merge.
679
772
if (blocksToMerge.empty ())
@@ -703,6 +796,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
703
796
1 + blocksToMerge.size (),
704
797
SmallVector<Value, 8 >(operandsToMerge.size ()));
705
798
unsigned curOpIndex = 0 ;
799
+ unsigned numOldArguments = leaderBlock->getNumArguments ();
706
800
for (const auto &it : llvm::enumerate (operandsToMerge)) {
707
801
unsigned nextOpOffset = it.value ().first - curOpIndex;
708
802
curOpIndex = it.value ().first ;
@@ -722,6 +816,11 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
722
816
}
723
817
}
724
818
}
819
+
820
+ // Prune redundant arguments and update the leader block argument list
821
+ newArguments = pruneRedundantArguments (newArguments, rewriter,
822
+ numOldArguments, leaderBlock);
823
+
725
824
// Update the predecessors for each of the blocks.
726
825
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
727
826
for (auto predIt = block->pred_begin (), predE = block->pred_end ();
@@ -818,6 +917,111 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
818
917
return success (anyChanged);
819
918
}
820
919
920
+ // / If a block's argument is always the same across different invocations, then
921
+ // / drop the argument and use the value directly inside the block
922
+ static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
923
+ Block &block) {
924
+ SmallVector<size_t > argsToErase;
925
+
926
+ // Go through the arguments of the block.
927
+ for (auto [argIdx, blockOperand] : llvm::enumerate (block.getArguments ())) {
928
+ bool sameArg = true ;
929
+ Value commonValue;
930
+
931
+ // Go through the block predecessor and flag if they pass to the block
932
+ // different values for the same argument.
933
+ for (Block::pred_iterator predIt = block.pred_begin (),
934
+ predE = block.pred_end ();
935
+ predIt != predE; ++predIt) {
936
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator ());
937
+ if (!branch) {
938
+ sameArg = false ;
939
+ break ;
940
+ }
941
+ unsigned succIndex = predIt.getSuccessorIndex ();
942
+ SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
943
+ auto branchOperands = succOperands.getForwardedOperands ();
944
+ if (!commonValue) {
945
+ commonValue = branchOperands[argIdx];
946
+ continue ;
947
+ }
948
+ if (branchOperands[argIdx] != commonValue) {
949
+ sameArg = false ;
950
+ break ;
951
+ }
952
+ }
953
+
954
+ // If they are passing the same value, drop the argument.
955
+ if (commonValue && sameArg) {
956
+ argsToErase.push_back (argIdx);
957
+
958
+ // Remove the argument from the block.
959
+ rewriter.replaceAllUsesWith (blockOperand, commonValue);
960
+ }
961
+ }
962
+
963
+ // Remove the arguments.
964
+ for (size_t argIdx : llvm::reverse (argsToErase)) {
965
+ block.eraseArgument (argIdx);
966
+
967
+ // Remove the argument from the branch ops.
968
+ for (auto predIt = block.pred_begin (), predE = block.pred_end ();
969
+ predIt != predE; ++predIt) {
970
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator ());
971
+ unsigned succIndex = predIt.getSuccessorIndex ();
972
+ SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
973
+ succOperands.erase (argIdx);
974
+ }
975
+ }
976
+ return success (!argsToErase.empty ());
977
+ }
978
+
979
+ // / This optimization drops redundant argument to blocks. I.e., if a given
980
+ // / argument to a block receives the same value from each of the block
981
+ // / predecessors, we can remove the argument from the block and use directly the
982
+ // / original value. This is a simple example:
983
+ // /
984
+ // / %cond = llvm.call @rand() : () -> i1
985
+ // / %val0 = llvm.mlir.constant(1 : i64) : i64
986
+ // / %val1 = llvm.mlir.constant(2 : i64) : i64
987
+ // / %val2 = llvm.mlir.constant(3 : i64) : i64
988
+ // / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
989
+ // / : i64)
990
+ // /
991
+ // / ^bb1(%arg0 : i64, %arg1 : i64):
992
+ // / llvm.call @foo(%arg0, %arg1)
993
+ // /
994
+ // / The previous IR can be rewritten as:
995
+ // / %cond = llvm.call @rand() : () -> i1
996
+ // / %val0 = llvm.mlir.constant(1 : i64) : i64
997
+ // / %val1 = llvm.mlir.constant(2 : i64) : i64
998
+ // / %val2 = llvm.mlir.constant(3 : i64) : i64
999
+ // / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
1000
+ // /
1001
+ // / ^bb1(%arg0 : i64):
1002
+ // / llvm.call @foo(%val0, %arg0)
1003
+ // /
1004
+ static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
1005
+ MutableArrayRef<Region> regions) {
1006
+ llvm::SmallSetVector<Region *, 1 > worklist;
1007
+ for (Region ®ion : regions)
1008
+ worklist.insert (®ion);
1009
+ bool anyChanged = false ;
1010
+ while (!worklist.empty ()) {
1011
+ Region *region = worklist.pop_back_val ();
1012
+
1013
+ // Add any nested regions to the worklist.
1014
+ for (Block &block : *region) {
1015
+ anyChanged = succeeded (dropRedundantArguments (rewriter, block));
1016
+
1017
+ for (Operation &op : block)
1018
+ for (Region &nestedRegion : op.getRegions ())
1019
+ worklist.insert (&nestedRegion);
1020
+ }
1021
+ }
1022
+ return success (anyChanged);
1023
+ }
1024
+
821
1025
// ===----------------------------------------------------------------------===//
822
1026
// Region Simplification
823
1027
// ===----------------------------------------------------------------------===//
@@ -832,8 +1036,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
832
1036
bool eliminatedBlocks = succeeded (eraseUnreachableBlocks (rewriter, regions));
833
1037
bool eliminatedOpsOrArgs = succeeded (runRegionDCE (rewriter, regions));
834
1038
bool mergedIdenticalBlocks = false ;
835
- if (mergeBlocks)
1039
+ bool droppedRedundantArguments = false ;
1040
+ if (mergeBlocks) {
836
1041
mergedIdenticalBlocks = succeeded (mergeIdenticalBlocks (rewriter, regions));
1042
+ droppedRedundantArguments =
1043
+ succeeded (dropRedundantArguments (rewriter, regions));
1044
+ }
837
1045
return success (eliminatedBlocks || eliminatedOpsOrArgs ||
838
- mergedIdenticalBlocks);
1046
+ mergedIdenticalBlocks || droppedRedundantArguments );
839
1047
}
0 commit comments