@@ -679,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
679
679
return true ;
680
680
}
681
681
682
+ // / Prunes the redundant list of arguments. E.g., if we are passing an argument
683
+ // / list like [x, y, z, x] this would return [x, y, z] and it would update the
684
+ // / `block` (to whom the argument are passed to) accordingly
685
+ static void
686
+ pruneRedundantArguments (SmallVector<SmallVector<Value, 8 >, 2 > &newArguments,
687
+ RewriterBase &rewriter, Block *block) {
688
+ SmallVector<SmallVector<Value, 8 >, 2 > newArgumentsPruned (
689
+ newArguments.size (), SmallVector<Value, 8 >());
690
+
691
+ if (!newArguments.empty ()) {
692
+ llvm::DenseMap<unsigned , unsigned > toReplace;
693
+ // Go through the first list of arguments (list 0)
694
+ for (unsigned j = 0 ; j < newArguments[0 ].size (); j++) {
695
+ bool shouldReplaceJ = false ;
696
+ unsigned replacement = 0 ;
697
+ // Look back to see if there are possible redundancies in
698
+ // list 0
699
+ for (unsigned k = 0 ; k < j; k++) {
700
+ if (newArguments[0 ][k] == newArguments[0 ][j]) {
701
+ shouldReplaceJ = true ;
702
+ replacement = k;
703
+ // If a possible redundancy is found, then scan the other lists: we
704
+ // can prune the arguments if and only if they are redundant in every
705
+ // list
706
+ for (unsigned i = 1 ; i < newArguments.size (); i++)
707
+ shouldReplaceJ =
708
+ shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
709
+ }
710
+ }
711
+ // Save the replacement
712
+ if (shouldReplaceJ)
713
+ toReplace[j] = replacement;
714
+ }
715
+
716
+ // Populate the pruned argument list
717
+ for (unsigned i = 0 ; i < newArguments.size (); i++)
718
+ for (unsigned j = 0 ; j < newArguments[i].size (); j++)
719
+ if (!toReplace.contains (j))
720
+ newArgumentsPruned[i].push_back (newArguments[i][j]);
721
+
722
+ // Replace the block's redundant arguments
723
+ SmallVector<unsigned > toErase;
724
+ for (auto [idx, arg] : llvm::enumerate (block->getArguments ())) {
725
+ if (toReplace.contains (idx)) {
726
+ Value oldArg = block->getArgument (idx);
727
+ Value newArg = block->getArgument (toReplace[idx]);
728
+ rewriter.replaceAllUsesWith (oldArg, newArg);
729
+ toErase.push_back (idx);
730
+ }
731
+ }
732
+
733
+ // Erase the block's redundant arguments
734
+ for (auto idxToErase : llvm::reverse (toErase))
735
+ block->eraseArgument (idxToErase);
736
+ newArguments = newArgumentsPruned;
737
+ }
738
+ }
739
+
682
740
LogicalResult BlockMergeCluster::merge (RewriterBase &rewriter) {
683
741
// Don't consider clusters that don't have blocks to merge.
684
742
if (blocksToMerge.empty ())
@@ -704,8 +762,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
704
762
blockIterators.push_back (mergeBlock->begin ());
705
763
706
764
// Update each of the predecessor terminators with the new arguments.
707
- SmallVector<SmallVector<Value, 8 >, 2 > newArguments (1 + blocksToMerge.size (),
708
- SmallVector<Value, 8 >());
765
+ SmallVector<SmallVector<Value, 8 >, 2 > newArguments (
766
+ 1 + blocksToMerge.size (),
767
+ SmallVector<Value, 8 >(operandsToMerge.size ()));
709
768
unsigned curOpIndex = 0 ;
710
769
for (const auto &it : llvm::enumerate (operandsToMerge)) {
711
770
unsigned nextOpOffset = it.value ().first - curOpIndex;
@@ -716,25 +775,20 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
716
775
Block::iterator &blockIter = blockIterators[i];
717
776
std::advance (blockIter, nextOpOffset);
718
777
auto &operand = blockIter->getOpOperand (it.value ().second );
719
- Value operandVal = operand.get ();
720
- Value *it = std::find (newArguments[i].begin (), newArguments[i].end (),
721
- operandVal);
722
- if (it == newArguments[i].end ()) {
723
- newArguments[i].push_back (operandVal);
724
- // Update the operand and insert an argument if this is the leader.
725
- if (i == 0 ) {
726
- operand.set (leaderBlock->addArgument (operandVal.getType (),
727
- operandVal.getLoc ()));
728
- }
729
- } else if (i == 0 ) {
730
- // If this is the leader, update the operand but do not insert a new
731
- // argument. Instead, the opearand should point to one of the
732
- // arguments we already passed (and that contained `operandVal`)
733
- operand.set (leaderBlock->getArgument (
734
- std::distance (newArguments[i].begin (), it)));
778
+ newArguments[i][it.index ()] = operand.get ();
779
+
780
+ // Update the operand and insert an argument if this is the leader.
781
+ if (i == 0 ) {
782
+ Value operandVal = operand.get ();
783
+ operand.set (leaderBlock->addArgument (operandVal.getType (),
784
+ operandVal.getLoc ()));
735
785
}
736
786
}
737
787
}
788
+
789
+ // Prune redundant arguments and update the leader block argument list
790
+ pruneRedundantArguments (newArguments, rewriter, leaderBlock);
791
+
738
792
// Update the predecessors for each of the blocks.
739
793
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
740
794
for (auto predIt = block->pred_begin (), predE = block->pred_end ();
@@ -896,17 +950,22 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
896
950
// / %cond = llvm.call @rand() : () -> i1
897
951
// / %val0 = llvm.mlir.constant(1 : i64) : i64
898
952
// / %val1 = llvm.mlir.constant(2 : i64) : i64
899
- // / %val2 = llvm.mlir.constant(2 : i64) : i64
953
+ // / %val2 = llvm.mlir.constant(3 : i64) : i64
900
954
// / llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
901
- // / : i64) ^bb1(%arg0 : i64, %arg1 : i64):
955
+ // / : i64)
956
+ // /
957
+ // / ^bb1(%arg0 : i64, %arg1 : i64):
902
958
// / llvm.call @foo(%arg0, %arg1)
903
959
// /
904
960
// / The previous IR can be rewritten as:
905
961
// / %cond = llvm.call @rand() : () -> i1
906
- // / %val = llvm.mlir.constant(1 : i64) : i64
962
+ // / %val0 = llvm.mlir.constant(1 : i64) : i64
963
+ // / %val1 = llvm.mlir.constant(2 : i64) : i64
964
+ // / %val2 = llvm.mlir.constant(3 : i64) : i64
907
965
// / llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
966
+ // /
908
967
// / ^bb1(%arg0 : i64):
909
- // / llvm.call @foo(%val0, %arg1 )
968
+ // / llvm.call @foo(%val0, %arg0 )
910
969
// /
911
970
static LogicalResult dropRedundantArguments (RewriterBase &rewriter,
912
971
MutableArrayRef<Region> regions) {
0 commit comments