Skip to content

Commit 55d0610

Browse files
authored
[mlir] make parseCustomOperationName() API check token type (#136306)
Previously, this parser API call would accept any token and interpret its spelling as operation name, including tokens that are are not valid operation names. Make it accept only bare identifiers and keywords. The latter is questionable but consistent with current practices upstream. Fixes #132889.
1 parent 7318074 commit 55d0610

File tree

6 files changed

+58
-0
lines changed

6 files changed

+58
-0
lines changed

mlir/lib/AsmParser/Parser.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,6 +2005,11 @@ class CustomOpAsmParser : public AsmParserImpl<OpAsmParser> {
20052005

20062006
FailureOr<OperationName> OperationParser::parseCustomOperationName() {
20072007
Token nameTok = getToken();
2008+
// Accept keywords here as they may be interpreted as a shortened operation
2009+
// name, e.g., `dialect.keyword` can be spelled as just `keyword` within a
2010+
// region of an operation from `dialect`.
2011+
if (nameTok.getKind() != Token::bare_identifier && !nameTok.isKeyword())
2012+
return emitError("expected bare identifier or keyword");
20082013
StringRef opName = nameTok.getSpelling();
20092014
if (opName.empty())
20102015
return (emitError("empty operation name is invalid"), failure());

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,3 +1666,12 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape(
16661666
%0 = linalg.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor<?x?x?x4xf32> -> tensor<?x?xf32>
16671667
return %0 : tensor<?x?xf32>
16681668
}
1669+
1670+
// -----
1671+
1672+
func.func @reduce_non_operation_name(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<f32> {
1673+
// expected-error @below {{expected bare identifier or keyword}}
1674+
%0 = linalg.reduce {@reduce_fusion_elementwise} ins(
1675+
%arg0: tensor<4xf32>) outs(%arg1: tensor<f32>) dimensions = [0]
1676+
return %0 : tensor<f32>
1677+
}

mlir/test/IR/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,16 @@ func.func @invalid_region_dominance_with_dominance_free_regions() {
642642

643643
// -----
644644

645+
// expected-error @below {{expected bare identifier or keyword}}
646+
test.parse_custom_operation_name_api(@foo) {}
647+
648+
// -----
649+
650+
// expected-error @below {{expected bare identifier or keyword}}
651+
test.parse_custom_operation_name_api(42) {}
652+
653+
// -----
654+
645655
// This makes sure we emit an error at the end of the correct line, the : is
646656
// expected at the end of foo, not on the return line.
647657
func.func @error_at_end_of_line() {

mlir/test/IR/parser.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,6 +1169,15 @@ func.func @op_with_region_args() {
11691169
return
11701170
}
11711171

1172+
// Test parsing an operation name from within another op custom syntax.
1173+
1174+
// CHECK-LABEL: @custom_name_api
1175+
func.func @custom_name_api() {
1176+
// CHECK: test.parse_custom_operation_name_api(builtin.module)
1177+
test.parse_custom_operation_name_api(builtin.module)
1178+
return
1179+
}
1180+
11721181
// Test allowing different name scopes for regions isolated from above.
11731182

11741183
// CHECK-LABEL: func @op_with_passthrough_region_args

mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,25 @@ static void printOptionalLoc(OpAsmPrinter &p, Operation *op, Attribute loc) {
487487
p.printOptionalLocationSpecifier(cast<LocationAttr>(loc));
488488
}
489489

490+
//===----------------------------------------------------------------------===//
491+
// ParseCustomOperationNameAPI
492+
//===----------------------------------------------------------------------===//
493+
494+
static ParseResult parseCustomOperationNameEntry(OpAsmParser &p,
495+
Attribute &name) {
496+
FailureOr<OperationName> opName = p.parseCustomOperationName();
497+
if (failed(opName))
498+
return ParseResult::failure();
499+
500+
name = p.getBuilder().getStringAttr(opName->getStringRef());
501+
return ParseResult::success();
502+
}
503+
504+
static void printCustomOperationNameEntry(OpAsmPrinter &p, Operation *op,
505+
Attribute name) {
506+
p << cast<StringAttr>(name).getValue();
507+
}
508+
490509
#define GET_OP_CLASSES
491510
#include "TestOpsSyntax.cpp.inc"
492511

mlir/test/lib/Dialect/Test/TestOpsSyntax.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def TestAttrWithLoc : TEST_Op<"attr_with_loc"> {
6969
let assemblyFormat = "`(` $value `` custom<OptionalLoc>($loc) `)` attr-dict";
7070
}
7171

72+
def ParseCustomOperationNameAPI : TEST_Op<"parse_custom_operation_name_api"> {
73+
let summary = "noop that exercises the parseCustomOperationName API";
74+
let arguments = (ins StrAttr:$name);
75+
let assemblyFormat = "`(` custom<CustomOperationNameEntry>($name) `)` attr-dict";
76+
}
77+
7278
// -----
7379

7480
// This is used to test that the fallback for a custom op's parser and printer

0 commit comments

Comments
 (0)