Skip to content

Commit 24797e2

Browse files
committed
review comments
1 parent fa1a29f commit 24797e2

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3885,8 +3885,8 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> {
38853885
}
38863886
};
38873887

3888-
/// If both ranges contain same values return mappping indices from args1 to
3889-
/// args2. Otherwise return std::nullopt
3888+
/// If both ranges contain same values return mappping indices from args2 to
3889+
/// args1. Otherwise return std::nullopt.
38903890
static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
38913891
ValueRange args2) {
38923892
if (args1.size() != args2.size())
@@ -3898,16 +3898,26 @@ static std::optional<SmallVector<unsigned>> getArgsMapping(ValueRange args1,
38983898
if (it == args2.end())
38993899
return std::nullopt;
39003900

3901-
auto j = it - args2.begin();
3902-
ret[j] = static_cast<unsigned>(i);
3901+
ret[std::distance(args2.begin(), it)] = static_cast<unsigned>(i);
39033902
}
39043903

39053904
return ret;
39063905
}
39073906

3907+
static bool hasDuplicates(ValueRange args) {
3908+
llvm::SmallDenseSet<Value> set;
3909+
for (Value arg : args) {
3910+
if (set.contains(arg))
3911+
return true;
3912+
3913+
set.insert(arg);
3914+
}
3915+
return false;
3916+
}
3917+
39083918
/// If `before` block args are directly forwarded to `scf.condition`, rearrange
39093919
/// `scf.condition` args into same order as block args. Update `after` block
3910-
// args and results values accordingly.
3920+
// args and op result values accordingly.
39113921
/// Needed to simplify `scf.while` -> `scf.for` uplifting.
39123922
struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
39133923
using OpRewritePattern::OpRewritePattern;
@@ -3921,6 +3931,9 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> {
39213931
if (beforeArgs == termArgs)
39223932
return failure();
39233933

3934+
if (hasDuplicates(termArgs))
3935+
return failure();
3936+
39243937
auto mapping = getArgsMapping(beforeArgs, termArgs);
39253938
if (!mapping)
39263939
return failure();

0 commit comments

Comments
 (0)