-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Transforms] Dialect Conversion: Add replaceOpWithMultiple
#115816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Dialect Conversion: Add replaceOpWithMultiple
#115816
Conversation
@llvm/pr-subscribers-mlir-func @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new function Note: This function is not an overload of This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments. This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using Depends on #114940. Patch is 29.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115816.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 7ef03b87179523..78729376507208 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -353,7 +353,7 @@ class OpBuilder : public Builder {
/// selected insertion point. (E.g., because they are defined in a nested
/// region or because they are not visible in an IsolatedFromAbove region.)
static InsertPoint after(ArrayRef<Value> values,
- const PostDominanceInfo &domInfo);
+ const PostDominanceInfo *domInfo = nullptr);
/// Returns true if this insert point is set.
bool isSet() const { return (block != nullptr); }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5e5957170e646c..e461b7d11602a0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }
- /// PatternRewriter hook for replacing an operation.
+ /// Replace the given operation with the new values. The number of op results
+ /// and replacement values must match. The types may differ: the dialect
+ /// conversion driver will reconcile any surviving type mismatches at the end
+ /// of the conversion process with source materializations. The given
+ /// operation is erased.
void replaceOp(Operation *op, ValueRange newValues) override;
- /// PatternRewriter hook for replacing an operation.
+ /// Replace the given operation with the results of the new op. The number of
+ /// op results must match. The types may differ: the dialect conversion
+ /// driver will reconcile any surviving type mismatches at the end of the
+ /// conversion process with source materializations. The original operation
+ /// is erased.
void replaceOp(Operation *op, Operation *newOp) override;
+ /// Replace the given operation with the new value range. The number of op
+ /// results and value ranges must match. If an original SSA value is replaced
+ /// by multiple SSA values (i.e., value range has more than 1 element), the
+ /// conversion driver will insert an argument materialization to convert the
+ /// N SSA values back into 1 SSA value of the original type. The given
+ /// operation is erased.
+ ///
+ /// Note: The argument materialization is a workaround until we have full 1:N
+ /// support in the dialect conversion. (It is going to disappear from both
+ /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+ void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+
/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
/// otherwise an assert will be issued.
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index de4aba2ed327db..a08764326a80b6 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
getTypeConverter()));
}
- // Create the new result types for the new `CallOp` and track the indices in
- // the new call op's results that correspond to the old call op's results.
- //
- // expandedResultIndices[i] = "list of new result indices that old result i
- // expanded to".
+ // Create the new result types for the new `CallOp` and track the number of
+ // replacement types for each original op result.
SmallVector<Type, 2> newResultTypes;
- SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
+ SmallVector<unsigned> expandedResultSizes;
for (Type resultType : op.getResultTypes()) {
unsigned oldSize = newResultTypes.size();
if (failed(typeConverter->convertType(resultType, newResultTypes)))
return failure();
- auto &resultMapping = expandedResultIndices.emplace_back();
- for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
- resultMapping.push_back(i);
+ expandedResultSizes.push_back(newResultTypes.size() - oldSize);
}
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
newResultTypes, newOperands);
- // Build a replacement value for each result to replace its uses. If a
- // result has multiple mapping values, it needs to be materialized as a
- // single value.
- SmallVector<Value, 2> replacedValues;
+ // Build a replacement value for each result to replace its uses.
+ SmallVector<ValueRange> replacedValues;
replacedValues.reserve(op.getNumResults());
+ unsigned startIdx = 0;
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
- auto decomposedValues = llvm::to_vector<6>(
- llvm::map_range(expandedResultIndices[i],
- [&](unsigned i) { return newCallOp.getResult(i); }));
- if (decomposedValues.empty()) {
- // No replacement is required.
- replacedValues.push_back(nullptr);
- } else if (decomposedValues.size() == 1) {
- replacedValues.push_back(decomposedValues.front());
- } else {
- // Materialize a single Value to replace the original Value.
- Value materialized = getTypeConverter()->materializeArgumentConversion(
- rewriter, op.getLoc(), op.getType(i), decomposedValues);
- replacedValues.push_back(materialized);
- }
+ ValueRange repl =
+ newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
+ replacedValues.push_back(repl);
+ startIdx += expandedResultSizes[i];
}
- rewriter.replaceOp(op, replacedValues);
+ rewriter.replaceOpWithMultiple(op, replacedValues);
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 062a0ea6cc47cb..09509278d7749a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
flattenOperands(adaptor.getOperands(), flattened);
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
finalRetTy, flattened);
- // (2) Create cast operation for sparse tensor returns.
- SmallVector<Value> castedRet;
+ // (2) Gather sparse tensor returns.
+ SmallVector<SmallVector<Value>> packedResultVals;
// Tracks the offset of current return value (of the original call)
// relative to the new call (after sparse tensor flattening);
unsigned retOffset = 0;
@@ -618,21 +618,27 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
- ValueRange fields(iterator_range<ResultRange::iterator>(
- newCall.result_begin() + retOffset,
- newCall.result_begin() + retOffset + flatSize));
- castedRet.push_back(genTuple(rewriter, loc, retType, fields));
+ packedResultVals.push_back(SmallVector<Value>());
+ llvm::append_range(packedResultVals.back(),
+ iterator_range<ResultRange::iterator>(
+ newCall.result_begin() + retOffset,
+ newCall.result_begin() + retOffset + flatSize));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
- castedRet.push_back(newCall.getResult(retOffset));
+ packedResultVals.push_back(SmallVector<Value>());
+ packedResultVals.back().push_back(newCall.getResult(retOffset));
retOffset++;
}
sparseFlat.clear();
}
- assert(castedRet.size() == op.getNumResults());
- rewriter.replaceOp(op, castedRet);
+ assert(packedResultVals.size() == op.getNumResults());
+ SmallVector<ValueRange> ranges;
+ ranges.reserve(packedResultVals.size());
+ for (const SmallVector<Value> &vec : packedResultVals)
+ ranges.push_back(ValueRange(vec));
+ rewriter.replaceOpWithMultiple(op, ranges);
return success();
}
};
@@ -776,7 +782,7 @@ class SparseTensorAllocConverter
// Reuses specifier.
fields.push_back(desc.getSpecifier());
assert(fields.size() == desc.getNumFields());
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -796,7 +802,7 @@ class SparseTensorAllocConverter
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -837,7 +843,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -893,7 +899,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1006,7 +1012,6 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<scf::YieldOp>(loc, insertRet);
rewriter.setInsertionPointAfter(loop);
- Value result = genTuple(rewriter, loc, dstType, loop->getResults());
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = getTop(op);
rewriter.setInsertionPointAfter(parent);
@@ -1014,7 +1019,7 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, result);
+ rewriter.replaceOpWithMultiple(op, {loop->getResults()});
return success();
}
};
@@ -1041,8 +1046,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op,
- genTuple(rewriter, loc, op.getDest().getType(), ret));
+ rewriter.replaceOpWithMultiple(op, {ret});
return success();
}
};
@@ -1215,8 +1219,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
return true;
});
- rewriter.replaceOp(
- op, genTuple(rewriter, loc, op.getResult().getType(), fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
@@ -1271,8 +1274,7 @@ class SparseExtractSliceConverter
// NOTE: we can not generate tuples directly from descriptor here, as the
// descriptor is holding the original type, yet we want the slice type
// here (they shared every memref but with an updated specifier).
- rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
- desc.getFields()));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1403,7 +1405,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setValMemSize(rewriter, loc, memSize);
- rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1577,7 +1579,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
EmitCInterface::Off);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index a3db50573c2720..834e3634cc130d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
// The sparse tensor type converter (defined in Passes.h).
//===----------------------------------------------------------------------===//
+static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
+ ValueRange inputs, Location loc) {
+ if (!getSparseTensorEncoding(tp))
+ // Not a sparse tensor.
+ return Value();
+ // Sparsifier knows how to cancel out these casts.
+ return genTuple(builder, loc, tp, inputs);
+}
+
SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
addConversion([](Type type) { return type; });
addConversion(convertSparseTensorType);
// Required by scf.for 1:N type conversion.
- addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
- ValueRange inputs, Location loc) -> Value {
- if (!getSparseTensorEncoding(tp))
- // Not a sparse tensor.
- return Value();
- // Sparsifier knows how to cancel out these casts.
- return genTuple(builder, loc, tp, inputs);
- });
+ addSourceMaterialization(materializeTuple);
+
+ // Required as a workaround until we have full 1:N support.
+ addArgumentMaterialization(materializeTuple);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 4714c3cace6c78..e85a86e94282ec 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -645,7 +645,7 @@ void OpBuilder::cloneRegionBefore(Region ®ion, Block *before) {
OpBuilder::InsertPoint
OpBuilder::InsertPoint::after(ArrayRef<Value> values,
- const PostDominanceInfo &domInfo) {
+ const PostDominanceInfo *domInfo) {
// Helper function that computes the point after v's definition.
auto computeAfterIp = [](Value v) -> std::pair<Block *, Block::iterator> {
if (auto blockArg = dyn_cast<BlockArgument>(v))
@@ -658,12 +658,18 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
assert(!values.empty() && "expected at least one Value");
auto [block, blockIt] = computeAfterIp(values.front());
+ if (values.size() == 1) {
+ // Fast path: There is only one value.
+ return InsertPoint(block, blockIt);
+ }
+
// Check the other values one-by-one and update the insertion point if
// needed.
+ assert(domInfo && "domInfo expected if >1 values");
for (Value v : values.drop_front()) {
auto [candidateBlock, candidateBlockIt] = computeAfterIp(v);
- if (domInfo.postDominantes(candidateBlock, candidateBlockIt, block,
- blockIt)) {
+ if (domInfo->postDominantes(candidateBlock, candidateBlockIt, block,
+ blockIt)) {
// The point after v's definition post-dominates the current (and all
// previous) insertion points. Note: Post-dominance is transitive.
block = candidateBlock;
@@ -671,8 +677,8 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
continue;
}
- if (!domInfo.postDominantes(block, blockIt, candidateBlock,
- candidateBlockIt)) {
+ if (!domInfo->postDominantes(block, blockIt, candidateBlock,
+ candidateBlockIt)) {
// The point after v's definition and the current insertion point do not
// post-dominate each other. Therefore, there is no insertion point that
// post-dominates all values.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0a62628b9ad240..2f6c0a1ab0bd3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,20 +54,14 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}
-/// Helper function that computes an insertion point where the given value is
-/// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(Value value) {
- Block *insertBlock = value.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(value))
- insertPt = ++inputRes.getOwner()->getIterator();
- return OpBuilder::InsertPoint(insertBlock, insertPt);
-}
-
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
+/// A list of replacement SSA values. Optimized for the common case of a single
+/// SSA value.
+using ReplacementValues = SmallVector<Value, 1>;
+
namespace {
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -818,6 +813,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Type originalType,
const TypeConverter *converter);
+ /// Build an N:1 materialization for the given original value that was
+ /// replaced with the given replacement values.
+ ///
+ /// This is a workaround around incomplete 1:N support in the dialect
+ /// conversion driver. The conversion mapping can store only 1:1 replacements
+ /// and the conversion patterns only support single Value replacements in the
+ /// adaptor, so N values must be converted back to a single value. This
+ /// function will be deleted when full 1:N support has been added.
+ ///
+ /// This function inserts an argument materialization back to the original
+ /// type, followed by a target materialization to the legalized type (if
+ /// applicable).
+ void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
+ ValueRange replacements, Value originalValue,
+ const TypeConverter *converter);
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -827,7 +838,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
OpBuilder::InsertPoint previous) override;
/// Notifies that an op is about to be replaced with the given values.
- void notifyOpReplaced(Operation *op, ValueRange newValues);
+ void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -1147,8 +1158,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// that the value was replaced with a value of different type and no
// ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new function Note: This function is not an overload of This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments. This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using Depends on #114940. Patch is 29.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115816.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 7ef03b87179523..78729376507208 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -353,7 +353,7 @@ class OpBuilder : public Builder {
/// selected insertion point. (E.g., because they are defined in a nested
/// region or because they are not visible in an IsolatedFromAbove region.)
static InsertPoint after(ArrayRef<Value> values,
- const PostDominanceInfo &domInfo);
+ const PostDominanceInfo *domInfo = nullptr);
/// Returns true if this insert point is set.
bool isSet() const { return (block != nullptr); }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5e5957170e646c..e461b7d11602a0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// patterns even if a failure is encountered during the rewrite step.
bool canRecoverFromRewriteFailure() const override { return true; }
- /// PatternRewriter hook for replacing an operation.
+ /// Replace the given operation with the new values. The number of op results
+ /// and replacement values must match. The types may differ: the dialect
+ /// conversion driver will reconcile any surviving type mismatches at the end
+ /// of the conversion process with source materializations. The given
+ /// operation is erased.
void replaceOp(Operation *op, ValueRange newValues) override;
- /// PatternRewriter hook for replacing an operation.
+ /// Replace the given operation with the results of the new op. The number of
+ /// op results must match. The types may differ: the dialect conversion
+ /// driver will reconcile any surviving type mismatches at the end of the
+ /// conversion process with source materializations. The original operation
+ /// is erased.
void replaceOp(Operation *op, Operation *newOp) override;
+ /// Replace the given operation with the new value range. The number of op
+ /// results and value ranges must match. If an original SSA value is replaced
+ /// by multiple SSA values (i.e., value range has more than 1 element), the
+ /// conversion driver will insert an argument materialization to convert the
+ /// N SSA values back into 1 SSA value of the original type. The given
+ /// operation is erased.
+ ///
+ /// Note: The argument materialization is a workaround until we have full 1:N
+ /// support in the dialect conversion. (It is going to disappear from both
+ /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+ void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+
/// PatternRewriter hook for erasing a dead operation. The uses of this
/// operation *must* be made dead by the end of the conversion process,
/// otherwise an assert will be issued.
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index de4aba2ed327db..a08764326a80b6 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
getTypeConverter()));
}
- // Create the new result types for the new `CallOp` and track the indices in
- // the new call op's results that correspond to the old call op's results.
- //
- // expandedResultIndices[i] = "list of new result indices that old result i
- // expanded to".
+ // Create the new result types for the new `CallOp` and track the number of
+ // replacement types for each original op result.
SmallVector<Type, 2> newResultTypes;
- SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
+ SmallVector<unsigned> expandedResultSizes;
for (Type resultType : op.getResultTypes()) {
unsigned oldSize = newResultTypes.size();
if (failed(typeConverter->convertType(resultType, newResultTypes)))
return failure();
- auto &resultMapping = expandedResultIndices.emplace_back();
- for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
- resultMapping.push_back(i);
+ expandedResultSizes.push_back(newResultTypes.size() - oldSize);
}
CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
newResultTypes, newOperands);
- // Build a replacement value for each result to replace its uses. If a
- // result has multiple mapping values, it needs to be materialized as a
- // single value.
- SmallVector<Value, 2> replacedValues;
+ // Build a replacement value for each result to replace its uses.
+ SmallVector<ValueRange> replacedValues;
replacedValues.reserve(op.getNumResults());
+ unsigned startIdx = 0;
for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
- auto decomposedValues = llvm::to_vector<6>(
- llvm::map_range(expandedResultIndices[i],
- [&](unsigned i) { return newCallOp.getResult(i); }));
- if (decomposedValues.empty()) {
- // No replacement is required.
- replacedValues.push_back(nullptr);
- } else if (decomposedValues.size() == 1) {
- replacedValues.push_back(decomposedValues.front());
- } else {
- // Materialize a single Value to replace the original Value.
- Value materialized = getTypeConverter()->materializeArgumentConversion(
- rewriter, op.getLoc(), op.getType(i), decomposedValues);
- replacedValues.push_back(materialized);
- }
+ ValueRange repl =
+ newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
+ replacedValues.push_back(repl);
+ startIdx += expandedResultSizes[i];
}
- rewriter.replaceOp(op, replacedValues);
+ rewriter.replaceOpWithMultiple(op, replacedValues);
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 062a0ea6cc47cb..09509278d7749a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
flattenOperands(adaptor.getOperands(), flattened);
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
finalRetTy, flattened);
- // (2) Create cast operation for sparse tensor returns.
- SmallVector<Value> castedRet;
+ // (2) Gather sparse tensor returns.
+ SmallVector<SmallVector<Value>> packedResultVals;
// Tracks the offset of current return value (of the original call)
// relative to the new call (after sparse tensor flattening);
unsigned retOffset = 0;
@@ -618,21 +618,27 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
- ValueRange fields(iterator_range<ResultRange::iterator>(
- newCall.result_begin() + retOffset,
- newCall.result_begin() + retOffset + flatSize));
- castedRet.push_back(genTuple(rewriter, loc, retType, fields));
+ packedResultVals.push_back(SmallVector<Value>());
+ llvm::append_range(packedResultVals.back(),
+ iterator_range<ResultRange::iterator>(
+ newCall.result_begin() + retOffset,
+ newCall.result_begin() + retOffset + flatSize));
retOffset += flatSize;
} else {
// If this is an 1:1 conversion, no need for casting.
- castedRet.push_back(newCall.getResult(retOffset));
+ packedResultVals.push_back(SmallVector<Value>());
+ packedResultVals.back().push_back(newCall.getResult(retOffset));
retOffset++;
}
sparseFlat.clear();
}
- assert(castedRet.size() == op.getNumResults());
- rewriter.replaceOp(op, castedRet);
+ assert(packedResultVals.size() == op.getNumResults());
+ SmallVector<ValueRange> ranges;
+ ranges.reserve(packedResultVals.size());
+ for (const SmallVector<Value> &vec : packedResultVals)
+ ranges.push_back(ValueRange(vec));
+ rewriter.replaceOpWithMultiple(op, ranges);
return success();
}
};
@@ -776,7 +782,7 @@ class SparseTensorAllocConverter
// Reuses specifier.
fields.push_back(desc.getSpecifier());
assert(fields.size() == desc.getNumFields());
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -796,7 +802,7 @@ class SparseTensorAllocConverter
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -837,7 +843,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
sizeHint, lvlSizesValues, fields);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
@@ -893,7 +899,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
if (op.getHasInserts())
genEndInsert(rewriter, op.getLoc(), desc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1006,7 +1012,6 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<scf::YieldOp>(loc, insertRet);
rewriter.setInsertionPointAfter(loop);
- Value result = genTuple(rewriter, loc, dstType, loop->getResults());
// Deallocate the buffers on exit of the full loop nest.
Operation *parent = getTop(op);
rewriter.setInsertionPointAfter(parent);
@@ -1014,7 +1019,7 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
rewriter.create<memref::DeallocOp>(loc, filled);
rewriter.create<memref::DeallocOp>(loc, added);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, result);
+ rewriter.replaceOpWithMultiple(op, {loop->getResults()});
return success();
}
};
@@ -1041,8 +1046,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op,
- genTuple(rewriter, loc, op.getDest().getType(), ret));
+ rewriter.replaceOpWithMultiple(op, {ret});
return success();
}
};
@@ -1215,8 +1219,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
return true;
});
- rewriter.replaceOp(
- op, genTuple(rewriter, loc, op.getResult().getType(), fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
@@ -1271,8 +1274,7 @@ class SparseExtractSliceConverter
// NOTE: we can not generate tuples directly from descriptor here, as the
// descriptor is holding the original type, yet we want the slice type
// here (they shared every memref but with an updated specifier).
- rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
- desc.getFields()));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1403,7 +1405,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setValMemSize(rewriter, loc, memSize);
- rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+ rewriter.replaceOpWithMultiple(op, {desc.getFields()});
return success();
}
};
@@ -1577,7 +1579,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
EmitCInterface::Off);
// Replace operation with resulting memrefs.
- rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
+ rewriter.replaceOpWithMultiple(op, {fields});
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index a3db50573c2720..834e3634cc130d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
// The sparse tensor type converter (defined in Passes.h).
//===----------------------------------------------------------------------===//
+static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
+ ValueRange inputs, Location loc) {
+ if (!getSparseTensorEncoding(tp))
+ // Not a sparse tensor.
+ return Value();
+ // Sparsifier knows how to cancel out these casts.
+ return genTuple(builder, loc, tp, inputs);
+}
+
SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
addConversion([](Type type) { return type; });
addConversion(convertSparseTensorType);
// Required by scf.for 1:N type conversion.
- addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
- ValueRange inputs, Location loc) -> Value {
- if (!getSparseTensorEncoding(tp))
- // Not a sparse tensor.
- return Value();
- // Sparsifier knows how to cancel out these casts.
- return genTuple(builder, loc, tp, inputs);
- });
+ addSourceMaterialization(materializeTuple);
+
+ // Required as a workaround until we have full 1:N support.
+ addArgumentMaterialization(materializeTuple);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 4714c3cace6c78..e85a86e94282ec 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -645,7 +645,7 @@ void OpBuilder::cloneRegionBefore(Region ®ion, Block *before) {
OpBuilder::InsertPoint
OpBuilder::InsertPoint::after(ArrayRef<Value> values,
- const PostDominanceInfo &domInfo) {
+ const PostDominanceInfo *domInfo) {
// Helper function that computes the point after v's definition.
auto computeAfterIp = [](Value v) -> std::pair<Block *, Block::iterator> {
if (auto blockArg = dyn_cast<BlockArgument>(v))
@@ -658,12 +658,18 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
assert(!values.empty() && "expected at least one Value");
auto [block, blockIt] = computeAfterIp(values.front());
+ if (values.size() == 1) {
+ // Fast path: There is only one value.
+ return InsertPoint(block, blockIt);
+ }
+
// Check the other values one-by-one and update the insertion point if
// needed.
+ assert(domInfo && "domInfo expected if >1 values");
for (Value v : values.drop_front()) {
auto [candidateBlock, candidateBlockIt] = computeAfterIp(v);
- if (domInfo.postDominantes(candidateBlock, candidateBlockIt, block,
- blockIt)) {
+ if (domInfo->postDominantes(candidateBlock, candidateBlockIt, block,
+ blockIt)) {
// The point after v's definition post-dominates the current (and all
// previous) insertion points. Note: Post-dominance is transitive.
block = candidateBlock;
@@ -671,8 +677,8 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
continue;
}
- if (!domInfo.postDominantes(block, blockIt, candidateBlock,
- candidateBlockIt)) {
+ if (!domInfo->postDominantes(block, blockIt, candidateBlock,
+ candidateBlockIt)) {
// The point after v's definition and the current insertion point do not
// post-dominate each other. Therefore, there is no insertion point that
// post-dominates all values.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0a62628b9ad240..2f6c0a1ab0bd3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,20 +54,14 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
});
}
-/// Helper function that computes an insertion point where the given value is
-/// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(Value value) {
- Block *insertBlock = value.getParentBlock();
- Block::iterator insertPt = insertBlock->begin();
- if (OpResult inputRes = dyn_cast<OpResult>(value))
- insertPt = ++inputRes.getOwner()->getIterator();
- return OpBuilder::InsertPoint(insertBlock, insertPt);
-}
-
//===----------------------------------------------------------------------===//
// ConversionValueMapping
//===----------------------------------------------------------------------===//
+/// A list of replacement SSA values. Optimized for the common case of a single
+/// SSA value.
+using ReplacementValues = SmallVector<Value, 1>;
+
namespace {
/// This class wraps a IRMapping to provide recursive lookup
/// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -818,6 +813,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Type originalType,
const TypeConverter *converter);
+ /// Build an N:1 materialization for the given original value that was
+ /// replaced with the given replacement values.
+ ///
+ /// This is a workaround around incomplete 1:N support in the dialect
+ /// conversion driver. The conversion mapping can store only 1:1 replacements
+ /// and the conversion patterns only support single Value replacements in the
+ /// adaptor, so N values must be converted back to a single value. This
+ /// function will be deleted when full 1:N support has been added.
+ ///
+ /// This function inserts an argument materialization back to the original
+ /// type, followed by a target materialization to the legalized type (if
+ /// applicable).
+ void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
+ ValueRange replacements, Value originalValue,
+ const TypeConverter *converter);
+
//===--------------------------------------------------------------------===//
// Rewriter Notification Hooks
//===--------------------------------------------------------------------===//
@@ -827,7 +838,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
OpBuilder::InsertPoint previous) override;
/// Notifies that an op is about to be replaced with the given values.
- void notifyOpReplaced(Operation *op, ValueRange newValues);
+ void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
/// Notifies that a block is about to be erased.
void notifyBlockIsBeingErased(Block *block);
@@ -1147,8 +1158,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
// that the value was replaced with a value of different type and no
// ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good to me besides some code nits 🙂
The only very high level conern I have is the use and need of InsertionPoint::after
. As I understand, this is a not-yet-implemented optimization for the future for the purpose of reusing materializations by calculating a better insertion point where it may dominate more operations which might need exactly that materialization after pattern application as well.
Is the current complexity of calculating the insertion point (and potentially calculatin the post-dom tree, a O(n log n)
operation) worth the gain of the optimization? I am thinking that the better insertion point insertion logic should probably be post-poned until we can measure its effectivness (and avoid the risk of a premature optimization) and have something simpler and working that does not worsen the status-quo for now instead.
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Outdated
Show resolved
Hide resolved
Yes, that's correct. But this optimization is already implemented. And it is actually quite deeply ingrained into the dialect conversion driver. SSA values are not replaced directly, but in a delayed fashion. The replacement value is stored in the Note: Computing an insertion point for a 1:1 materialization (or a 1:N block signature conversion) is much easier because there is only one SSA value (or one block). That's why computing the insertion point was trivial until now.
This PR also makes the In case of N:1 materializations, the implementation uses However, I believe Long term, the |
You can test this locally with the following command:git-clang-format --diff 6c9256dc5cda9184e295bc8d00be35e61b3be892 001453c12e18c8e5f8d1b762cf5f67d56ec542e9 --extensions cpp,h -- mlir/include/mlir/Transforms/DialectConversion.h mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6b154dd161..bf7b3f9bec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -618,7 +618,7 @@ public:
assert(!sparseFlat.empty());
if (sparseFlat.size() > 1) {
auto flatSize = sparseFlat.size();
- packedResultVals.emplace_back();
+ packedResultVals.emplace_back();
llvm::append_range(packedResultVals.back(),
newCall.getResults().slice(retOffset, flatSize));
retOffset += flatSize;
|
One question I have here is why not just insert right before the operation that is being replaced? What is the need for trying to insert after one of the values? Wouldnt' that remove the need for all of this complexity? It gives you a singular place to insert. |
Actually, I think you are right. I was thinking of a case where two ops (in different locations) were replaced with the same 1:N replacement values. And I thought that we store one mapping pair in that case. But we actually store one per original value (and not per replacement value). I just noticed this as I was trying to write down a counter-example... Let me update the PR, I think this will simplify things quite a bit. Edit: Dropped the insertion point computation. It is actually not needed, as @zero9178 and @River707 commented. |
b59db46
to
d0d0632
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thank you so much for the discussion. Makes perfect sense to me now.
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
Outdated
Show resolved
Hide resolved
d0d0632
to
001453c
Compare
Apply suggestions from code review Co-authored-by: Markus Böck <[email protected]> address comments
001453c
to
f5ed959
Compare
This commit adds a new function
ConversionPatternRewriter::replaceOpWithMultiple
. This function is similar toreplaceOp
, but it accepts multipleValueRange
replacements, one per op result.Note: This function is not an overload of
replaceOp
because of ambiguous overload resolution that would make the API difficult to use.This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments.
This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using
replaceOpWithMultiple
.