Skip to content

Commit 001453c

Browse files
replace with multiple
Apply suggestions from code review Co-authored-by: Markus Böck <[email protected]> address comments
1 parent 6c9256d commit 001453c

File tree

5 files changed

+164
-99
lines changed

5 files changed

+164
-99
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
795795
/// patterns even if a failure is encountered during the rewrite step.
796796
bool canRecoverFromRewriteFailure() const override { return true; }
797797

798-
/// PatternRewriter hook for replacing an operation.
798+
/// Replace the given operation with the new values. The number of op results
799+
/// and replacement values must match. The types may differ: the dialect
800+
/// conversion driver will reconcile any surviving type mismatches at the end
801+
/// of the conversion process with source materializations. The given
802+
/// operation is erased.
799803
void replaceOp(Operation *op, ValueRange newValues) override;
800804

801-
/// PatternRewriter hook for replacing an operation.
805+
/// Replace the given operation with the results of the new op. The number of
806+
/// op results must match. The types may differ: the dialect conversion
807+
/// driver will reconcile any surviving type mismatches at the end of the
808+
/// conversion process with source materializations. The original operation
809+
/// is erased.
802810
void replaceOp(Operation *op, Operation *newOp) override;
803811

812+
/// Replace the given operation with the new value ranges. The number of op
813+
/// results and value ranges must match. If an original SSA value is replaced
814+
/// by multiple SSA values (i.e., a value range has more than 1 element), the
815+
/// conversion driver will insert an argument materialization to convert the
816+
/// N SSA values back into 1 SSA value of the original type. The given
817+
/// operation is erased.
818+
///
819+
/// Note: The argument materialization is a workaround until we have full 1:N
820+
/// support in the dialect conversion. (It is going to disappear from both
821+
/// `replaceOpWithMultiple` and `applySignatureConversion`.)
822+
void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
823+
804824
/// PatternRewriter hook for erasing a dead operation. The uses of this
805825
/// operation *must* be made dead by the end of the conversion process,
806826
/// otherwise an assert will be issued.

mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
141141
getTypeConverter()));
142142
}
143143

144-
// Create the new result types for the new `CallOp` and track the indices in
145-
// the new call op's results that correspond to the old call op's results.
146-
//
147-
// expandedResultIndices[i] = "list of new result indices that old result i
148-
// expanded to".
144+
// Create the new result types for the new `CallOp` and track the number of
145+
// replacement types for each original op result.
149146
SmallVector<Type, 2> newResultTypes;
150-
SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
147+
SmallVector<unsigned> expandedResultSizes;
151148
for (Type resultType : op.getResultTypes()) {
152149
unsigned oldSize = newResultTypes.size();
153150
if (failed(typeConverter->convertType(resultType, newResultTypes)))
154151
return failure();
155-
auto &resultMapping = expandedResultIndices.emplace_back();
156-
for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
157-
resultMapping.push_back(i);
152+
expandedResultSizes.push_back(newResultTypes.size() - oldSize);
158153
}
159154

160155
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
161156
newResultTypes, newOperands);
162157

163-
// Build a replacement value for each result to replace its uses. If a
164-
// result has multiple mapping values, it needs to be materialized as a
165-
// single value.
166-
SmallVector<Value, 2> replacedValues;
158+
// Build a replacement value for each result to replace its uses.
159+
SmallVector<ValueRange> replacedValues;
167160
replacedValues.reserve(op.getNumResults());
161+
unsigned startIdx = 0;
168162
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
169-
auto decomposedValues = llvm::to_vector<6>(
170-
llvm::map_range(expandedResultIndices[i],
171-
[&](unsigned i) { return newCallOp.getResult(i); }));
172-
if (decomposedValues.empty()) {
173-
// No replacement is required.
174-
replacedValues.push_back(nullptr);
175-
} else if (decomposedValues.size() == 1) {
176-
replacedValues.push_back(decomposedValues.front());
177-
} else {
178-
// Materialize a single Value to replace the original Value.
179-
Value materialized = getTypeConverter()->materializeArgumentConversion(
180-
rewriter, op.getLoc(), op.getType(i), decomposedValues);
181-
replacedValues.push_back(materialized);
182-
}
163+
ValueRange repl =
164+
newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
165+
replacedValues.push_back(repl);
166+
startIdx += expandedResultSizes[i];
183167
}
184-
rewriter.replaceOp(op, replacedValues);
168+
rewriter.replaceOpWithMultiple(op, replacedValues);
185169
return success();
186170
}
187171
};

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
600600
flattenOperands(adaptor.getOperands(), flattened);
601601
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
602602
finalRetTy, flattened);
603-
// (2) Create cast operation for sparse tensor returns.
604-
SmallVector<Value> castedRet;
603+
// (2) Gather sparse tensor returns.
604+
SmallVector<SmallVector<Value>> packedResultVals;
605605
// Tracks the offset of current return value (of the original call)
606606
// relative to the new call (after sparse tensor flattening);
607607
unsigned retOffset = 0;
@@ -618,21 +618,22 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
618618
assert(!sparseFlat.empty());
619619
if (sparseFlat.size() > 1) {
620620
auto flatSize = sparseFlat.size();
621-
ValueRange fields(iterator_range<ResultRange::iterator>(
622-
newCall.result_begin() + retOffset,
623-
newCall.result_begin() + retOffset + flatSize));
624-
castedRet.push_back(genTuple(rewriter, loc, retType, fields));
621+
packedResultVals.emplace_back();
622+
llvm::append_range(packedResultVals.back(),
623+
newCall.getResults().slice(retOffset, flatSize));
625624
retOffset += flatSize;
626625
} else {
627626
// If this is an 1:1 conversion, no need for casting.
628-
castedRet.push_back(newCall.getResult(retOffset));
627+
packedResultVals.emplace_back();
628+
packedResultVals.back().push_back(newCall.getResult(retOffset));
629629
retOffset++;
630630
}
631631
sparseFlat.clear();
632632
}
633633

