Skip to content

Commit bd56160

Browse files
committed
[mlir][linalg] Fix crashes in parser on linalg ops without operands
`parseDstStyleOp` parses both `ins()` and `outs()` optionally. The parsers for `linalg.transpose`, `linalg.broadcast` and `linalg.map` however assume that at least one operand is present in the state, leading to crashes otherwise. This patch adds checks to the parsers which stop them from crashing if no operands were parsed. After the Ops are parsed successfuly, the verifier takes it from there. Fix #97857
1 parent f4c7811 commit bd56160

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1356,8 +1356,10 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
13561356
return failure();
13571357

13581358
if (payloadOpName.has_value()) {
1359-
addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1360-
ArrayRef(result.operands).drop_back());
1359+
if (!result.operands.empty())
1360+
addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1361+
payloadOpAttrs,
1362+
ArrayRef(result.operands).drop_back());
13611363
} else {
13621364
SmallVector<OpAsmParser::Argument> regionArgs;
13631365
if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
@@ -1739,7 +1741,8 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc,
17391741
ValueRange outputs) {
17401742
buildGenericRegion(builder, loc, region, inputs, outputs,
17411743
[](OpBuilder &b, Location loc, ValueRange args) {
1742-
b.create<linalg::YieldOp>(loc, args[0]);
1744+
if (!args.empty())
1745+
b.create<linalg::YieldOp>(loc, args[0]);
17431746
});
17441747
}
17451748

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,18 @@ func.func @map_input_output_shape_mismatch(
455455

456456
// -----
457457

458+
func.func @map_no_operands(
459+
%lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
460+
-> tensor<64xf32> {
461+
// This must not crash the parser.
462+
linalg.map { arith.addf }
463+
// expected-error @+1 {{cannot name an operation with no results}}
464+
%add = linalg.map { arith.addf }
465+
func.return %add : tensor<64xf32>
466+
}
467+
468+
// -----
469+
458470
func.func @reduce_input_vs_init_dimension_mismatch(
459471
%input: tensor<16x32x64xf32>,
460472
%init: tensor<16x64xf32>) -> tensor<16x64xf32> {
@@ -676,6 +688,16 @@ func.func @transpose_input_init_rank_mismatch(%input: tensor<16x32xf32>,
676688

677689
// -----
678690

691+
func.func @transpose_no_operands() -> tensor<32x64x16xf32> {
692+
// This must not crash the parser.
693+
linalg.transpose permutation = [1, 0, 2]
694+
// expected-error @+1 {{cannot name an operation with no results}}
695+
%transpose = linalg.transpose permutation = [1, 0, 2]
696+
func.return %transpose : tensor<32x64x16xf32>
697+
}
698+
699+
// -----
700+
679701
func.func @broadcast_input_dims_rank_mismatch(
680702
%input: tensor<4x16xf32>, %init: tensor<4x8x16xf32>)
681703
-> tensor<4x8x16xf32> {
@@ -725,6 +747,15 @@ func.func @broadcast_size_1_extension_not_supported(
725747
dimensions = [1]
726748
func.return %bcast : tensor<4x?x16xf32>
727749
}
750+
// -----
751+
752+
func.func @broadcast_no_operands()
753+
-> tensor<4x?x16xf32> {
754+
linalg.broadcast dimensions = [1]
755+
// expected-error @+1 {{cannot name an operation with no results}}
756+
%broadcast = linalg.broadcast dimensions = [1]
757+
func.return %broadcast : tensor<32x64x16xf32>
758+
}
728759

729760
// -----
730761

0 commit comments

Comments
 (0)