Skip to content

[mlir][Transforms] Improve replaceOpWithMultiple API #132608

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Mar 23, 2025

This commit adds an additional overload to replaceOpWithMultiple that accepts additional container types. This has been brought up by users of the new replaceOpWithMultiple API.

In particular, one missing container type was SmallVector<SmallVector<Value>>. The "default" ArrayRef<ValueRange> container type can lead to use-after-scope errors in cases such as:

// Compute the replacement value ranges. Some replacements are single
// values, some are value ranges.
SmallVector<ValueRange> repl;
repl.push_back(someValueRange);  // OK
for (...) {
  // push_back(Value) triggers an implicit conversion to ValueRange,
  // which does not own the range.
  repl.push_back(someValue);  // triggers use-after-scope later
}
rewriter.replaceOpWithMultiple(op, repl);

In this example, users should use SmallVector<SmallVector<Value>> repl;.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:sparse Sparse compiler in MLIR mlir labels Mar 23, 2025
@llvmbot
Copy link
Member

llvmbot commented Mar 23, 2025

@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-sparse

Author: Matthias Springer (matthias-springer)

Changes

This commit adds an additional overload to replaceOpWithMultiple that accepts additional container types. This has been brought up by users of the new replaceOpWithMultiple API.

In particular, one missing container type was SmallVector&lt;SmallVector&lt;Value&gt;&gt;. The "default" ArrayRef&lt;ValueRange&gt; container type can lead to use-after-scope errors in cases such as:

// Compute the replacement value ranges. Some replacements are single
// values, some are value ranges.
SmallVector&lt;ValueRange&gt; repl;
repl.push_back(someValueRange);  // OK
for (...) {
  // push_back(Value) triggers an implicit conversion to ValueRange,
  // which does not own the Value.
  repl.push_back(someValue);  // triggers use-after-scope later
}

In this example, users should use SmallVector&lt;SmallVector&lt;Value&gt;&gt; repl;.


Full diff: https://github.com/llvm/llvm-project/pull/132608.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+4)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+1-2)
  • (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+22)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 8a70883293d91..cbf60b784af94 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -898,6 +898,10 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// Replace the given operation with the new value ranges. The number of op
   /// results and value ranges must match. The given  operation is erased.
   void replaceOpWithMultiple(Operation *op, ArrayRef<ValueRange> newValues);
+  template <typename RangeT>
+  void replaceOpWithMultiple(Operation *op, RangeT newValues) {
+    replaceOpWithMultiple(op, llvm::to_vector_of<ValueRange>(newValues));
+  }
 
   /// PatternRewriter hook for erasing a dead operation. The uses of this
   /// operation *must* be made dead by the end of the conversion process,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 6a66ad24a87b4..6291f3ea37230 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -616,8 +616,7 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
     }
 
     assert(packedResultVals.size() == op.getNumResults());
-    rewriter.replaceOpWithMultiple(
-        op, llvm::to_vector_of<ValueRange>(packedResultVals));
+    rewriter.replaceOpWithMultiple(op, packedResultVals);
     return success();
   }
 };
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index b868f1a3a08da..e325003f5384c 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1278,6 +1278,28 @@ class TestMultiple1ToNReplacement : public ConversionPattern {
   }
 };
 
+/// Test unambiguous overload resolution of replaceOpWithMultiple. This
+/// function is just to trigger compiler errors. It is never executed.
+[[maybe_unused]] void testReplaceOpWithMultipleOverloads(
+    ConversionPatternRewriter &rewriter, Operation *op, ArrayRef<ValueRange> r1,
+    SmallVector<ValueRange> r2, ArrayRef<SmallVector<Value>> r3,
+    SmallVector<SmallVector<Value>> r4, ArrayRef<ArrayRef<Value>> r5,
+    SmallVector<ArrayRef<Value>> r6, Value v, ValueRange vr,
+    ArrayRef<Value> ar) {
+  rewriter.replaceOpWithMultiple(op, r1);
+  rewriter.replaceOpWithMultiple(op, r2);
+  rewriter.replaceOpWithMultiple(op, r3);
+  rewriter.replaceOpWithMultiple(op, r4);
+  rewriter.replaceOpWithMultiple(op, r5);
+  rewriter.replaceOpWithMultiple(op, r6);
+  rewriter.replaceOpWithMultiple(op, {vr});
+  rewriter.replaceOpWithMultiple(op, {ar});
+  rewriter.replaceOpWithMultiple(op, {{v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, vr});
+  rewriter.replaceOpWithMultiple(op, {{v, v}, ar});
+  rewriter.replaceOpWithMultiple(op, {ar, {v, v}, vr});
+}
 } // namespace
 
 namespace {

@matthias-springer
Copy link
Member Author

@MaheshRavishankar I couldn't find the PR anymore where we were discussing this. I ran into similar problems with the API recently. Does this PR help?

@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_op_with_multiple_overloads branch from c4e8397 to 8d9755a Compare March 23, 2025 12:23
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

FWIW I tried to fix the root cause of the example you gave previously here #121996 but ran out of bits in Type on 32-bit platforms :( Hoping I can come up with some way in the future to fix use-after-free

Copy link

github-actions bot commented Mar 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_op_with_multiple_overloads branch from 287bcbd to b6a28eb Compare March 25, 2025 13:33
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making things not copy in the ideal case is slightly more involved as it affects all functions that are sinks: functions that take data from parameters and store them directly into internal datastructures, including transitively and must be able to move them.

This sadly disqualifies using ArrayRef as parameter as it always returns a const reference, which will never call the move constructor. The good news is that SmallVector has a move constructor from SmallVectorImpl, i.e. we do not need SmallVector<Value, 1>.

I prototyped a solution in zero9178@b30cdb1 but it requires a few more changes including in ADT.
It makes all the sinks either have two versions: SmallVector<SmallVector<Value>>&& for when the user does a std::move or construct a copy to be moved. This is not ideal compared to e.g. having two overloads of every funciton (one for moving, one without), but an improvement nevertheless.

@@ -1520,7 +1520,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
}

void ConversionPatternRewriterImpl::notifyOpReplaced(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is also a sink function i.e. needs to at the very least take SmallVector<SmallVector<Value>>&& such that we can move repl below into map

@joker-eph
Copy link
Collaborator

This sadly disqualifies using ArrayRef as parameter as it always returns a const reference, which will never call the move constructor.

What about MutableArrayRef ?

@zero9178
Copy link
Member

This sadly disqualifies using ArrayRef as parameter as it always returns a const reference, which will never call the move constructor.

What about MutableArrayRef ?

That would work. I just personally think behaviour is a bit unituative as the mutation is rather unexpected. i.e:

SmallVector replacement = {...};
replaceOpWithMutable(op, replacements);
// contents of replacement are suddenly all empty vectors

compare to having std::move show that this is a destructive operation.

Update mlir/include/mlir/Transforms/DialectConversion.h

Co-authored-by: Markus Böck <[email protected]>

improve api
@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_op_with_multiple_overloads branch from 1b6eddf to f3b2445 Compare March 28, 2025 10:41
@matthias-springer
Copy link
Member Author

matthias-springer commented Mar 28, 2025

I took a look at all the places where we currently call replaceOpWithMultiple in MLIR. (The API is not that widely used yet because it is new.)

  1. replaceOpWithMultiple(Operation *, { SmallVector<Value> }): 10x
  2. replaceOpWithMultiple(Operation *, SmallVector<ValueRange>): 2x
  3. replaceOpWithMultiple(Operation *, SmallVector<SmallVector<Value>>): 1x
  4. replaceOpWithMultiple(Operation *, SmallVector<ResultRange>): 1x
  5. replaceOpWithMultiple(Operation *, { ArrayRef<Value> }): 1x
  6. replaceOpWithMultiple(Operation *, { ValueRange }): 7x

Internally, a replacement for a single value is always stored as a SmallVector<Value, 1>.

(2), (4), (5), (6) always require a copy because the container does not own the range of values. It is not possible to move into the SmallVector<Value, 1>. I expect (6) to become even more common because the OneToNOpAdaptor returns ValueRange.

A copy could be avoided for (3) if the user passes the replacements with std::move.

I am not entirely sure about about (1). I think it is not possible to use move semantics here. Unless maybe the user writes replaceOpWithMultiple(op, { std::move(vec) }). Alternatively, moving would be possible when rewriting the code as (3). (See previous paragraph.)

SmallVector<SmallVector<Value>>&& for when the user does a std::move or construct a copy to be moved.

I implemented this to make it possible to use move semantics with (3). (The other overloads copy.)

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still LGTM :)))

@matthias-springer matthias-springer merged commit 4abff4d into main Mar 28, 2025
11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/replace_op_with_multiple_overloads branch March 28, 2025 13:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:sme mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants