Skip to content

Commit 53c7fe5

Browse files
[mlir][nvgpu]add dim check test to nvgpu.mma op. (#122864)
add shape checks of matrixA, matrixB, and matrixC to the nvgpu.mma's verify.
1 parent 95f7c2f commit 53c7fe5

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,18 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
203203
// Basic verification
204204
//
205205

206+
if (aShape.size() != 2) {
207+
return op->emitError() << "matrixA must be 2 dimensional vector";
208+
}
209+
210+
if (bShape.size() != 2) {
211+
return op->emitError() << "matrixB must be 2 dimensional vector";
212+
}
213+
214+
if (cShape.size() != 2) {
215+
return op->emitError() << "matrixC must be 2 dimensional vector";
216+
}
217+
206218
auto [m, n, k] = mmaShape;
207219

208220
// verify warp-wide size for vector a

mlir/test/Dialect/NVGPU/invalid.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,3 +354,27 @@ func.func @rcp_unsupported_ftz(%in : vector<16xf32>) {
354354
// expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode approx> or non-ftz is not supported yet.}}
355355
%out = nvgpu.rcp %in {rounding = approx} : vector<16xf32>
356356
}
357+
358+
// -----
359+
360+
func.func @check_matrixA_dim(%arg0: vector<16xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
361+
// expected-error @+1 {{matrixA must be 2 dimensional vector}}
362+
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<16xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
363+
return %d : vector<2x2xf16>
364+
}
365+
366+
// -----
367+
368+
func.func @check_matrixB_dim(%arg0: vector<4x4xf16>, %arg1: vector<4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
369+
// expected-error @+1 {{matrixB must be 2 dimensional vector}}
370+
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
371+
return %d : vector<2x2xf16>
372+
}
373+
374+
// -----
375+
376+
func.func @check_matrixC_dim(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<4xf16>) -> vector<2x2xf16> {
377+
// expected-error @+1 {{matrixC must be 2 dimensional vector}}
378+
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
379+
return %d : vector<2x2xf16>
380+
}

0 commit comments

Comments
 (0)