Skip to content

Commit 51e5f67

Browse files
authored
[mlir][vector] Fix crash on invalid permutation_map (#74925)
Without this patch, MLIR crashes with ``` Assertion failed: (getNumDims() == map.getNumResults() && "Number of results mismatch"), function compose, file AffineMap.cpp, line 537. ``` during parsing.
1 parent 35ebd92 commit 51e5f67

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3815,6 +3815,11 @@ ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
38153815
if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
38163816
return parser.emitError(
38173817
maskInfo.location, "does not support masks with vector element type");
3818+
if (vectorType.getRank() != permMap.getNumResults()) {
3819+
return parser.emitError(typesLoc,
3820+
"expected the same rank for the vector and the "
3821+
"results of the permutation map");
3822+
}
38183823
// Instead of adding the mask type as an op type, compute it based on the
38193824
// vector type and the permutation map (to keep the type signature small).
38203825
auto maskType = inferTransferOpMaskType(vectorType, permMap);
@@ -4181,6 +4186,11 @@ ParseResult TransferWriteOp::parse(OpAsmParser &parser,
41814186
if (llvm::dyn_cast<VectorType>(shapedType.getElementType()))
41824187
return parser.emitError(
41834188
maskInfo.location, "does not support masks with vector element type");
4189+
if (vectorType.getRank() != permMap.getNumResults()) {
4190+
return parser.emitError(typesLoc,
4191+
"expected the same rank for the vector and the "
4192+
"results of the permutation map");
4193+
}
41844194
auto maskType = inferTransferOpMaskType(vectorType, permMap);
41854195
if (parser.resolveOperand(maskInfo, maskType, result.operands))
41864196
return failure();

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,28 @@ func.func @test_vector.transfer_read(%arg0: memref<?x?xf32>) {
332332

333333
// -----
334334

335+
#map1 = affine_map<(d0, d1, d2) -> (d0, 0, 0)>
336+
func.func @main(%m: memref<1xi32>, %2: vector<1x32xi1>) -> vector<1x32xi32> {
337+
%0 = arith.constant 1 : index
338+
%1 = arith.constant 1 : i32
339+
// expected-error@+1 {{expected the same rank for the vector and the results of the permutation map}}
340+
%3 = vector.transfer_read %m[%0], %1, %2 { permutation_map = #map1 } : memref<1xi32>, vector<1x32xi32>
341+
return %3 : vector<1x32xi32>
342+
}
343+
344+
// -----
345+
346+
#map1 = affine_map<(d0, d1, d2) -> (d0, 0, 0)>
347+
func.func @test_vector.transfer_write(%m: memref<1xi32>, %2: vector<1x32xi32>) -> vector<1x32xi32> {
348+
%0 = arith.constant 1 : index
349+
%1 = arith.constant 1 : i32
350+
// expected-error@+1 {{expected the same rank for the vector and the results of the permutation map}}
351+
%3 = vector.transfer_write %2, %m[%0], %1 { permutation_map = #map1 } : vector<1x32xi32>, memref<1xi32>
352+
return %3 : vector<1x32xi32>
353+
}
354+
355+
// -----
356+
335357
func.func @test_vector.transfer_read(%arg0: vector<4x3xf32>) {
336358
%c3 = arith.constant 3 : index
337359
%f0 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)