Skip to content

Commit 249a387

Browse files
committed
Fix integration tests
1 parent 8d9fe79 commit 249a387

File tree

2 files changed

+167
-22
lines changed

2 files changed

+167
-22
lines changed

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 81 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,64 @@ static bool ableToUpdatePredOperands(Block *block) {
679679
return true;
680680
}
681681

682+
/// Prunes the redundant list of arguments. E.g., if we are passing an argument
683+
/// list like [x, y, z, x] this would return [x, y, z] and it would update the
684+
/// `block` (to whom the argument are passed to) accordingly
685+
static void
686+
pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
687+
RewriterBase &rewriter, Block *block) {
688+
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
689+
newArguments.size(), SmallVector<Value, 8>());
690+
691+
if (!newArguments.empty()) {
692+
llvm::DenseMap<unsigned, unsigned> toReplace;
693+
// Go through the first list of arguments (list 0)
694+
for (unsigned j = 0; j < newArguments[0].size(); j++) {
695+
bool shouldReplaceJ = false;
696+
unsigned replacement = 0;
697+
// Look back to see if there are possible redundancies in
698+
// list 0
699+
for (unsigned k = 0; k < j; k++) {
700+
if (newArguments[0][k] == newArguments[0][j]) {
701+
shouldReplaceJ = true;
702+
replacement = k;
703+
// If a possible redundancy is found, then scan the other lists: we
704+
// can prune the arguments if and only if they are redundant in every
705+
// list
706+
for (unsigned i = 1; i < newArguments.size(); i++)
707+
shouldReplaceJ =
708+
shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
709+
}
710+
}
711+
// Save the replacement
712+
if (shouldReplaceJ)
713+
toReplace[j] = replacement;
714+
}
715+
716+
// Populate the pruned argument list
717+
for (unsigned i = 0; i < newArguments.size(); i++)
718+
for (unsigned j = 0; j < newArguments[i].size(); j++)
719+
if (!toReplace.contains(j))
720+
newArgumentsPruned[i].push_back(newArguments[i][j]);
721+
722+
// Replace the block's redundant arguments
723+
SmallVector<unsigned> toErase;
724+
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
725+
if (toReplace.contains(idx)) {
726+
Value oldArg = block->getArgument(idx);
727+
Value newArg = block->getArgument(toReplace[idx]);
728+
rewriter.replaceAllUsesWith(oldArg, newArg);
729+
toErase.push_back(idx);
730+
}
731+
}
732+
733+
// Erase the block's redundant arguments
734+
for (auto idxToErase : llvm::reverse(toErase))
735+
block->eraseArgument(idxToErase);
736+
newArguments = newArgumentsPruned;
737+
}
738+
}
739+
682740
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
683741
// Don't consider clusters that don't have blocks to merge.
684742
if (blocksToMerge.empty())
@@ -704,8 +762,9 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
704762
blockIterators.push_back(mergeBlock->begin());
705763

