Skip to content

[mlir][Transforms] Dialect Conversion: Add replaceOpWithMultiple #115816

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 12, 2024

This commit adds a new function ConversionPatternRewriter::replaceOpWithMultiple. This function is similar to replaceOp, but it accepts multiple ValueRange replacements, one per op result.

Note: This function is not an overload of replaceOp because of ambiguous overload resolution that would make the API difficult to use.

This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments.

This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using replaceOpWithMultiple.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:sparse Sparse compiler in MLIR mlir mlir:func labels Nov 12, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-mlir-func

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new function ConversionPatternRewriter::replaceOpWithMultiple. This function is similar to replaceOp, but it accepts multiple ValueRange replacements, one per op result.

Note: This function is not an overload of replaceOp because of ambiguous overload resolution that would make the API difficult to use.

This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments.

This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using replaceOpWithMultiple.

Depends on #114940.


Patch is 29.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115816.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/Builders.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+22-2)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+12-28)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+25-23)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (+13-8)
  • (modified) mlir/lib/IR/Builders.cpp (+11-5)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+102-51)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 7ef03b87179523..78729376507208 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -353,7 +353,7 @@ class OpBuilder : public Builder {
     /// selected insertion point. (E.g., because they are defined in a nested
     /// region or because they are not visible in an IsolatedFromAbove region.)
     static InsertPoint after(ArrayRef<Value> values,
-                             const PostDominanceInfo &domInfo);
+                             const PostDominanceInfo *domInfo = nullptr);
 
     /// Returns true if this insert point is set.
     bool isSet() const { return (block != nullptr); }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5e5957170e646c..e461b7d11602a0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// PatternRewriter hook for replacing an operation.
+  /// Replace the given operation with the new values. The number of op results
+  /// and replacement values must match. The types may differ: the dialect
+  /// conversion driver will reconcile any surviving type mismatches at the end
+  /// of the conversion process with source materializations. The given
+  /// operation is erased.
   void replaceOp(Operation *op, ValueRange newValues) override;
 
-  /// PatternRewriter hook for replacing an operation.
+  /// Replace the given operation with the results of the new op. The number of
+  /// op results must match. The types may differ: the dialect conversion
+  /// driver will reconcile any surviving type mismatches at the end of the
+  /// conversion process with source materializations. The original operation
+  /// is erased.
   void replaceOp(Operation *op, Operation *newOp) override;
 
+  /// Replace the given operation with the new value range. The number of op
+  /// results and value ranges must match. If an original SSA value is replaced
+  /// by multiple SSA values (i.e., value range has more than 1 element), the
+  /// conversion driver will insert an argument materialization to convert the
+  /// N SSA values back into 1 SSA value of the original type. The given
+  /// operation is erased.
+  ///
+  /// Note: The argument materialization is a workaround until we have full 1:N
+  /// support in the dialect conversion. (It is going to disappear from both
+  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+  void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,
   /// otherwise an assert will be issued.
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index de4aba2ed327db..a08764326a80b6 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
                                         getTypeConverter()));
     }
 
-    // Create the new result types for the new `CallOp` and track the indices in
-    // the new call op's results that correspond to the old call op's results.
-    //
-    // expandedResultIndices[i] = "list of new result indices that old result i
-    // expanded to".
+    // Create the new result types for the new `CallOp` and track the number of
+    // replacement types for each original op result.
     SmallVector<Type, 2> newResultTypes;
-    SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
+    SmallVector<unsigned> expandedResultSizes;
     for (Type resultType : op.getResultTypes()) {
       unsigned oldSize = newResultTypes.size();
       if (failed(typeConverter->convertType(resultType, newResultTypes)))
         return failure();
-      auto &resultMapping = expandedResultIndices.emplace_back();
-      for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
-        resultMapping.push_back(i);
+      expandedResultSizes.push_back(newResultTypes.size() - oldSize);
     }
 
     CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
                                                newResultTypes, newOperands);
 
-    // Build a replacement value for each result to replace its uses. If a
-    // result has multiple mapping values, it needs to be materialized as a
-    // single value.
-    SmallVector<Value, 2> replacedValues;
+    // Build a replacement value for each result to replace its uses.
+    SmallVector<ValueRange> replacedValues;
     replacedValues.reserve(op.getNumResults());
+    unsigned startIdx = 0;
     for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
-      auto decomposedValues = llvm::to_vector<6>(
-          llvm::map_range(expandedResultIndices[i],
-                          [&](unsigned i) { return newCallOp.getResult(i); }));
-      if (decomposedValues.empty()) {
-        // No replacement is required.
-        replacedValues.push_back(nullptr);
-      } else if (decomposedValues.size() == 1) {
-        replacedValues.push_back(decomposedValues.front());
-      } else {
-        // Materialize a single Value to replace the original Value.
-        Value materialized = getTypeConverter()->materializeArgumentConversion(
-            rewriter, op.getLoc(), op.getType(i), decomposedValues);
-        replacedValues.push_back(materialized);
-      }
+      ValueRange repl =
+          newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
+      replacedValues.push_back(repl);
+      startIdx += expandedResultSizes[i];
     }
