Skip to content

Commit b6a28eb

Browse files
Use SmallVector<Value,1> by default
1 parent 5aadd7f commit b6a28eb

File tree

5 files changed

+46
-25
lines changed

5 files changed

+46
-25
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -897,11 +897,33 @@ class ConversionPatternRewriter final : public PatternRewriter {
897897

898898
/// Replace the given operation with the new value ranges. The number of op
899899
/// results and value ranges must match. The given operation is erased.
900-
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
901-
template <typename RangeT>
902-
void replaceOpWithMultiple(Operation *op, RangeT &&newValues) {
903-
replaceOpWithMultiple(op,
904-
ArrayRef(llvm::to_vector_of<ValueRange>(newValues)));
900+
void replaceOpWithMultiple(Operation *op,
901+
ArrayRef<SmallVector<Value, 1>> newValues);
902+
// Note: This overload matches SmallVector<ValueRange>,
903+
// SmallVector<SmallVector<Value>>, etc.
904+
template <typename RangeRangeT>
905+
void replaceOpWithMultiple(Operation *op, RangeRangeT &&newValues) {
906+
// Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
907+
// does not copy the replacements vector.
908+
auto vals = llvm::map_to_vector(newValues, [](const auto &r) {
909+
// Note: Create intermediate ValueRange because SmallVector<Value, 1>
910+
// is not constructible from SmallVector<Value>.
911+
return SmallVector<Value, 1>(ValueRange(r));
912+
});
913+
replaceOpWithMultiple(op, ArrayRef(vals));
914+
}
915+
// Note: This overload matches initializer list of ValueRange,
916+
// SmallVector<Value>, etc.
917+
template <typename RangeT = ValueRange>
918+
void replaceOpWithMultiple(Operation *op, ArrayRef<RangeT> newValues) {
919+
// Note: Prefer the ArrayRef<SmallVector<Value, 1>> overload because it
920+
// does not copy the replacements vector.
921+
auto vals = llvm::map_to_vector(newValues, [](const RangeT &r) {
922+
// Note: Create intermediate ValueRange because SmallVector<Value, 1>
923+
// is not constructible from SmallVector<Value>.
924+
return SmallVector<Value, 1>(ValueRange(r));
925+
});
926+
replaceOpWithMultiple(op, ArrayRef(vals));
905927
}
906928

907929
/// PatternRewriter hook for erasing a dead operation. The uses of this

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ struct LegalizeArithConstantOpsByDecomposition
192192
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
193193
auto tileSplat = rewriter.create<arith::ConstantOp>(
194194
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
195-
SmallVector<Value> repl(tileCount, tileSplat);
195+
SmallVector<Value, 1> repl(tileCount, tileSplat);
196196
rewriter.replaceOpWithMultiple(constantOp, {repl});
197197

198198
return success();
@@ -232,7 +232,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
232232
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
233233
VectorType sliceType = VectorType::Builder(smeTileType).dropDim(0);
234234

235-
SmallVector<Value> resultSMETiles;
235+
SmallVector<Value, 1> resultSMETiles;
236236
for (auto [index, smeTile] : llvm::enumerate(
237237
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
238238

@@ -310,7 +310,7 @@ struct LegalizeTransferReadOpsByDecomposition
310310
auto loc = readOp.getLoc();
311311
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
312312

313-
SmallVector<Value> resultSMETiles;
313+
SmallVector<Value, 1> resultSMETiles;
314314
for (SMESubTile smeTile :
315315
decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) {
316316
auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile);

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
585585
auto newCall = rewriter.create<func::CallOp>(
586586
loc, op.getCallee(), finalRetTy, flattenValues(adaptor.getOperands()));
587587
// (2) Gather sparse tensor returns.
588-
SmallVector<SmallVector<Value>> packedResultVals;
588+
SmallVector<SmallVector<Value, 1>> packedResultVals;
589589
// Tracks the offset of current return value (of the original call)
590590
// relative to the new call (after sparse tensor flattening);
591591
unsigned retOffset = 0;
@@ -752,7 +752,7 @@ class SparseTensorAllocConverter
752752
if (op.getCopy()) {
753753
auto desc = getDescriptorFromTensorTuple(
754754
adaptor.getCopy(), cast<RankedTensorType>(op.getCopy().getType()));
755-
SmallVector<Value> fields;
755+
SmallVector<Value, 1> fields;
756756
fields.reserve(desc.getNumFields());
757757
// Memcpy on memref fields.
758758
for (auto field : desc.getMemRefFields()) {
@@ -823,7 +823,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
823823
/*dimSizesValues=*/lvlSizesValues);
824824
// Construct allocation for each field.
825825
Value sizeHint; // none
826-
SmallVector<Value> fields;
826+
SmallVector<Value, 1> fields;
827827
createAllocFields(rewriter, loc, resType, enableBufferInitialization,
828828
sizeHint, lvlSizesValues, fields);
829829

@@ -1176,7 +1176,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
11761176
Location loc = op.getLoc();
11771177
auto srcDesc = getDescriptorFromTensorTuple(adaptor.getSource(),
11781178
op.getSource().getType());
1179-
SmallVector<Value> fields;
1179+
SmallVector<Value, 1> fields;
11801180
foreachFieldAndTypeInSparseTensor(
11811181
SparseTensorType(cast<RankedTensorType>(op.getResult().getType())),
11821182
[&rewriter, &fields, srcDesc,

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -947,7 +947,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
947947
OpBuilder::InsertPoint previous) override;
948948

949949
/// Notifies that an op is about to be replaced with the given values.
950-
void notifyOpReplaced(Operation *op, ArrayRef<ValueRange> newValues);
950+
void notifyOpReplaced(Operation *op, ArrayRef<ValueVector> newValues);
951951

952952
/// Notifies that a block is about to be erased.
953953
void notifyBlockIsBeingErased(Block *block);
@@ -1520,7 +1520,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
15201520
}
15211521

15221522
void ConversionPatternRewriterImpl::notifyOpReplaced(
1523-
Operation *op, ArrayRef<ValueRange> newValues) {
1523+
Operation *op, ArrayRef<ValueVector> newValues) {
15241524
assert(newValues.size() == op->getNumResults());
15251525
assert(!ignoredOps.contains(op) && "operation was already replaced");
15261526

@@ -1640,19 +1640,15 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
16401640
impl->logger.startLine()
16411641
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
16421642
});
1643-
SmallVector<ValueRange> newVals;
1644-
for (size_t i = 0; i < newValues.size(); ++i) {
1645-
if (newValues[i]) {
1646-
newVals.push_back(newValues.slice(i, 1));
1647-
} else {
1648-
newVals.push_back(ValueRange());
1649-
}
1650-
}
1643+
SmallVector<ValueVector> newVals =
1644+
llvm::map_to_vector(newValues, [](Value v) -> ValueVector {
1645+
return v ? ValueVector{v} : ValueVector();
1646+
});
16511647
impl->notifyOpReplaced(op, newVals);
16521648
}
16531649

16541650
void ConversionPatternRewriter::replaceOpWithMultiple(
1655-
Operation *op, ArrayRef<ValueRange> newValues) {
1651+
Operation *op, ArrayRef<SmallVector<Value, 1>> newValues) {
16561652
assert(op->getNumResults() == newValues.size() &&
16571653
"incorrect # of replacement values");
16581654
LLVM_DEBUG({
@@ -1667,7 +1663,7 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
16671663
impl->logger.startLine()
16681664
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
16691665
});
1670-
SmallVector<ValueRange> nullRepls(op->getNumResults(), {});
1666+
SmallVector<ValueVector> nullRepls(op->getNumResults(), ValueVector());
16711667
impl->notifyOpReplaced(op, nullRepls);
16721668
}
16731669

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1284,14 +1284,17 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
12841284
ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
12851285
SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
12861286
SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
1287-
SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
1287+
SmallVector<ArrayRef<Value>> r6, SmallVector<SmallVector<Value, 1>> r7,
1288+
ArrayRef<SmallVector<Value, 1>> r8, Value v, ValueRange vr,
12881289
ArrayRef<Value> ar) {
12891290
rewriter.replaceOpWithMultiple(op, r1);
12901291
rewriter.replaceOpWithMultiple(op, r2);
12911292
rewriter.replaceOpWithMultiple(op, r3);
12921293
rewriter.replaceOpWithMultiple(op, r4);
12931294
rewriter.replaceOpWithMultiple(op, r5);
12941295
rewriter.replaceOpWithMultiple(op, r6);
1296+
rewriter.replaceOpWithMultiple(op, r7);
1297+
rewriter.replaceOpWithMultiple(op, r8);
12951298
rewriter.replaceOpWithMultiple(op, {vr});
12961299
rewriter.replaceOpWithMultiple(op, {ar});
12971300
rewriter.replaceOpWithMultiple(op, {{v}});

0 commit comments

Comments
 (0)