Skip to content

Commit 7d03c8e

Browse files
[mlir][Parser] Fix use-after-free when parsing invalid reference to nested definition (#127778)
This commit fixes a use-after-free crash when parsing the following invalid IR: ```mlir scf.for ... iter_args(%var = %foo) -> tensor<?xf32> { %foo = "test.inner"() : () -> (tensor<?xf32>) scf.yield %arg0 : tensor<?xf32> } ``` The `scf.for` parser was implemented as follows: 1. Resolve operands (including `%foo`). 2. Parse the region. During operand resolution, a forward reference (`unrealized_conversion_cast`) is added by the parser because `%foo` has not been defined yet. During region parsing, the definition of `%foo` is found and the forward reference is replaced with the actual definition. (And the forward reference is deleted.) However, the operand of the `scf.for` op is not updated because the `scf.for` op has not been created yet; all we have is an `OperationState` object. All parsers should be written in such a way that they first parse the region and then resolve the operands. That way, no forward reference is inserted in the first place. Before parsing the region, it may be necessary to set the argument types if they are defined as part of the assembly format of the op (as is the case with `scf.for`). Note: Ops in generic format are parsed in the same way. To make the parsing infrastructure more robust, this commit also delays the erase of forward references until the end of the lifetime of the parser. Instead of a use-after-free crash, users will then see more descriptive error messages such as: ``` error: operation's operand is unlinked ``` Note: The proper way to fix the parser is to first parse the region, then resolve the operands. The change to `Parser.cpp` is merely to help users finding the root cause of the problem.
1 parent 77183a4 commit 7d03c8e

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

mlir/lib/AsmParser/Parser.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -820,6 +820,12 @@ class OperationParser : public Parser {
820820
/// their first reference, to allow checking for use of undefined values.
821821
DenseMap<Value, SMLoc> forwardRefPlaceholders;
822822

823+
/// Operations that define the placeholders. These are kept until the end of
824+
/// of the lifetime of the parser because some custom parsers may store
825+
/// references to them in local state and use them after forward references
826+
/// have been resolved.
827+
DenseSet<Operation *> forwardRefOps;
828+
823829
/// Deffered locations: when parsing `loc(#loc42)` we add an entry to this
824830
/// map. After parsing the definition `#loc42 = ...` we'll patch back users
825831
/// of this location.
@@ -847,11 +853,11 @@ OperationParser::OperationParser(ParserState &state, ModuleOp topLevelOp)
847853
}
848854

849855
OperationParser::~OperationParser() {
850-
for (auto &fwd : forwardRefPlaceholders) {
856+
for (Operation *op : forwardRefOps) {
851857
// Drop all uses of undefined forward declared reference and destroy
852858
// defining operation.
853-
fwd.first.dropAllUses();
854-
fwd.first.getDefiningOp()->destroy();
859+
op->dropAllUses();
860+
op->destroy();
855861
}
856862
for (const auto &scope : forwardRef) {
857863
for (const auto &fwd : scope) {
@@ -1007,7 +1013,6 @@ ParseResult OperationParser::addDefinition(UnresolvedOperand useInfo,
10071013
// the actual definition instead, delete the forward ref, and remove it
10081014
// from our set of forward references we track.
10091015
existing.replaceAllUsesWith(value);
1010-
existing.getDefiningOp()->destroy();
10111016
forwardRefPlaceholders.erase(existing);
10121017

10131018
// If a definition of the value already exists, replace it in the assembly
@@ -1194,6 +1199,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
11941199
/*attributes=*/std::nullopt, /*properties=*/nullptr, /*successors=*/{},
11951200
/*numRegions=*/0);
11961201
forwardRefPlaceholders[op->getResult(0)] = loc;
1202+
forwardRefOps.insert(op);
11971203
return op->getResult(0);
11981204
}
11991205

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -499,15 +499,27 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
499499
else if (parser.parseType(type))
500500
return failure();
501501

502-
// Resolve input operands.
502+
// Set block argument types, so that they are known when parsing the region.
503503
regionArgs.front().type = type;
504+
for (auto [iterArg, type] :
505+
llvm::zip_equal(llvm::drop_begin(regionArgs), result.types))
506+
iterArg.type = type;
507+
508+
// Parse the body region.
509+
Region *body = result.addRegion();
510+
if (parser.parseRegion(*body, regionArgs))
511+
return failure();
512+
ForOp::ensureTerminator(*body, builder, result.location);
513+
514+
// Resolve input operands. This should be done after parsing the region to
515+
// catch invalid IR where operands were defined inside of the region.
504516
if (parser.resolveOperand(lb, type, result.operands) ||
505517
parser.resolveOperand(ub, type, result.operands) ||
506518
parser.resolveOperand(step, type, result.operands))
507519
return failure();
508520
if (hasIterArgs) {
509-
for (auto argOperandType :
510-
llvm::zip(llvm::drop_begin(regionArgs), operands, result.types)) {
521+
for (auto argOperandType : llvm::zip_equal(llvm::drop_begin(regionArgs),
522+
operands, result.types)) {
511523
Type type = std::get<2>(argOperandType);
512524
std::get<0>(argOperandType).type = type;
513525
if (parser.resolveOperand(std::get<1>(argOperandType), type,
@@ -516,13 +528,6 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
516528
}
517529
}
518530

519-
// Parse the body region.
520-
Region *body = result.addRegion();
521-
if (parser.parseRegion(*body, regionArgs))
522-
return failure();
523-
524-
ForOp::ensureTerminator(*body, builder, result.location);
525-
526531
// Parse the optional attribute list.
527532
if (parser.parseOptionalAttrDict(result.attributes))
528533
return failure();

mlir/test/Dialect/SCF/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,3 +747,13 @@ func.func @parallel_missing_terminator(%0 : index) {
747747
return
748748
}
749749

750+
// -----
751+
752+
func.func @invalid_reference(%a: index) {
753+
// expected-error @below{{use of undeclared SSA value name}}
754+
scf.for %x = %a to %a step %a iter_args(%var = %foo) -> tensor<?xf32> {
755+
%foo = "test.inner"() : () -> (tensor<?xf32>)
756+
scf.yield %foo : tensor<?xf32>
757+
}
758+
return
759+
}

0 commit comments

Comments
 (0)