-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Parser] Fix use-after-free when parsing invalid reference to nested definition #127778
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis commit fixes a use-after-free crash when parsing the following invalid IR: scf.for ... iter_args(%var = %foo) -> tensor<?xf32> {
%foo = "test.inner"() : () -> (tensor<?xf32>)
scf.yield %arg0 : tensor<?xf32>
} The
During operand resolution, a forward reference ( All parsers should be written in such a way that they first parse the region and then resolve the operands. That way, no forward reference is inserted in the first place. Before parsing the region, it may be necessary to set the argument types if they are defined as part of the assembly format of the op (as is the case with To make the parsing infrastructure more robust, this commit also delays the erase of forward references until the end of the lifetime of the parser. Instead of a use-after-free crash, users will then see more descriptive error messages such as:
Note: The proper way to fix the parser is to first parse the region, then resolve the operands. The change to Full diff: https://github.com/llvm/llvm-project/pull/127778.diff 3 Files Affected:
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index b5f1d2e27c9ba..2982757a6c5ce 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -820,6 +820,12 @@ class OperationParser : public Parser {
/// their first reference, to allow checking for use of undefined values.
DenseMap<Value, SMLoc> forwardRefPlaceholders;
+ /// Operations that define the placeholders. These are kept until the end of
+ /// of the lifetime of the parser because some custom parsers may store
+ /// references to them in local state and use them after forward references
+ /// have been resolved.
+ DenseSet<Operation *> forwardRefOps;
+
/// Deffered locations: when parsing `loc(#loc42)` we add an entry to this
/// map. After parsing the definition `#loc42 = ...` we'll patch back users
/// of this location.
@@ -847,11 +853,11 @@ OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
}
OperationParser::~OperationParser() {
- for (auto &fwd : forwardRefPlaceholders) {
+ for (Operation *op : forwardRefOps) {
// Drop all uses of undefined forward declared reference and destroy
// defining operation.
- fwd.first.dropAllUses();
- fwd.first.getDefiningOp()->destroy();
+ op->dropAllUses();
+ op->destroy();
}
for (const auto &scope : forwardRef) {
for (const auto &fwd : scope) {
@@ -1007,7 +1013,6 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
// the actual definition instead, delete the forward ref, and remove it
// from our set of forward references we track.
existing.replaceAllUsesWith(value);
- existing.getDefiningOp()->destroy();
forwardRefPlaceholders.erase(existing);
// If a definition of the value already exists, replace it in the assembly
@@ -1194,6 +1199,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
/*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
/*numRegions=*/0);
forwardRefPlaceholders[op->getResult(0)] = loc;
+ forwardRefOps.insert(op);
return op->getResult(0);
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 448141735ba7f..1f70ad57d986b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -499,8 +499,20 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
else if (parser.parseType(type))
return failure();
- // Resolve input operands.
+ // Set block argument types, so that they are known when parsing the region.
regionArgs.front().type = type;
+ for (auto [iterArg, type] :
+ llvm::zip(llvm::drop_begin(regionArgs), result.types))
+ iterArg.type = type;
+
+ // Parse the body region.
+ Region *body = result.addRegion();
+ if (parser.parseRegion(*body, regionArgs))
+ return failure();
+ ForOp::ensureTerminator(*body, builder, result.location);
+
+ // Resolve input operands. This should be done after parsing the region to
+ // catch invalid IR where operands were defined inside of the region.
if (parser.resolveOperand(lb, type, result.operands) ||
parser.resolveOperand(ub, type, result.operands) ||
parser.resolveOperand(step, type, result.operands))
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
}
}
- // Parse the body region.
- Region *body = result.addRegion();
- if (parser.parseRegion(*body, regionArgs))
- return failure();
-
- ForOp::ensureTerminator(*body, builder, result.location);
-
// Parse the optional attribute list.
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index 80576be880127..76c785f3e6166 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -747,3 +747,13 @@ func.func @parallel_missing_terminator(%0 : index) {
return
}
+// -----
+
+func.func @invalid_reference(%a: index) {
+ // expected-error @below{{use of undeclared SSA value name}}
+ scf.for %x = %a to %a step %a iter_args(%var = %foo) -> tensor<?xf32> {
+ %foo = "test.inner"() : () -> (tensor<?xf32>)
+ scf.yield %foo : tensor<?xf32>
+ }
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh, I'm confused: how could we not hit this before?
Edit: oh, it would only fail on invalid IR right?
Yes, this just crashed on invalid IR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Small suggestion while reviewing.
This looks to fix crash with input IR from #107121, which now produces:
issue-107121.mlir:1:1: error: operation's operand is unlinked
scf.if %true {
^
issue-107121.mlir:1:1: note: see current operation:
"scf.if"(<<UNKNOWN SSA VALUE>>) ({
%0 = "arith.constant"() <{value = true}> : () -> i1
"scf.yield"() : () -> ()
}, {
}) : (i1) -> ()
Although maybe this still not ideal, it at least no longer crashes.
(probably scf.if's parser should be altered similarly? I can take a look 👍)
Yes, I think swapping |
This commit fixes a use-after-free crash when parsing the following invalid IR:
The
scf.for
parser was implemented as follows:%foo
).During operand resolution, a forward reference (
unrealized_conversion_cast
) is added by the parser because%foo
has not been defined yet. During region parsing, the definition of%foo
is found and the forward reference is replaced with the actual definition. (And the forward reference is deleted.) However, the operand of thescf.for
op is not updated because thescf.for
op has not been created yet; all we have is anOperationState
object.All parsers should be written in such a way that they first parse the region and then resolve the operands. That way, no forward reference is inserted in the first place. Before parsing the region, it may be necessary to set the argument types if they are defined as part of the assembly format of the op (as is the case with
scf.for
). Note: Ops in generic format are parsed in the same way.To make the parsing infrastructure more robust, this commit also delays the erase of forward references until the end of the lifetime of the parser. Instead of a use-after-free crash, users will then see more descriptive error messages such as:
Note: The proper way to fix the parser is to first parse the region, then resolve the operands. The change to
Parser.cpp
is merely to help users finding the root cause of the problem.