Skip to content

Commit ef73336

Browse files
committed
Address review feedbacks
1 parent 12755f7 commit ef73336

File tree

1 file changed

+68
-59
lines changed

1 file changed

+68
-59
lines changed

mlir/lib/Transforms/Utils/RegionUtils.cpp

Lines changed: 68 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -681,60 +681,70 @@ static bool ableToUpdatePredOperands(Block *block) {
681681

682682
/// Prunes the redundant list of arguments. E.g., if we are passing an argument
683683
/// list like [x, y, z, x] this would return [x, y, z] and it would update the
684-
/// `block` (to whom the argument are passed to) accordingly
685-
static void
686-
pruneRedundantArguments(SmallVector<SmallVector<Value, 8>, 2> &newArguments,
687-
RewriterBase &rewriter, Block *block) {
684+
/// `block` (to whom the argument are passed to) accordingly.
685+
static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
686+
const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
687+
RewriterBase &rewriter, Block *block) {
688+
688689
SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
689690
newArguments.size(), SmallVector<Value, 8>());
690691

691-
if (!newArguments.empty()) {
692-
llvm::DenseMap<unsigned, unsigned> toReplace;
693-
// Go through the first list of arguments (list 0)
694-
for (unsigned j = 0; j < newArguments[0].size(); j++) {
695-
bool shouldReplaceJ = false;
696-
unsigned replacement = 0;
697-
// Look back to see if there are possible redundancies in
698-
// list 0
699-
for (unsigned k = 0; k < j; k++) {
700-
if (newArguments[0][k] == newArguments[0][j]) {
701-
shouldReplaceJ = true;
702-
replacement = k;
703-
// If a possible redundancy is found, then scan the other lists: we
704-
// can prune the arguments if and only if they are redundant in every
705-
// list
706-
for (unsigned i = 1; i < newArguments.size(); i++)
707-
shouldReplaceJ =
708-
shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
709-
}
692+
if (newArguments.empty())
693+
return newArguments;
694+
695+
// `newArguments` is a 2D array of size `numLists` x `numArgs`
696+
unsigned numLists = newArguments.size();
697+
unsigned numArgs = newArguments[0].size();
698+
699+
// Map that for each arg index contains the index that we can use in place of
700+
// the original index. E.g., if we have newArgs = [x, y, z, x], we will have
701+
// idxToReplacement[3] = 0
702+
llvm::DenseMap<unsigned, unsigned> idxToReplacement;
703+
704+
// Go through the first list of arguments (list 0).
705+
for (unsigned j = 0; j < numArgs; ++j) {
706+
bool shouldReplaceJ = false;
707+
unsigned replacement = 0;
708+
// Look back to see if there are possible redundancies in
709+
// list 0.
710+
for (unsigned k = 0; k < j; k++) {
711+
if (newArguments[0][k] == newArguments[0][j]) {
712+
shouldReplaceJ = true;
713+
replacement = k;
714+
// If a possible redundancy is found, then scan the other lists: we
715+
// can prune the arguments if and only if they are redundant in every
716+
// list.
717+
for (unsigned i = 1; i < numLists; ++i)
718+
shouldReplaceJ =
719+
shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
710720
}
711-
// Save the replacement
712-
if (shouldReplaceJ)
713-
toReplace[j] = replacement;
714721
}
722+
// Save the replacement.
723+
if (shouldReplaceJ)
724+
idxToReplacement[j] = replacement;
725+
}
715726

716-
// Populate the pruned argument list
717-
for (unsigned i = 0; i < newArguments.size(); i++)
718-
for (unsigned j = 0; j < newArguments[i].size(); j++)
719-
if (!toReplace.contains(j))
720-
newArgumentsPruned[i].push_back(newArguments[i][j]);
721-
722-
// Replace the block's redundant arguments
723-
SmallVector<unsigned> toErase;
724-
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
725-
if (toReplace.contains(idx)) {
726-
Value oldArg = block->getArgument(idx);
727-
Value newArg = block->getArgument(toReplace[idx]);
728-
rewriter.replaceAllUsesWith(oldArg, newArg);
729-
toErase.push_back(idx);
730-
}
727+
// Populate the pruned argument list.
728+
for (unsigned i = 0; i < numLists; ++i)
729+
for (unsigned j = 0; j < numArgs; ++j)
730+
if (!idxToReplacement.contains(j))
731+
newArgumentsPruned[i].push_back(newArguments[i][j]);
732+
733+
// Replace the block's redundant arguments.
734+
SmallVector<unsigned> toErase;
735+
for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
736+
if (idxToReplacement.contains(idx)) {
737+
Value oldArg = block->getArgument(idx);
738+
Value newArg = block->getArgument(idxToReplacement[idx]);
739+
rewriter.replaceAllUsesWith(oldArg, newArg);
740+
toErase.push_back(idx);
731741
}
732-
733-
// Erase the block's redundant arguments
734-
for (auto idxToErase : llvm::reverse(toErase))
735-
block->eraseArgument(idxToErase);
736-
newArguments = newArgumentsPruned;
737742
}
743+
744+
// Erase the block's redundant arguments.
745+
for (unsigned idxToErase : llvm::reverse(toErase))
746+
block->eraseArgument(idxToErase);
747+
return newArgumentsPruned;
738748
}
739749

740750
LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
@@ -787,7 +797,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
787797
}
788798

789799
// Prune redundant arguments and update the leader block argument list
790-
pruneRedundantArguments(newArguments, rewriter, leaderBlock);
800+
newArguments = pruneRedundantArguments(newArguments, rewriter, leaderBlock);
791801

792802
// Update the predecessors for each of the blocks.
793803
auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
@@ -889,13 +899,13 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
889899
Block &block) {
890900
SmallVector<size_t> argsToErase;
891901

892-
// Go through the arguments of the block
893-
for (size_t argIdx = 0; argIdx < block.getNumArguments(); argIdx++) {
902+
// Go through the arguments of the block.
903+
for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
894904
bool sameArg = true;
895905
Value commonValue;
896906

897907
// Go through the block predecessor and flag if they pass to the block
898-
// different values for the same argument
908+
// different values for the same argument.
899909
for (auto predIt = block.pred_begin(), predE = block.pred_end();
900910
predIt != predE; ++predIt) {
901911
auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
@@ -905,32 +915,31 @@ static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
905915
}
906916
unsigned succIndex = predIt.getSuccessorIndex();
907917
SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
908-
auto operands = succOperands.getForwardedOperands();
918+
auto branchOperands = succOperands.getForwardedOperands();
909919
if (!commonValue) {
910-
commonValue = operands[argIdx];
920+
commonValue = branchOperands[argIdx];
911921
} else {
912-
if (operands[argIdx] != commonValue) {
922+
if (branchOperands[argIdx] != commonValue) {
913923
sameArg = false;
914924
break;
915925
}
916926
}
917927
}
918928

919-
// If they are passing the same value, drop the argument
929+
// If they are passing the same value, drop the argument.
920930
if (commonValue && sameArg) {
921931
argsToErase.push_back(argIdx);
922932

923-
// Remove the argument from the block
924-
Value argVal = block.getArgument(argIdx);
925-
rewriter.replaceAllUsesWith(argVal, commonValue);
933+
// Remove the argument from the block.
934+
rewriter.replaceAllUsesWith(blockOperand, commonValue);
926935
}
927936
}
928937

929-
// Remove the arguments
938+
// Remove the arguments.
930939
for (auto argIdx : llvm::reverse(argsToErase)) {
931940
block.eraseArgument(argIdx);
932941

933-
// Remove the argument from the branch ops
942+
// Remove the argument from the branch ops.
934943
for (auto predIt = block.pred_begin(), predE = block.pred_end();
935944
predIt != predE; ++predIt) {
936945
auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());

0 commit comments

Comments
 (0)