706764
// Update each of the predecessor terminators with the new arguments.
707-
SmallVector<SmallVector<Value, 8>, 2> newArguments(1 + blocksToMerge.size(),
708-
SmallVector<Value, 8>());
765+
SmallVector<SmallVector<Value, 8>, 2> newArguments(
766+
1 + blocksToMerge.size(),
767+
SmallVector<Value, 8>(operandsToMerge.size()));
709768
unsigned curOpIndex = 0;
710769
for (const auto &it : llvm::enumerate(operandsToMerge)) {
711770
unsigned nextOpOffset = it.value().first - curOpIndex;
@@ -716,25 +775,20 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
716775
Block::iterator &blockIter = blockIterators[i];
717776
std::advance(blockIter, nextOpOffset);
718777
auto &operand = blockIter->getOpOperand(it.value().second);
719-
Value operandVal = operand.get();
720-
Value *it = std::find(newArguments[i].begin(), newArguments[i].end(),
721-
operandVal);
722-
if (it == newArguments[i].end()) {
723-
newArguments[i].push_back(operandVal);
724-
// Update the operand and insert an argument if this is the leader.
725-
if (i == 0) {
726-
operand.set(leaderBlock->addArgument(operandVal.getType(),
727-
operandVal.getLoc()));
728-
}
729-
} else if (i == 0) {
730-
// If this is the leader, update the operand but do not insert a new
731-
// argument. Instead, the opearand should point to one of the
732-
// arguments we already passed (and that contained `operandVal`)
733-
operand.set(leaderBlock->getArgument(
734-
std::distance(newArguments[i].begin(), it)));
778+
newArguments[i][it.index()] = operand.get();
779+
780+
// Update the operand and insert an argument if this is the leader.
781+
if (i == 0) {
782+
Value operandVal = operand.get();
783+
operand.set(leaderBlock->addArgument(operandVal.getType(),
784+
operandVal.getLoc()));
735785
}
736786
}
737787
}
788+
789+
// Prune redundant arguments and update the leader block argument list
790+
pruneRedundantArguments(newArguments, rewriter, leaderBlock);
791+
738792
// Update the predecessors for each of the blocks.
739793
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
740794
for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -896,17 +950,22 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
896950
/// %cond = llvm.call @rand() : () -> i1
897951
/// %val0 = llvm.mlir.constant(1 : i64) : i64
898952
/// %val1 = llvm.mlir.constant(2 : i64) : i64
899-
/// %val2 = llvm.mlir.constant(2 : i64) : i64
953+
/// %val2 = llvm.mlir.constant(3 : i64) : i64
900954
/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
901-
/// : i64) ^bb1(%arg0 : i64, %arg1 : i64):
955+
/// : i64)
956+
///
957+
/// ^bb1(%arg0 : i64, %arg1 : i64):
902958
/// llvm.call @foo(%arg0, %arg1)
903959
///
904960
/// The previous IR can be rewritten as:
905961
/// %cond = llvm.call @rand() : () -> i1
906-
/// %val = llvm.mlir.constant(1 : i64) : i64
962+
/// %val0 = llvm.mlir.constant(1 : i64) : i64
963+
/// %val1 = llvm.mlir.constant(2 : i64) : i64
964+
/// %val2 = llvm.mlir.constant(3 : i64) : i64
907965
/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
966+
///
908967
/// ^bb1(%arg0 : i64):
909-
/// llvm.call @foo(%val0, %arg1)
968+
/// llvm.call @foo(%val0, %arg0)
910969
///
911970
static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
912971
MutableArrayRef<Region> regions) {

mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,89 @@ llvm.func @large_merge_block(%arg0: i64) {
7474
^bb27: // 2 preds: ^bb13, ^bb26
7575
llvm.return
7676
}
77+
78+
llvm.func @redundant_args0(%cond : i1) {
79+
%0 = llvm.mlir.constant(0 : i64) : i64
80+
%2 = llvm.mlir.constant(1 : i64) : i64
81+
%3 = llvm.mlir.constant(2 : i64) : i64
82+
// CHECK %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
83+
// CHECK %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
84+
// CHECK %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
85+
86+
llvm.cond_br %cond, ^bb1, ^bb2
87+
88+
// CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C0]], %[[C0]] : i64, i64), ^bb{{.*}}(%[[C1]], %[[C2]] : i64, i64)
89+
// CHECK: ^bb{{.*}}(%{{.*}}: i64, %{{.*}}: i64)
90+
^bb1:
91+
llvm.call @foo(%0) : (i64) -> ()
92+
llvm.call @foo(%0) : (i64) -> ()
93+
llvm.br ^bb3
94+
^bb2:
95+
llvm.call @foo(%2) : (i64) -> ()
96+
llvm.call @foo(%3) : (i64) -> ()
97+
llvm.br ^bb3
98+
^bb3:
99+
llvm.return
100+
}
101+
102+
llvm.func @redundant_args1(%cond : i1) {
103+
%0 = llvm.mlir.constant(0 : i64) : i64
104+
%2 = llvm.mlir.constant(1 : i64) : i64
105+
%3 = llvm.mlir.constant(2 : i64) : i64
106+
// CHECK %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
107+
// CHECK %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
108+
// CHECK %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
109+
110+
llvm.cond_br %cond, ^bb1, ^bb2
111+
112+
// CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C1]], %[[C2]] : i64, i64), ^bb{{.*}}(%[[C0]], %[[C0]] : i64, i64)
113+
// CHECK: ^bb{{.*}}(%{{.*}}: i64, %{{.*}}: i64)
114+
^bb1:
115+
llvm.call @foo(%2) : (i64) -> ()
116+
llvm.call @foo(%3) : (i64) -> ()
117+
llvm.br ^bb3
118+
^bb2:
119+
llvm.call @foo(%0) : (i64) -> ()
120+
llvm.call @foo(%0) : (i64) -> ()
121+
llvm.br ^bb3
122+
^bb3:
123+
llvm.return
124+
}
125+
126+
llvm.func @redundant_args_complex(%cond : i1) {
127+
%0 = llvm.mlir.constant(0 : i64) : i64
128+
%1 = llvm.mlir.constant(1 : i64) : i64
129+
%2 = llvm.mlir.constant(2 : i64) : i64
130+
%3 = llvm.mlir.constant(3 : i64) : i64
131+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
132+
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
133+
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
134+
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
135+
136+
llvm.cond_br %cond, ^bb1, ^bb2
137+
138+
// CHECK: llvm.cond_br %{{.*}}, ^bb{{.*}}(%[[C2]], %[[C1]], %[[C3]] : i64, i64, i64), ^bb{{.*}}(%[[C0]], %[[C3]], %[[C2]] : i64, i64, i64)
139+
// CHECK: ^bb{{.*}}(%[[arg0:.*]]: i64, %[[arg1:.*]]: i64, %[[arg2:.*]]: i64):
140+
// CHECK: llvm.call @foo(%[[arg0]])
141+
// CHECK: llvm.call @foo(%[[arg0]])
142+
// CHECK: llvm.call @foo(%[[arg1]])
143+
// CHECK: llvm.call @foo(%[[C2]])
144+
// CHECK: llvm.call @foo(%[[arg2]])
145+
146+
^bb1:
147+
llvm.call @foo(%2) : (i64) -> ()
148+
llvm.call @foo(%2) : (i64) -> ()
149+
llvm.call @foo(%1) : (i64) -> ()
150+
llvm.call @foo(%2) : (i64) -> ()
151+
llvm.call @foo(%3) : (i64) -> ()
152+
llvm.br ^bb3
153+
^bb2:
154+
llvm.call @foo(%0) : (i64) -> ()
155+
llvm.call @foo(%0) : (i64) -> ()
156+
llvm.call @foo(%3) : (i64) -> ()
157+
llvm.call @foo(%2) : (i64) -> ()
158+
llvm.call @foo(%2) : (i64) -> ()
159+
llvm.br ^bb3
160+
^bb3:
161+
llvm.return
162+
}

0 commit comments

Comments
 (0)