Skip to content

[mlir][nvgpu]add dim check test to nvgpu.mma op. #122864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025

Conversation

linuxlonelyeagle
Copy link
Member

as tile.
In the one-dimensional case, mlir-opt crashes directly, and I added more checks in nvgpu.mmaOp 's verify.

@llvmbot
Copy link
Member

llvmbot commented Jan 14, 2025

@llvm/pr-subscribers-mlir-nvgpu

@llvm/pr-subscribers-mlir-gpu

Author: lonely eagle (linuxlonelyeagle)

Changes

as tile.
In the one-dimensional case, mlir-opt crashes directly, and I added more checks in nvgpu.mmaOp 's verify.


Full diff: https://github.com/llvm/llvm-project/pull/122864.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+12)
  • (modified) mlir/test/Dialect/NVGPU/invalid.mlir (+24)
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index de9bbcbace6924..a027350e8a5f70 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -203,6 +203,18 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
   // Basic verification
   //
 
+  if (aShape.size() != 2) {
+    return op->emitError() << "matrixA must be 2 dimensional vector";
+  }
+
+  if (bShape.size() != 2) {
+    return op->emitError() << "matrixB must be 2 dimensional vector";
+  }
+
+  if (cShape.size() != 2) {
+    return op->emitError() << "matrixC must be 2 dimensional vector";
+  }
+
   auto [m, n, k] = mmaShape;
 
   // verify warp-wide size for vector a
diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir
index f7db1140794e54..b5bfbe9ff27b79 100644
--- a/mlir/test/Dialect/NVGPU/invalid.mlir
+++ b/mlir/test/Dialect/NVGPU/invalid.mlir
@@ -354,3 +354,27 @@ func.func @rcp_unsupported_ftz(%in : vector<16xf32>) {
   // expected-error @+1 {{'nvgpu.rcp' op has a limitation. #nvgpu<rcp_rounding_mode approx> or non-ftz is not supported yet.}}
   %out = nvgpu.rcp %in {rounding = approx} : vector<16xf32>
 }
+
+// -----
+
+func.func @check_matrixA_dim(%arg0: vector<16xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{matrixA must be 2 dimensional vector}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<16xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
+func.func @check_matrixB_dim(%arg0: vector<4x4xf16>, %arg1: vector<4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{matrixB must be 2 dimensional vector}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<4xf16>, vector<2x2xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}
+
+// -----
+
+func.func @check_matrixC_dim(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<4xf16>) -> vector<2x2xf16> {
+  // expected-error @+1 {{matrixC must be 2 dimensional vector}}
+  %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<4xf16>) -> vector<2x2xf16>
+  return %d : vector<2x2xf16>
+}

@linuxlonelyeagle linuxlonelyeagle merged commit 53c7fe5 into llvm:main Jan 14, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants