Skip to content

Commit 2491867

Browse files
authored
[mlir][nvgpu] Improve verifier of ldmatrix (#77807)
PR improves the verifier of `nvgpu.ldmatrix` Op, so `nvgpu-to-nvvm` lowering does not crash.
1 parent 2e78c22 commit 2491867

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ LogicalResult LdMatrixOp::verify() {
321321
if (isTranspose && !(elementBitWidth == 16))
322322
return emitError()
323323
<< "nvgpu.ldmatrix transpose works only at 16b granularity";
324+
if (resShape.size() != 2) {
325+
return emitError() << "results must be 2 dimensional vector";
326+
}
324327
if (!(resShape[1] == numElementsPer32b))
325328
return emitError() << "expected vector register shape[1] = "
326329
<< numElementsPer32b;

mlir/test/Dialect/NVGPU/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x1xf
4040
}
4141
// -----
4242

43+
func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x1xf32> {
44+
%c0 = arith.constant 0 : index
45+
// expected-error @+1 {{results must be 2 dimensional vector}}
46+
%a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4xf32>
47+
return %a : vector<4xf32>
48+
}
49+
// -----
50+
4351
func.func @ldmatrix_type_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x2xf16> {
4452
%c0 = arith.constant 0 : index
4553
// expected-error @+1 {{'nvgpu.ldmatrix' op failed to verify that srcMemref and res have same element type}}

0 commit comments

Comments
 (0)