Skip to content

Commit c65f8d8

Browse files
authored
[mlir][linalg] Fix crashes in parser on linalg ops without operands (#97944)
`parseDstStyleOp` parses both `ins()` and `outs()` optionally. The parsers for `linalg.transpose`, `linalg.broadcast` and `linalg.map` however assume that at least one operand is present in the state, leading to crashes otherwise. This patch adds checks to the parsers which stop them from crashing if no operands were parsed. When the Ops are parsed successfuly, the verifiers can work on them. Fix #97857
1 parent 902fb1b commit c65f8d8

File tree

2 files changed

+83
-3
lines changed

2 files changed

+83
-3
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,8 +1356,12 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
13561356
return failure();
13571357

13581358
if (payloadOpName.has_value()) {
1359-
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1360-
ArrayRef(result.operands).drop_back());
1359+
if (!result.operands.empty())
1360+
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1361+
payloadOpAttrs,
1362+
ArrayRef(result.operands).drop_back());
1363+
else
1364+
result.addRegion();
13611365
} else {
13621366
SmallVector<OpAsmParser::Argument> regionArgs;
13631367
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1739,7 +1743,8 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc,
17391743
ValueRange outputs) {
17401744
buildGenericRegion(builder, loc, region, inputs, outputs,
17411745
[](OpBuilder &b, Location loc, ValueRange args) {
1742-
b.create<linalg::YieldOp>(loc, args[0]);
1746+
if (!args.empty())
1747+
b.create<linalg::YieldOp>(loc, args[0]);
17431748
});
17441749
}
17451750

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,32 @@ func.func @map_input_output_shape_mismatch(
455455

456456
// -----
457457

458+
func.func @map_no_operands1() {
459+
// expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}}
460+
linalg.map { arith.addf }
461+
}
462+
463+
// -----
464+
465+
func.func @map_no_operands2() {
466+
// expected-error @+1 {{'linalg.map' op expected 1 or more operands, but found 0}}
467+
"linalg.map"() ({
468+
^bb0:
469+
}) : () -> ()
470+
}
471+
472+
// -----
473+
474+
func.func @map_no_operands3(
475+
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
476+
-> tensor<64xf32> {
477+
// expected-error @+1 {{cannot name an operation with no results}}
478+
%add = linalg.map { arith.addf }
479+
func.return %add : tensor<64xf32>
480+
}
481+
482+
// -----
483+
458484
func.func @reduce_input_vs_init_dimension_mismatch(
459485
%input: tensor<16x32x64xf32>,
460486
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
@@ -676,6 +702,30 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
676702

677703
// -----
678704

705+
func.func @transpose_no_operands1() {
706+
// expected-error @+1 {{'linalg.transpose' op expected 2 operands, but found 0}}
707+
linalg.transpose permutation = [1, 0, 2]
708+
}
709+
710+
// -----
711+
712+
func.func @transpose_no_operands2() {
713+
// expected-error @+1 {{'linalg.transpose' op expected 2 operands, but found 0}}
714+
"linalg.transpose"() <{permutation = array<i64: 1, 0, 2>}> ({
715+
^bb0:
716+
}) : () -> ()
717+
}
718+
719+
// -----
720+
721+
func.func @transpose_no_operands3() -> tensor<32x64x16xf32> {
722+
// expected-error @+1 {{cannot name an operation with no results}}
723+
%transpose = linalg.transpose permutation = [1, 0, 2]
724+
func.return %transpose : tensor<32x64x16xf32>
725+
}
726+
727+
// -----
728+
679729
func.func @broadcast_input_dims_rank_mismatch(
680730
%input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
681731
-> tensor<4x8x16xf32> {
@@ -728,6 +778,31 @@ func.func @broadcast_size_1_extension_not_supported(
728778

729779
// -----
730780

781+
func.func @broadcast_no_operands1() {
782+
// expected-error @+1 {{'linalg.broadcast' op expected 2 operands, but found 0}}
783+
linalg.broadcast dimensions = [1]
784+
}
785+
786+
// -----
787+
788+
func.func @broadcast_no_operands2() {
789+
// expected-error @+1 {{'linalg.broadcast' op expected 2 operands, but found 0}}
790+
"linalg.broadcast"() <{dimensions = array<i64: 1>}> ({
791+
^bb0:
792+
}) : () -> ()
793+
}
794+
795+
// -----
796+
797+
func.func @broadcast_no_operands3()
798+
-> tensor<4x?x16xf32> {
799+
// expected-error @+1 {{cannot name an operation with no results}}
800+
%broadcast = linalg.broadcast dimensions = [1]
801+
func.return %broadcast : tensor<32x64x16xf32>
802+
}
803+
804+
// -----
805+
731806
func.func @missing_iterator_types() {
732807
// expected-error @below {{expected "iterator_types" array attribute}}
733808
linalg.generic {} ins() outs()

0 commit comments

Comments
 (0)