Skip to content

[mlir][Parser] Fix use-after-free when parsing invalid reference to nested definition #127778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 20, 2025

Conversation

matthias-springer
Copy link
Member

This commit fixes a use-after-free crash when parsing the following invalid IR:

scf.for ... iter_args(%var = %foo) -> tensor<?xf32> {
  %foo = "test.inner"() : () -> (tensor<?xf32>)
  scf.yield %arg0 : tensor<?xf32>
}

The scf.for parser was implemented as follows:

  1. Resolve operands (including %foo).
  2. Parse the region.

During operand resolution, a forward reference (unrealized_conversion_cast) is added by the parser because %foo has not been defined yet. During region parsing, the definition of %foo is found and the forward reference is replaced with the actual definition. (And the forward reference is deleted.) However, the operand of the scf.for op is not updated because the scf.for op has not been created yet; all we have is an OperationState object.

All parsers should be written in such a way that they first parse the region and then resolve the operands. That way, no forward reference is inserted in the first place. Before parsing the region, it may be necessary to set the argument types if they are defined as part of the assembly format of the op (as is the case with scf.for). Note: Ops in generic format are parsed in the same way.

To make the parsing infrastructure more robust, this commit also delays the erase of forward references until the end of the lifetime of the parser. Instead of a use-after-free crash, users will then see more descriptive error messages such as:

error: operation's operand is unlinked

Note: The proper way to fix the parser is to first parse the region, then resolve the operands. The change to Parser.cpp is merely to help users finding the root cause of the problem.

@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This commit fixes a use-after-free crash when parsing the following invalid IR:

scf.for ... iter_args(%var = %foo) -&gt; tensor&lt;?xf32&gt; {
  %foo = "test.inner"() : () -&gt; (tensor&lt;?xf32&gt;)
  scf.yield %arg0 : tensor&lt;?xf32&gt;
}

The scf.for parser was implemented as follows:

  1. Resolve operands (including %foo).
  2. Parse the region.

During operand resolution, a forward reference (unrealized_conversion_cast) is added by the parser because %foo has not been defined yet. During region parsing, the definition of %foo is found and the forward reference is replaced with the actual definition. (And the forward reference is deleted.) However, the operand of the scf.for op is not updated because the scf.for op has not been created yet; all we have is an OperationState object.

All parsers should be written in such a way that they first parse the region and then resolve the operands. That way, no forward reference is inserted in the first place. Before parsing the region, it may be necessary to set the argument types if they are defined as part of the assembly format of the op (as is the case with scf.for). Note: Ops in generic format are parsed in the same way.

To make the parsing infrastructure more robust, this commit also delays the erase of forward references until the end of the lifetime of the parser. Instead of a use-after-free crash, users will then see more descriptive error messages such as:

error: operation's operand is unlinked

Note: The proper way to fix the parser is to first parse the region, then resolve the operands. The change to Parser.cpp is merely to help users finding the root cause of the problem.


Full diff: https://github.com/llvm/llvm-project/pull/127778.diff

3 Files Affected:

  • (modified) mlir/lib/AsmParser/Parser.cpp (+10-4)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+13-8)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+10)
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index b5f1d2e27c9ba..2982757a6c5ce 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -820,6 +820,12 @@ class OperationParser : public Parser {
   /// their first reference, to allow checking for use of undefined values.
   DenseMap<Value, SMLoc> forwardRefPlaceholders;
 
+  /// Operations that define the placeholders. These are kept until the end of
+  /// of the lifetime of the parser because some custom parsers may store
+  /// references to them in local state and use them after forward references
+  /// have been resolved.
+  DenseSet<Operation *> forwardRefOps;
+
   /// Deffered locations: when parsing `loc(#loc42)` we add an entry to this
   /// map. After parsing the definition `#loc42 = ...` we'll patch back users
   /// of this location.
@@ -847,11 +853,11 @@ OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
 }
 
 OperationParser::~OperationParser() {
-  for (auto &fwd : forwardRefPlaceholders) {
+  for (Operation *op : forwardRefOps) {
     // Drop all uses of undefined forward declared reference and destroy
     // defining operation.
-    fwd.first.dropAllUses();
-    fwd.first.getDefiningOp()->destroy();
+    op->dropAllUses();
+    op->destroy();
   }
   for (const auto &scope : forwardRef) {
     for (const auto &fwd : scope) {
@@ -1007,7 +1013,6 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
     // the actual definition instead, delete the forward ref, and remove it
     // from our set of forward references we track.
     existing.replaceAllUsesWith(value);
-    existing.getDefiningOp()->destroy();
     forwardRefPlaceholders.erase(existing);
 
     // If a definition of the value already exists, replace it in the assembly
@@ -1194,6 +1199,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
       /*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
       /*numRegions=*/0);
   forwardRefPlaceholders[op->getResult(0)] = loc;
+  forwardRefOps.insert(op);
   return op->getResult(0);
 }
 
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 448141735ba7f..1f70ad57d986b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -499,8 +499,20 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
   else if (parser.parseType(type))
     return failure();
 
-  // Resolve input operands.
+  // Set block argument types, so that they are known when parsing the region.
   regionArgs.front().type = type;
+  for (auto [iterArg, type] :
+       llvm::zip(llvm::drop_begin(regionArgs), result.types))
+    iterArg.type = type;
+
+  // Parse the body region.
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+  ForOp::ensureTerminator(*body, builder, result.location);
+
+  // Resolve input operands. This should be done after parsing the region to
+  // catch invalid IR where operands were defined inside of the region.
   if (parser.resolveOperand(lb, type, result.operands) ||
       parser.resolveOperand(ub, type, result.operands) ||
       parser.resolveOperand(step, type, result.operands))
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
     }
   }
 
-  // Parse the body region.
-  Region *body = result.addRegion();
-  if (parser.parseRegion(*body, regionArgs))
-    return failure();
-
-  ForOp::ensureTerminator(*body, builder, result.location);
-
   // Parse the optional attribute list.
   if (parser.parseOptionalAttrDict(result.attributes))
     return failure();
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 80576be880127..76c785f3e6166 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -747,3 +747,13 @@ func.func @parallel_missing_terminator(%0 : index) {
   return
 }
 
+// -----
+
+func.func @invalid_reference(%a: index) {
+  // expected-error @below{{use of undeclared SSA value name}}
+  scf.for %x = %a to %a step %a iter_args(%var = %foo) -> tensor<?xf32> {
+    %foo = "test.inner"() : () -> (tensor<?xf32>)
+    scf.yield %foo : tensor<?xf32>
+  }
+  return
+}

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, I'm confused: how could we not hit this before?

Edit: oh, it would only fail on invalid IR right?

@matthias-springer
Copy link
Member Author

Yes, this just crashed on invalid IR.

Copy link
Contributor

@dtzSiFive dtzSiFive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Small suggestion while reviewing.

This looks to fix crash with input IR from #107121, which now produces:

issue-107121.mlir:1:1: error: operation's operand is unlinked
scf.if %true {
^
issue-107121.mlir:1:1: note: see current operation:
"scf.if"(<<UNKNOWN SSA VALUE>>) ({
  %0 = "arith.constant"() <{value = true}> : () -> i1
  "scf.yield"() : () -> ()
}, {
}) : (i1) -> ()

Although maybe this still not ideal, it at least no longer crashes.

(probably scf.if's parser should be altered similarly? I can take a look 👍)

@matthias-springer
Copy link
Member Author

Yes, I think swapping resolveOperand and parseRegion should do the trick. Feel free to send a PR!

@matthias-springer matthias-springer merged commit 7d03c8e into main Feb 20, 2025
6 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/fix_parser branch February 20, 2025 07:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:scf mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir] heap-use-after-free parsing example with forward-ref to result defined within
4 participants