-    rewriter.replaceOp(op, replacedValues);
+    rewriter.replaceOpWithMultiple(op, replacedValues);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 062a0ea6cc47cb..09509278d7749a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     flattenOperands(adaptor.getOperands(), flattened);
     auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
                                                  finalRetTy, flattened);
-    // (2) Create cast operation for sparse tensor returns.
-    SmallVector<Value> castedRet;
+    // (2) Gather sparse tensor returns.
+    SmallVector<SmallVector<Value>> packedResultVals;
     // Tracks the offset of current return value (of the original call)
     // relative to the new call (after sparse tensor flattening);
     unsigned retOffset = 0;
@@ -618,21 +618,27 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
       assert(!sparseFlat.empty());
       if (sparseFlat.size() > 1) {
         auto flatSize = sparseFlat.size();
-        ValueRange fields(iterator_range<ResultRange::iterator>(
-            newCall.result_begin() + retOffset,
-            newCall.result_begin() + retOffset + flatSize));
-        castedRet.push_back(genTuple(rewriter, loc, retType, fields));
+        packedResultVals.push_back(SmallVector<Value>());
+        llvm::append_range(packedResultVals.back(),
+                           iterator_range<ResultRange::iterator>(
+                               newCall.result_begin() + retOffset,
+                               newCall.result_begin() + retOffset + flatSize));
         retOffset += flatSize;
       } else {
         // If this is an 1:1 conversion, no need for casting.
-        castedRet.push_back(newCall.getResult(retOffset));
+        packedResultVals.push_back(SmallVector<Value>());
+        packedResultVals.back().push_back(newCall.getResult(retOffset));
         retOffset++;
       }
       sparseFlat.clear();
     }
 
-    assert(castedRet.size() == op.getNumResults());
-    rewriter.replaceOp(op, castedRet);
+    assert(packedResultVals.size() == op.getNumResults());
+    SmallVector<ValueRange> ranges;
+    ranges.reserve(packedResultVals.size());
+    for (const SmallVector<Value> &vec : packedResultVals)
+      ranges.push_back(ValueRange(vec));
+    rewriter.replaceOpWithMultiple(op, ranges);
     return success();
   }
 };
@@ -776,7 +782,7 @@ class SparseTensorAllocConverter
       // Reuses specifier.
       fields.push_back(desc.getSpecifier());
       assert(fields.size() == desc.getNumFields());
-      rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+      rewriter.replaceOpWithMultiple(op, {fields});
       return success();
     }
 
@@ -796,7 +802,7 @@ class SparseTensorAllocConverter
                       sizeHint, lvlSizesValues, fields);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 
@@ -837,7 +843,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
                       sizeHint, lvlSizesValues, fields);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 
@@ -893,7 +899,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
     if (op.getHasInserts())
       genEndInsert(rewriter, op.getLoc(), desc);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1006,7 +1012,6 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     rewriter.create<scf::YieldOp>(loc, insertRet);
 
     rewriter.setInsertionPointAfter(loop);
-    Value result = genTuple(rewriter, loc, dstType, loop->getResults());
     // Deallocate the buffers on exit of the full loop nest.
     Operation *parent = getTop(op);
     rewriter.setInsertionPointAfter(parent);
@@ -1014,7 +1019,7 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     rewriter.create<memref::DeallocOp>(loc, filled);
     rewriter.create<memref::DeallocOp>(loc, added);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, result);
+    rewriter.replaceOpWithMultiple(op, {loop->getResults()});
     return success();
   }
 };
@@ -1041,8 +1046,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
                                     params, /*genCall=*/true);
     SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op,
-                       genTuple(rewriter, loc, op.getDest().getType(), ret));
+    rewriter.replaceOpWithMultiple(op, {ret});
     return success();
   }
 };
@@ -1215,8 +1219,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
           return true;
         });
 
-    rewriter.replaceOp(
-        op, genTuple(rewriter, loc, op.getResult().getType(), fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 };
@@ -1271,8 +1274,7 @@ class SparseExtractSliceConverter
     // NOTE: we can not generate tuples directly from descriptor here, as the
     // descriptor is holding the original type, yet we want the slice type
     // here (they shared every memref but with an updated specifier).
-    rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
-                                    desc.getFields()));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1403,7 +1405,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
     }
     desc.setValMemSize(rewriter, loc, memSize);
 
-    rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1577,7 +1579,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
                    EmitCInterface::Off);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index a3db50573c2720..834e3634cc130d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
 // The sparse tensor type converter (defined in Passes.h).
 //===----------------------------------------------------------------------===//
 
+static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
+                              ValueRange inputs, Location loc) {
+  if (!getSparseTensorEncoding(tp))
+    // Not a sparse tensor.
+    return Value();
+  // Sparsifier knows how to cancel out these casts.
+  return genTuple(builder, loc, tp, inputs);
+}
+
 SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
   addConversion([](Type type) { return type; });
   addConversion(convertSparseTensorType);
 
   // Required by scf.for 1:N type conversion.
-  addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
-                              ValueRange inputs, Location loc) -> Value {
-    if (!getSparseTensorEncoding(tp))
-      // Not a sparse tensor.
-      return Value();
-    // Sparsifier knows how to cancel out these casts.
-    return genTuple(builder, loc, tp, inputs);
-  });
+  addSourceMaterialization(materializeTuple);
+
+  // Required as a workaround until we have full 1:N support.
+  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 4714c3cace6c78..e85a86e94282ec 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -645,7 +645,7 @@ void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
 
 OpBuilder::InsertPoint
 OpBuilder::InsertPoint::after(ArrayRef<Value> values,
-                              const PostDominanceInfo &domInfo) {
+                              const PostDominanceInfo *domInfo) {
   // Helper function that computes the point after v's definition.
   auto computeAfterIp = [](Value v) -> std::pair<Block *, Block::iterator> {
     if (auto blockArg = dyn_cast<BlockArgument>(v))
@@ -658,12 +658,18 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
   assert(!values.empty() && "expected at least one Value");
   auto [block, blockIt] = computeAfterIp(values.front());
 
+  if (values.size() == 1) {
+    // Fast path: There is only one value.
+    return InsertPoint(block, blockIt);
+  }
+
   // Check the other values one-by-one and update the insertion point if
   // needed.
+  assert(domInfo && "domInfo expected if >1 values");
   for (Value v : values.drop_front()) {
     auto [candidateBlock, candidateBlockIt] = computeAfterIp(v);
-    if (domInfo.postDominantes(candidateBlock, candidateBlockIt, block,
-                               blockIt)) {
+    if (domInfo->postDominantes(candidateBlock, candidateBlockIt, block,
+                                blockIt)) {
       // The point after v's definition post-dominates the current (and all
       // previous) insertion points. Note: Post-dominance is transitive.
       block = candidateBlock;
@@ -671,8 +677,8 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
       continue;
     }
 
-    if (!domInfo.postDominantes(block, blockIt, candidateBlock,
-                                candidateBlockIt)) {
+    if (!domInfo->postDominantes(block, blockIt, candidateBlock,
+                                 candidateBlockIt)) {
       // The point after v's definition and the current insertion point do not
       // post-dominate each other. Therefore, there is no insertion point that
       // post-dominates all values.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0a62628b9ad240..2f6c0a1ab0bd3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Iterators.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,20 +54,14 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
-/// Helper function that computes an insertion point where the given value is
-/// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(Value value) {
-  Block *insertBlock = value.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(value))
-    insertPt = ++inputRes.getOwner()->getIterator();
-  return OpBuilder::InsertPoint(insertBlock, insertPt);
-}
-
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
+/// A list of replacement SSA values. Optimized for the common case of a single
+/// SSA value.
+using ReplacementValues = SmallVector<Value, 1>;
+
 namespace {
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -818,6 +813,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        Type originalType,
                                        const TypeConverter *converter);
 
+  /// Build an N:1 materialization for the given original value that was
+  /// replaced with the given replacement values.
+  ///
+  /// This is a workaround around incomplete 1:N support in the dialect
+  /// conversion driver. The conversion mapping can store only 1:1 replacements
+  /// and the conversion patterns only support single Value replacements in the
+  /// adaptor, so N values must be converted back to a single value. This
+  /// function will be deleted when full 1:N support has been added.
+  ///
+  /// This function inserts an argument materialization back to the original
+  /// type, followed by a target materialization to the legalized type (if
+  /// applicable).
+  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
+                                 ValueRange replacements, Value originalValue,
+                                 const TypeConverter *converter);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -827,7 +838,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                OpBuilder::InsertPoint previous) override;
 
   /// Notifies that an op is about to be replaced with the given values.
-  void notifyOpReplaced(Operation *op, ValueRange newValues);
+  void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
@@ -1147,8 +1158,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       // that the value was replaced with a value of different type and no
       // ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

This commit adds a new function ConversionPatternRewriter::replaceOpWithMultiple. This function is similar to replaceOp, but it accepts multiple ValueRange replacements, one per op result.

Note: This function is not an overload of replaceOp because of ambiguous overload resolution that would make the API difficult to use.

This commit aligns "block signature conversions" with "op replacements": both support 1:N replacements now. Due to incomplete 1:N support in the dialect conversion driver, an argument materialization is inserted when an SSA value is replaced with multiple values; same as block signature conversions already work around the problem. These argument materializations are going to be removed in a subsequent commit that adds full 1:N support. The purpose of this PR is to add missing features gradually in small increments.

This commit also updates two MLIR transformations that have their custom workarounds around missing 1:N support. These can already start using replaceOpWithMultiple.

Depends on #114940.


Patch is 29.03 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115816.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/Builders.h (+1-1)
  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+22-2)
  • (modified) mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp (+12-28)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+25-23)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp (+13-8)
  • (modified) mlir/lib/IR/Builders.cpp (+11-5)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+102-51)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 7ef03b87179523..78729376507208 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -353,7 +353,7 @@ class OpBuilder : public Builder {
     /// selected insertion point. (E.g., because they are defined in a nested
     /// region or because they are not visible in an IsolatedFromAbove region.)
     static InsertPoint after(ArrayRef<Value> values,
-                             const PostDominanceInfo &domInfo);
+                             const PostDominanceInfo *domInfo = nullptr);
 
     /// Returns true if this insert point is set.
     bool isSet() const { return (block != nullptr); }
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5e5957170e646c..e461b7d11602a0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -795,12 +795,32 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// patterns even if a failure is encountered during the rewrite step.
   bool canRecoverFromRewriteFailure() const override { return true; }
 
