9
9
#include " mlir/Transforms/RegionUtils.h"
10
10
#include " mlir/Analysis/TopologicalSortUtils.h"
11
11
#include " mlir/IR/Block.h"
12
+ #include " mlir/IR/BuiltinOps.h"
12
13
#include " mlir/IR/IRMapping.h"
13
14
#include " mlir/IR/Operation.h"
14
15
#include " mlir/IR/PatternMatch.h"
15
16
#include " mlir/IR/RegionGraphTraits.h"
16
17
#include " mlir/IR/Value.h"
17
18
#include " mlir/Interfaces/ControlFlowInterfaces.h"
18
19
#include " mlir/Interfaces/SideEffectInterfaces.h"
20
+ #include " mlir/Support/LogicalResult.h"
19
21
20
22
#include " llvm/ADT/DepthFirstIterator.h"
21
23
#include " llvm/ADT/PostOrderIterator.h"
24
+ #include " llvm/ADT/STLExtras.h"
25
+ #include " llvm/ADT/SmallSet.h"
22
26
23
27
#include < deque>
28
+ #include < iterator>
24
29
25
30
using namespace mlir ;
26
31
@@ -699,9 +704,8 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
699
704
blockIterators.push_back (mergeBlock->begin ());
700
705
701
706
// Update each of the predecessor terminators with the new arguments.
702
- SmallVector<SmallVector<Value, 8 >, 2 > newArguments (
703
- 1 + blocksToMerge.size (),
704
- SmallVector<Value, 8 >(operandsToMerge.size ()));
707
+ SmallVector<SmallVector<Value, 8 >, 2 > newArguments (1 + blocksToMerge.size (),
708
+ SmallVector<Value, 8 >());
705
709
unsigned curOpIndex = 0 ;
706
710
for (const auto &it : llvm::enumerate (operandsToMerge)) {
707
711
unsigned nextOpOffset = it.value ().first - curOpIndex;
@@ -712,13 +716,22 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
712
716
Block::iterator &blockIter = blockIterators[i];
713
717
std::advance (blockIter, nextOpOffset);
714
718
auto &operand = blockIter->getOpOperand (it.value ().second );
715
- newArguments[i][it.index ()] = operand.get ();
716
-
717
- // Update the operand and insert an argument if this is the leader.
718
- if (i == 0 ) {
719
- Value operandVal = operand.get ();
720
- operand.set (leaderBlock->addArgument (operandVal.getType (),
721
- operandVal.getLoc ()));
719
+ Value operandVal = operand.get ();
720
+ Value *it = std::find (newArguments[i].begin (), newArguments[i].end (),
721
+ operandVal);
722
+ if (it == newArguments[i].end ()) {
723
+ newArguments[i].push_back (operandVal);
724
+ // Update the operand and insert an argument if this is the leader.
725
+ if (i == 0 ) {
726
+ operand.set (leaderBlock->addArgument (operandVal.getType (),
727
+ operandVal.getLoc ()));
728
+ }
729
+ } else if (i == 0 ) {
730
+ // If this is the leader, update the operand but do not insert a new
731
+ // argument. Instead, the opearand should point to one of the
732
+ // arguments we already passed (and that contained `operandVal`)
733
+ operand.set (leaderBlock->getArgument (
734
+ std::distance (newArguments[i].begin (), it)));
722
735
}
723
736
}
724
737
}
@@ -818,6 +831,104 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
818
831
return success (anyChanged);
819
832
}
820
833
834
+ static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
835
+ Block &block) {
836
+ SmallVector<size_t > argsToErase;
837
+
838
+ // Go through the arguments of the block
839
+ for (size_t argIdx = 0 ; argIdx < block.getNumArguments (); argIdx++) {
840
+ bool sameArg = true ;
841
+ Value commonValue;
842
+
843
+ // Go through the block predecessor and flag if they pass to the block
844
+ // different values for the same argument
845
+ for (auto predIt = block.pred_begin (), predE = block.pred_end ();
846
+ predIt != predE; ++predIt) {
847
+ auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator ());
848
+ if (!branch) {
849
+ sameArg = false ;
850
+ break ;
851
+ }
852
+ unsigned succIndex = predIt.getSuccessorIndex ();
853
+ SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
854
+ auto operands = succOperands.getForwardedOperands ();
855
+ if (!commonValue) {
856
+ commonValue = operands[argIdx];
857
+ } else {
858
+ if (operands[argIdx] != commonValue) {
859
+ sameArg = false ;
860
+ break ;
861
+ }
862
+ }
863
+ }
864
+
865
+ // If they are passing the same value, drop the argument
866
+ if (commonValue && sameArg) {
867
+ argsToErase.push_back (argIdx);
868
+
869
+ // Remove the argument from the block
870
+ Value argVal = block.getArgument (argIdx);
871
+ rewriter.replaceAllUsesWith (argVal, commonValue);
872
+ }
873
+ }
874
+
875
+ // Remove the arguments
876
+ for (auto argIdx : llvm::reverse (argsToErase)) {
877
+ block.eraseArgument (argIdx);
878
+
879
+ // Remove the argument from the branch ops
880
+ for (auto predIt = block.pred_begin (), predE = block.pred_end ();
881
+ predIt != predE; ++predIt) {
882
+ auto branch = cast<BranchOpInterface>((*predIt)->getTerminator ());
883
+ unsigned succIndex = predIt.getSuccessorIndex ();
884
+ SuccessorOperands succOperands = branch.getSuccessorOperands (succIndex);
885
+ succOperands.erase (argIdx);
886
+ }
887
+ }
888
+ return success (!argsToErase.empty ());
889
+ }
890
+
891
+ // / This optimization drops redundant argument to blocks. I.e., if a given
892
+ // / argument to a block receives the same value from each of the block
893
+ // / predecessors, we can remove the argument from the block and use directly the
894
+ // / original value. This is a simple example:
895
+ // /
896
+ // / %cond = llvm.call @rand() : () -> i1
897
+ // / %val0 = llvm.mlir.constant(1 : i64) : i64
898
+ // / %val1 = llvm.mlir.constant(2 : i64) : i64
899
+ // / %val2 = llvm.mlir.constant(2 : i64) : i64
900
+ // / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
901
+ // / : i64) ^bb1(%arg0 : i64, %arg1 : i64):
902
+ // / llvm.call @foo(%arg0, %arg1)
903
+ // /
904
+ // / The previous IR can be rewritten as:
905
+ // / %cond = llvm.call @rand() : () -> i1
906
+ // / %val = llvm.mlir.constant(1 : i64) : i64
907
+ // / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
908
+ // / ^bb1(%arg0 : i64):
909
+ // / llvm.call @foo(%val0, %arg1)
910
+ // /
911
+ static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
912
+ MutableArrayRef<Region> regions) {
913
+ llvm::SmallSetVector<Region *, 1 > worklist;
914
+ for (auto ®ion : regions)
915
+ worklist.insert (®ion);
916
+ bool anyChanged = false ;
917
+ while (!worklist.empty ()) {
918
+ Region *region = worklist.pop_back_val ();
919
+
920
+ // Add any nested regions to the worklist.
921
+ for (Block &block : *region) {
922
+ anyChanged = succeeded (dropRedundantArguments (rewriter, block));
923
+
924
+ for (auto &op : block)
925
+ for (auto &nestedRegion : op.getRegions ())
926
+ worklist.insert (&nestedRegion);
927
+ }
928
+ }
929
+ return success (anyChanged);
930
+ }
931
+
821
932
// ===----------------------------------------------------------------------===//
822
933
// Region Simplification
823
934
// ===----------------------------------------------------------------------===//
@@ -832,8 +943,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
832
943
bool eliminatedBlocks = succeeded (eraseUnreachableBlocks (rewriter, regions));
833
944
bool eliminatedOpsOrArgs = succeeded (runRegionDCE (rewriter, regions));
834
945
bool mergedIdenticalBlocks = false ;
835
- if (mergeBlocks)
946
+ bool droppedRedundantArguments = false ;
947
+ if (mergeBlocks) {
836
948
mergedIdenticalBlocks = succeeded (mergeIdenticalBlocks (rewriter, regions));
949
+ droppedRedundantArguments =
950
+ succeeded (dropRedundantArguments (rewriter, regions));
951
+ }
837
952
return success (eliminatedBlocks || eliminatedOpsOrArgs ||
838
- mergedIdenticalBlocks);
953
+ mergedIdenticalBlocks || droppedRedundantArguments );
839
954
}
0 commit comments