634-
assert(castedRet.size() == op.getNumResults());
635-
rewriter.replaceOp(op, castedRet);
634+
assert(packedResultVals.size() == op.getNumResults());
635+
rewriter.replaceOpWithMultiple(
636+
op, llvm::to_vector_of<ValueRange>(packedResultVals));
636637
return success();
637638
}
638639
};
@@ -776,7 +777,7 @@ class SparseTensorAllocConverter
776777
// Reuses specifier.
777778
fields.push_back(desc.getSpecifier());
778779
assert(fields.size() == desc.getNumFields());
779-
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
780+
rewriter.replaceOpWithMultiple(op, {fields});
780781
return success();
781782
}
782783

@@ -796,7 +797,7 @@ class SparseTensorAllocConverter
796797
sizeHint, lvlSizesValues, fields);
797798

798799
// Replace operation with resulting memrefs.
799-
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
800+
rewriter.replaceOpWithMultiple(op, {fields});
800801
return success();
801802
}
802803

@@ -837,7 +838,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
837838
sizeHint, lvlSizesValues, fields);
838839

839840
// Replace operation with resulting memrefs.
840-
rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
841+
rewriter.replaceOpWithMultiple(op, {fields});
841842
return success();
842843
}
843844

@@ -893,7 +894,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
893894
if (op.getHasInserts())
894895
genEndInsert(rewriter, op.getLoc(), desc);
895896
// Replace operation with resulting memrefs.
896-
rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
897+
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
897898
return success();
898899
}
899900
};
@@ -1006,15 +1007,14 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
10061007
rewriter.create<scf::YieldOp>(loc, insertRet);
10071008

10081009
rewriter.setInsertionPointAfter(loop);
1009-
Value result = genTuple(rewriter, loc, dstType, loop->getResults());
10101010
// Deallocate the buffers on exit of the full loop nest.
10111011
Operation *parent = getTop(op);
10121012
rewriter.setInsertionPointAfter(parent);
10131013
rewriter.create<memref::DeallocOp>(loc, values);
10141014
rewriter.create<memref::DeallocOp>(loc, filled);
10151015
rewriter.create<memref::DeallocOp>(loc, added);
10161016
// Replace operation with resulting memrefs.
1017-
rewriter.replaceOp(op, result);
1017+
rewriter.replaceOpWithMultiple(op, {loop->getResults()});
10181018
return success();
10191019
}
10201020
};
@@ -1041,8 +1041,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
10411041
params, /*genCall=*/true);
10421042
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
10431043
// Replace operation with resulting memrefs.
1044-
rewriter.replaceOp(op,
1045-
genTuple(rewriter, loc, op.getDest().getType(), ret));
1044+
rewriter.replaceOpWithMultiple(op, {ret});
10461045
return success();
10471046
}
10481047
};
@@ -1215,8 +1214,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
12151214
return true;
12161215
});
12171216

1218-
rewriter.replaceOp(
1219-
op, genTuple(rewriter, loc, op.getResult().getType(), fields));
1217+
rewriter.replaceOpWithMultiple(op, {fields});
12201218
return success();
12211219
}
12221220
};
@@ -1271,8 +1269,7 @@ class SparseExtractSliceConverter
12711269
// NOTE: we can not generate tuples directly from descriptor here, as the
12721270
// descriptor is holding the original type, yet we want the slice type
12731271
// here (they shared every memref but with an updated specifier).
1274-
rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
1275-
desc.getFields()));
1272+
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
12761273
return success();
12771274
}
12781275
};
@@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
14031400
}
14041401
desc.setValMemSize(rewriter, loc, memSize);
14051402

1406-
rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
1403+
rewriter.replaceOpWithMultiple(op, {desc.getFields()});
14071404
return success();
14081405
}
14091406
};
@@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
15771574
EmitCInterface::Off);
15781575

15791576
// Replace operation with resulting memrefs.
1580-
rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
1577+
rewriter.replaceOpWithMultiple(op, {fields});
15811578
return success();
15821579
}
15831580
};

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
5454
// The sparse tensor type converter (defined in Passes.h).
5555
//===----------------------------------------------------------------------===//
5656

57+
static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
58+
ValueRange inputs, Location loc) {
59+
if (!getSparseTensorEncoding(tp))
60+
// Not a sparse tensor.
61+
return Value();
62+
// Sparsifier knows how to cancel out these casts.
63+
return genTuple(builder, loc, tp, inputs);
64+
}
65+
5766
SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
5867
addConversion([](Type type) { return type; });
5968
addConversion(convertSparseTensorType);
6069

6170
// Required by scf.for 1:N type conversion.
62-
addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
63-
ValueRange inputs, Location loc) -> Value {
64-
if (!getSparseTensorEncoding(tp))
65-
// Not a sparse tensor.
66-
return Value();
67-
// Sparsifier knows how to cancel out these casts.
68-
return genTuple(builder, loc, tp, inputs);
69-
});
71+
addSourceMaterialization(materializeTuple);
72+
73+
// Required as a workaround until we have full 1:N support.
74+
addArgumentMaterialization(materializeTuple);
7075
}
7176

7277
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)