-  /// PatternRewriter hook for replacing an operation.
+  /// Replace the given operation with the new values. The number of op results
+  /// and replacement values must match. The types may differ: the dialect
+  /// conversion driver will reconcile any surviving type mismatches at the end
+  /// of the conversion process with source materializations. The given
+  /// operation is erased.
   void replaceOp(Operation *op, ValueRange newValues) override;
 
-  /// PatternRewriter hook for replacing an operation.
+  /// Replace the given operation with the results of the new op. The number of
+  /// op results must match. The types may differ: the dialect conversion
+  /// driver will reconcile any surviving type mismatches at the end of the
+  /// conversion process with source materializations. The original operation
+  /// is erased.
   void replaceOp(Operation *op, Operation *newOp) override;
 
+  /// Replace the given operation with the new value range. The number of op
+  /// results and value ranges must match. If an original SSA value is replaced
+  /// by multiple SSA values (i.e., value range has more than 1 element), the
+  /// conversion driver will insert an argument materialization to convert the
+  /// N SSA values back into 1 SSA value of the original type. The given
+  /// operation is erased.
+  ///
+  /// Note: The argument materialization is a workaround until we have full 1:N
+  /// support in the dialect conversion. (It is going to disappear from both
+  /// `replaceOpWithMultiple` and `applySignatureConversion`.)
+  void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,
   /// otherwise an assert will be issued.
diff --git a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
index de4aba2ed327db..a08764326a80b6 100644
--- a/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
+++ b/mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp
@@ -141,47 +141,31 @@ struct DecomposeCallGraphTypesForCallOp : public OpConversionPattern<CallOp> {
                                         getTypeConverter()));
     }
 
-    // Create the new result types for the new `CallOp` and track the indices in
-    // the new call op's results that correspond to the old call op's results.
-    //
-    // expandedResultIndices[i] = "list of new result indices that old result i
-    // expanded to".
+    // Create the new result types for the new `CallOp` and track the number of
+    // replacement types for each original op result.
     SmallVector<Type, 2> newResultTypes;
-    SmallVector<SmallVector<unsigned, 2>, 4> expandedResultIndices;
+    SmallVector<unsigned> expandedResultSizes;
     for (Type resultType : op.getResultTypes()) {
       unsigned oldSize = newResultTypes.size();
       if (failed(typeConverter->convertType(resultType, newResultTypes)))
         return failure();
-      auto &resultMapping = expandedResultIndices.emplace_back();
-      for (unsigned i = oldSize, e = newResultTypes.size(); i < e; i++)
-        resultMapping.push_back(i);
+      expandedResultSizes.push_back(newResultTypes.size() - oldSize);
     }
 
     CallOp newCallOp = rewriter.create<CallOp>(op.getLoc(), op.getCalleeAttr(),
                                                newResultTypes, newOperands);
 
-    // Build a replacement value for each result to replace its uses. If a
-    // result has multiple mapping values, it needs to be materialized as a
-    // single value.
-    SmallVector<Value, 2> replacedValues;
+    // Build a replacement value for each result to replace its uses.
+    SmallVector<ValueRange> replacedValues;
     replacedValues.reserve(op.getNumResults());
+    unsigned startIdx = 0;
     for (unsigned i = 0, e = op.getNumResults(); i < e; ++i) {
-      auto decomposedValues = llvm::to_vector<6>(
-          llvm::map_range(expandedResultIndices[i],
-                          [&](unsigned i) { return newCallOp.getResult(i); }));
-      if (decomposedValues.empty()) {
-        // No replacement is required.
-        replacedValues.push_back(nullptr);
-      } else if (decomposedValues.size() == 1) {
-        replacedValues.push_back(decomposedValues.front());
-      } else {
-        // Materialize a single Value to replace the original Value.
-        Value materialized = getTypeConverter()->materializeArgumentConversion(
-            rewriter, op.getLoc(), op.getType(i), decomposedValues);
-        replacedValues.push_back(materialized);
-      }
+      ValueRange repl =
+          newCallOp.getResults().slice(startIdx, expandedResultSizes[i]);
+      replacedValues.push_back(repl);
+      startIdx += expandedResultSizes[i];
     }
