Skip to content

Commit c4e8397

Browse files
[mlir][Transforms] Improve replaceOpWithMultiple API
1 parent c482b8f commit c4e8397

File tree

3 files changed

+27
-2
lines changed

3 files changed

+27
-2
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,10 @@ class ConversionPatternRewriter final : public PatternRewriter {
898898
/// Replace the given operation with the new value ranges. The number of op
899899
/// results and value ranges must match. The given operation is erased.
900900
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
901+
template <typename RangeT>
902+
void replaceOpWithMultiple(Operation *op, RangeT newValues) {
903+
replaceOpWithMultiple(op, llvm::to_vector_of<ValueRange>(newValues));
904+
}
901905

902906
/// PatternRewriter hook for erasing a dead operation. The uses of this
903907
/// operation *must* be made dead by the end of the conversion process,

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
616616
}
617617

618618
assert(packedResultVals.size() == op.getNumResults());
619-
rewriter.replaceOpWithMultiple(
620-
op, llvm::to_vector_of<ValueRange>(packedResultVals));
619+
rewriter.replaceOpWithMultiple(op, packedResultVals);
621620
return success();
622621
}
623622
};

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,28 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
12781278
}
12791279
};
12801280

1281+
/// Test unambiguous overload resolution of replaceOpWithMultiple. This
1282+
/// function is just to trigger compiler errors. It is never executed.
1283+
[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
1284+
ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
1285+
SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
1286+
SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
1287+
SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
1288+
ArrayRef<Value> ar) {
1289+
rewriter.replaceOpWithMultiple(op, r1);
1290+
rewriter.replaceOpWithMultiple(op, r2);
1291+
rewriter.replaceOpWithMultiple(op, r3);
1292+
rewriter.replaceOpWithMultiple(op, r4);
1293+
rewriter.replaceOpWithMultiple(op, r5);
1294+
rewriter.replaceOpWithMultiple(op, r6);
1295+
rewriter.replaceOpWithMultiple(op, {vr});
1296+
rewriter.replaceOpWithMultiple(op, {ar});
1297+
rewriter.replaceOpWithMultiple(op, {{v}});
1298+
rewriter.replaceOpWithMultiple(op, {{v, v}});
1299+
rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
1300+
rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
1301+
rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
1302+
}
12811303
} // namespace
12821304

12831305
namespace {

0 commit comments

Comments
 (0)