-    rewriter.replaceOp(op, replacedValues);
+    rewriter.replaceOpWithMultiple(op, replacedValues);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 062a0ea6cc47cb..09509278d7749a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     flattenOperands(adaptor.getOperands(), flattened);
     auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
                                                  finalRetTy, flattened);
-    // (2) Create cast operation for sparse tensor returns.
-    SmallVector<Value> castedRet;
+    // (2) Gather sparse tensor returns.
+    SmallVector<SmallVector<Value>> packedResultVals;
     // Tracks the offset of current return value (of the original call)
     // relative to the new call (after sparse tensor flattening);
     unsigned retOffset = 0;
@@ -618,21 +618,27 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
       assert(!sparseFlat.empty());
       if (sparseFlat.size() > 1) {
         auto flatSize = sparseFlat.size();
-        ValueRange fields(iterator_range<ResultRange::iterator>(
-            newCall.result_begin() + retOffset,
-            newCall.result_begin() + retOffset + flatSize));
-        castedRet.push_back(genTuple(rewriter, loc, retType, fields));
+        packedResultVals.push_back(SmallVector<Value>());
+        llvm::append_range(packedResultVals.back(),
+                           iterator_range<ResultRange::iterator>(
+                               newCall.result_begin() + retOffset,
+                               newCall.result_begin() + retOffset + flatSize));
         retOffset += flatSize;
       } else {
         // If this is an 1:1 conversion, no need for casting.
-        castedRet.push_back(newCall.getResult(retOffset));
+        packedResultVals.push_back(SmallVector<Value>());
+        packedResultVals.back().push_back(newCall.getResult(retOffset));
         retOffset++;
       }
       sparseFlat.clear();
     }
 
-    assert(castedRet.size() == op.getNumResults());
-    rewriter.replaceOp(op, castedRet);
+    assert(packedResultVals.size() == op.getNumResults());
+    SmallVector<ValueRange> ranges;
+    ranges.reserve(packedResultVals.size());
+    for (const SmallVector<Value> &vec : packedResultVals)
+      ranges.push_back(ValueRange(vec));
+    rewriter.replaceOpWithMultiple(op, ranges);
     return success();
   }
 };
@@ -776,7 +782,7 @@ class SparseTensorAllocConverter
       // Reuses specifier.
       fields.push_back(desc.getSpecifier());
       assert(fields.size() == desc.getNumFields());
-      rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+      rewriter.replaceOpWithMultiple(op, {fields});
       return success();
     }
 
@@ -796,7 +802,7 @@ class SparseTensorAllocConverter
                       sizeHint, lvlSizesValues, fields);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 
@@ -837,7 +843,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
                       sizeHint, lvlSizesValues, fields);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, resType, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 
@@ -893,7 +899,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
     if (op.getHasInserts())
       genEndInsert(rewriter, op.getLoc(), desc);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, op.getLoc(), desc));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1006,7 +1012,6 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     rewriter.create<scf::YieldOp>(loc, insertRet);
 
     rewriter.setInsertionPointAfter(loop);
-    Value result = genTuple(rewriter, loc, dstType, loop->getResults());
     // Deallocate the buffers on exit of the full loop nest.
     Operation *parent = getTop(op);
     rewriter.setInsertionPointAfter(parent);
@@ -1014,7 +1019,7 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
     rewriter.create<memref::DeallocOp>(loc, filled);
     rewriter.create<memref::DeallocOp>(loc, added);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, result);
+    rewriter.replaceOpWithMultiple(op, {loop->getResults()});
     return success();
   }
 };
@@ -1041,8 +1046,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
                                     params, /*genCall=*/true);
     SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op,
-                       genTuple(rewriter, loc, op.getDest().getType(), ret));
+    rewriter.replaceOpWithMultiple(op, {ret});
     return success();
   }
 };
@@ -1215,8 +1219,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
           return true;
         });
 
-    rewriter.replaceOp(
-        op, genTuple(rewriter, loc, op.getResult().getType(), fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 };
@@ -1271,8 +1274,7 @@ class SparseExtractSliceConverter
     // NOTE: we can not generate tuples directly from descriptor here, as the
     // descriptor is holding the original type, yet we want the slice type
     // here (they shared every memref but with an updated specifier).
-    rewriter.replaceOp(op, genTuple(rewriter, loc, op.getResult().getType(),
-                                    desc.getFields()));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1403,7 +1405,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
     }
     desc.setValMemSize(rewriter, loc, memSize);
 
-    rewriter.replaceOp(op, genTuple(rewriter, loc, desc));
+    rewriter.replaceOpWithMultiple(op, {desc.getFields()});
     return success();
   }
 };
@@ -1577,7 +1579,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
                    EmitCInterface::Off);
 
     // Replace operation with resulting memrefs.
-    rewriter.replaceOp(op, genTuple(rewriter, loc, dstTp, fields));
+    rewriter.replaceOpWithMultiple(op, {fields});
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
index a3db50573c2720..834e3634cc130d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp
@@ -54,19 +54,24 @@ convertSparseTensorType(RankedTensorType rtp, SmallVectorImpl<Type> &fields) {
 // The sparse tensor type converter (defined in Passes.h).
 //===----------------------------------------------------------------------===//
 
+static Value materializeTuple(OpBuilder &builder, RankedTensorType tp,
+                              ValueRange inputs, Location loc) {
+  if (!getSparseTensorEncoding(tp))
+    // Not a sparse tensor.
+    return Value();
+  // Sparsifier knows how to cancel out these casts.
+  return genTuple(builder, loc, tp, inputs);
+}
+
 SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
   addConversion([](Type type) { return type; });
   addConversion(convertSparseTensorType);
 
   // Required by scf.for 1:N type conversion.
-  addSourceMaterialization([](OpBuilder &builder, RankedTensorType tp,
-                              ValueRange inputs, Location loc) -> Value {
-    if (!getSparseTensorEncoding(tp))
-      // Not a sparse tensor.
-      return Value();
-    // Sparsifier knows how to cancel out these casts.
-    return genTuple(builder, loc, tp, inputs);
-  });
+  addSourceMaterialization(materializeTuple);
+
+  // Required as a workaround until we have full 1:N support.
+  addArgumentMaterialization(materializeTuple);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 4714c3cace6c78..e85a86e94282ec 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -645,7 +645,7 @@ void OpBuilder::cloneRegionBefore(Region &region, Block *before) {
 
 OpBuilder::InsertPoint
 OpBuilder::InsertPoint::after(ArrayRef<Value> values,
-                              const PostDominanceInfo &domInfo) {
+                              const PostDominanceInfo *domInfo) {
   // Helper function that computes the point after v's definition.
   auto computeAfterIp = [](Value v) -> std::pair<Block *, Block::iterator> {
     if (auto blockArg = dyn_cast<BlockArgument>(v))
@@ -658,12 +658,18 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
   assert(!values.empty() && "expected at least one Value");
   auto [block, blockIt] = computeAfterIp(values.front());
 
+  if (values.size() == 1) {
+    // Fast path: There is only one value.
+    return InsertPoint(block, blockIt);
+  }
+
   // Check the other values one-by-one and update the insertion point if
   // needed.
+  assert(domInfo && "domInfo expected if >1 values");
   for (Value v : values.drop_front()) {
     auto [candidateBlock, candidateBlockIt] = computeAfterIp(v);
-    if (domInfo.postDominantes(candidateBlock, candidateBlockIt, block,
-                               blockIt)) {
+    if (domInfo->postDominantes(candidateBlock, candidateBlockIt, block,
+                                blockIt)) {
       // The point after v's definition post-dominates the current (and all
       // previous) insertion points. Note: Post-dominance is transitive.
       block = candidateBlock;
@@ -671,8 +677,8 @@ OpBuilder::InsertPoint::after(ArrayRef<Value> values,
       continue;
     }
 
-    if (!domInfo.postDominantes(block, blockIt, candidateBlock,
-                                candidateBlockIt)) {
+    if (!domInfo->postDominantes(block, blockIt, candidateBlock,
+                                 candidateBlockIt)) {
       // The point after v's definition and the current insertion point do not
       // post-dominate each other. Therefore, there is no insertion point that
       // post-dominates all values.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0a62628b9ad240..2f6c0a1ab0bd3b 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/Block.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Iterators.h"
 #include "mlir/Interfaces/FunctionInterfaces.h"
@@ -53,20 +54,14 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
   });
 }
 
-/// Helper function that computes an insertion point where the given value is
-/// defined and can be used without a dominance violation.
-static OpBuilder::InsertPoint computeInsertPoint(Value value) {
-  Block *insertBlock = value.getParentBlock();
-  Block::iterator insertPt = insertBlock->begin();
-  if (OpResult inputRes = dyn_cast<OpResult>(value))
-    insertPt = ++inputRes.getOwner()->getIterator();
-  return OpBuilder::InsertPoint(insertBlock, insertPt);
-}
-
 //===----------------------------------------------------------------------===//
 // ConversionValueMapping
 //===----------------------------------------------------------------------===//
 
+/// A list of replacement SSA values. Optimized for the common case of a single
+/// SSA value.
+using ReplacementValues = SmallVector<Value, 1>;
+
 namespace {
 /// This class wraps a IRMapping to provide recursive lookup
 /// functionality, i.e. we will traverse if the mapped value also has a mapping.
@@ -818,6 +813,22 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                        Type originalType,
                                        const TypeConverter *converter);
 
+  /// Build an N:1 materialization for the given original value that was
+  /// replaced with the given replacement values.
+  ///
+  /// This is a workaround around incomplete 1:N support in the dialect
+  /// conversion driver. The conversion mapping can store only 1:1 replacements
+  /// and the conversion patterns only support single Value replacements in the
+  /// adaptor, so N values must be converted back to a single value. This
+  /// function will be deleted when full 1:N support has been added.
+  ///
+  /// This function inserts an argument materialization back to the original
+  /// type, followed by a target materialization to the legalized type (if
+  /// applicable).
+  void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
+                                 ValueRange replacements, Value originalValue,
+                                 const TypeConverter *converter);
+
   //===--------------------------------------------------------------------===//
   // Rewriter Notification Hooks
   //===--------------------------------------------------------------------===//
@@ -827,7 +838,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                                OpBuilder::InsertPoint previous) override;
 
   /// Notifies that an op is about to be replaced with the given values.
-  void notifyOpReplaced(Operation *op, ValueRange newValues);
+  void notifyOpReplaced(Operation *op, ArrayRef<ReplacementValues> newValues);
 
   /// Notifies that a block is about to be erased.
   void notifyBlockIsBeingErased(Block *block);
@@ -1147,8 +1158,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
       // that the value was replaced with a value of different type and no
       // ...
[truncated]

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good to me besides some code nits 🙂

The only very high level conern I have is the use and need of InsertionPoint::after. As I understand, this is a not-yet-implemented optimization for the future for the purpose of reusing materializations by calculating a better insertion point where it may dominate more operations which might need exactly that materialization after pattern application as well.

Is the current complexity of calculating the insertion point (and potentially calculatin the post-dom tree, a O(n log n) operation) worth the gain of the optimization? I am thinking that the better insertion point insertion logic should probably be post-poned until we can measure its effectivness (and avoid the risk of a premature optimization) and have something simpler and working that does not worsen the status-quo for now instead.

@matthias-springer
Copy link
Member Author

matthias-springer commented Nov 13, 2024

The only very high level conern I have is the use and need of InsertionPoint::after. As I understand, this is a not-yet-implemented optimization for the future for the purpose of reusing materializations by calculating a better insertion point where it may dominate more operations which might need exactly that materialization after pattern application as well.

Yes, that's correct. But this optimization is already implemented. And it is actually quite deeply ingrained into the dialect conversion driver.

SSA values are not replaced directly, but in a delayed fashion. The replacement value is stored in the ConversionValueMapping (a kind of IRMapping). We store not only user-specified replacement values but also materializations in there. The data structure is kind of a linked list: for an original value, we store a list of replacement values: the specified replacement value and materializations to a different type. In the "finalize" phase, the driver goes through that list to find a replacement/materialization value with the correct type. We cannot store more than one replacement/materialization value with the same type; the driver would just pick the first one that matches. So we always reuse.

Note: Computing an insertion point for a 1:1 materialization (or a 1:N block signature conversion) is much easier because there is only one SSA value (or one block). That's why computing the insertion point was trivial until now.

Is the current complexity of calculating the insertion point (and potentially calculatin the post-dom tree, a O(n log n) operation) worth the gain of the optimization? I am thinking that the better insertion point insertion logic should probably be post-poned until we can measure its effectivness (and avoid the risk of a premature optimization) and have something simpler and working that does not worsen the status-quo for now instead.

This PR also makes the DominanceInfo argument to InsertPoint::after optional in case of a single SSA value. (And no DominanceInfo is passed in that case.) For the most frequent case of 1:1 replacements, we do not compute a dominator tree at all. (And we are not doing any extra work.)

In case of N:1 materializations, the implementation uses DominanceInfo and we create it during ConversionPatternRewriter::replaceOpWithMultiple. Unfortunately, it is not safe to reuse the same DominanceInfo object because a pattern could have made IR changes that invalidate the dominator tree.

However, I believe DominanceInfo is quite "cheap" to use. The dominator tree is built lazily and it is built on a per-region basis. E.g., creating a new DominanceInfo and querying dominance for two ops in the same region will just build the dominator tree for that region (and only if the ops are in different blocks). In case of two ops from different regions (or different blocks in the same region), the implementation will find the common ancestors (in the same region) and then compute the dominator tree only for that region.

Long term, the ConversionValueMapping is going to disappear with the One-Shot Dialect Conversion. As part of that, I'm also thinking of removing the "materialization caching" mechanism and just building duplicate materializations (in the form of unrealized_conversion_cast, i.e., not calling the callback). These can then be CSE'd before calling the materialization callback. The CSE-ing will require DominanceInfo, but the same DominanceInfo can be reused for the entire dialect conversion because at this point of time we are done with pattern applications.

Copy link

github-actions bot commented Nov 13, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 6c9256dc5cda9184e295bc8d00be35e61b3be892 001453c12e18c8e5f8d1b762cf5f67d56ec542e9 --extensions cpp,h -- mlir/include/mlir/Transforms/DialectConversion.h mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6b154dd161..bf7b3f9bec 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -618,7 +618,7 @@ public:
       assert(!sparseFlat.empty());
       if (sparseFlat.size() > 1) {
         auto flatSize = sparseFlat.size();
-	packedResultVals.emplace_back();
+        packedResultVals.emplace_back();
         llvm::append_range(packedResultVals.back(),
                            newCall.getResults().slice(retOffset, flatSize));
         retOffset += flatSize;

@River707
Copy link
Contributor

River707 commented Nov 13, 2024

The only very high level conern I have is the use and need of InsertionPoint::after. As I understand, this is a not-yet-implemented optimization for the future for the purpose of reusing materializations by calculating a better insertion point where it may dominate more operations which might need exactly that materialization after pattern application as well.

Yes, that's correct. But this optimization is already implemented. And it is actually quite deeply ingrained into the dialect conversion driver.

SSA values are not replaced directly, but in a delayed fashion. The replacement value is stored in the ConversionValueMapping (a kind of IRMapping). We store not only user-specified replacement values but also materializations in there. The data structure is kind of a linked list: for an original value, we store a list of replacement values: the specified replacement value and materializations to a different type. In the "finalize" phase, the driver goes through that list to find a replacement/materialization value with the correct type. We cannot store more than one replacement/materialization value with the same type; the driver would just pick the first one that matches. So we always reuse.

Note: Computing an insertion point for a 1:1 materialization (or a 1:N block signature conversion) is much easier because there is only one SSA value (or one block). That's why computing the insertion point was trivial until now.

Is the current complexity of calculating the insertion point (and potentially calculatin the post-dom tree, a O(n log n) operation) worth the gain of the optimization? I am thinking that the better insertion point insertion logic should probably be post-poned until we can measure its effectivness (and avoid the risk of a premature optimization) and have something simpler and working that does not worsen the status-quo for now instead.

This PR also makes the DominanceInfo argument to InsertPoint::after optional in case of a single SSA value. (And no DominanceInfo is passed in that case.) For the most frequent case of 1:1 replacements, we do not compute a dominator tree at all. (And we are not doing any extra work.)

In case of N:1 materializations, the implementation uses DominanceInfo and we create it during ConversionPatternRewriter::replaceOpWithMultiple. Unfortunately, it is not safe to reuse the same DominanceInfo object because a pattern could have made IR changes that invalidate the dominator tree.

However, I believe DominanceInfo is quite "cheap" to use. The dominator tree is built lazily and it is built on a per-region basis. E.g., creating a new DominanceInfo and querying dominance for two ops in the same region will just build the dominator tree for that region (and only if the ops are in different blocks). In case of two ops from different regions (or different blocks in the same region), the implementation will find the common ancestors (in the same region) and then compute the dominator tree only for that region.

Long term, the ConversionValueMapping is going to disappear with the One-Shot Dialect Conversion. As part of that, I'm also thinking of removing the "materialization caching" mechanism and just building duplicate materializations (in the form of unrealized_conversion_cast, i.e., not calling the callback). These can then be CSE'd before calling the materialization callback. The CSE-ing will require DominanceInfo, but the same DominanceInfo can be reused for the entire dialect conversion because at this point of time we are done with pattern applications.

One question I have here is why not just insert right before the operation that is being replaced? What is the need for trying to insert after one of the values? Wouldnt' that remove the need for all of this complexity? It gives you a singular place to insert.

@matthias-springer
Copy link
Member Author

matthias-springer commented Nov 13, 2024

One question I have here is why not just insert right before the operation that is being replaced? What is the need for trying to insert after one of the values? Wouldnt' that remove the need for all of this complexity? It gives you a singular place to insert.

Actually, I think you are right. I was thinking of a case where two ops (in different locations) were replaced with the same 1:N replacement values. And I thought that we store one mapping pair in that case. But we actually store one per original value (and not per replacement value). I just noticed this as I was trying to write down a counter-example...

Let me update the PR, I think this will simplify things quite a bit.

Edit: Dropped the insertion point computation. It is actually not needed, as @zero9178 and @River707 commented.

@matthias-springer matthias-springer changed the base branch from users/matthias-springer/insert_pt to main November 13, 2024 04:56
@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_with_multiple branch from b59db46 to d0d0632 Compare November 13, 2024 04:59
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you so much for the discussion. Makes perfect sense to me now.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_with_multiple branch from d0d0632 to 001453c Compare November 14, 2024 00:34
Apply suggestions from code review

Co-authored-by: Markus Böck <[email protected]>

address comments
@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_with_multiple branch from 001453c to f5ed959 Compare November 14, 2024 01:27
@matthias-springer matthias-springer merged commit aed4356 into main Nov 14, 2024
4 of 6 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/replace_with_multiple branch November 14, 2024 01:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:func mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants