Skip to content

[mlir][spirv] Fix remaining coop matrix verification corner cases #66137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 12, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 12, 2023

  • Check MakePointer* load/store attribute values.
  • Support coop matrix types in MatrixTimesScalar verification.
  • Add test cases for all the remaining ops that accept coop matrix types.
  • Split NV and KHR tests.

- Check `MakePointer*` load/store attribute values.
- Suuport coop matrix types in `MatrixTimesScalar` verification.
- Add testcases for all the remaining ops that accept coop matrix types.
- Split NV and KHR tests.
@kuhar kuhar requested a review from a team as a code owner September 12, 2023 20:08
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:spirv mlir labels Sep 12, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2023

@llvm/pr-subscribers-mlir-spirv

Changes
  • Check MakePointer* load/store attribute values.
  • Support coop matrix types in MatrixTimesScalar verification.
  • Add test cases for all the remaining ops that accept coop matrix types.
  • Split NV and KHR tests.
    --

Patch is 33.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66137.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td (+2-3)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+29-10)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+12-11)
  • (renamed) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+114-119)
  • (added) mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir (+177)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a21fc0ce2f9299c..a055cadc756a7e6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
-    Result Type must be an OpTypeMatrix whose Column Type is a vector of
-    floating-point type.
+    Result Type must be a matrix type with a float component type.
 
-     The type of Matrix must be the same as Result Type. Each component in
+    The type of Matrix must be the same as Result Type. Each component in
     each column in Matrix is multiplied by Scalar.
 
     Scalar must have the same type as the Component Type in Result Type.
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d43f7a1823e912b..c8b274ceec3e59d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -20,9 +20,6 @@
 using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
 
 static LogicalResult
 verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
@@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
            << pointeeType;
   }
 
-  // The 'Aligned' memory operand requires an alignment literal to follow, which
-  // needs to be implemented on the level of op parsing and (de-)serialization.
-  // TODO: Consider adding support for this attribute value.
-  if (memoryOperand &&
-      spirv::bitEnumContainsAll(memoryOperand.getValue(),
-                                spirv::MemoryAccess::Aligned)) {
-    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  if (memoryOperand) {
+    spirv::MemoryAccess operandSet = memoryOperand.getValue();
+
+    if (isa(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerAvailable)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerAvailable'");
+    }
+
+    if (isa(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerVisible)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerVisible'");
+    }
+
+    // The 'Aligned' memory operand requires an alignment literal to follow,
+    // which needs to be implemented on the level of op parsing and
+    // (de-)serialization.
+    // TODO: Consider adding support for this attribute value.
+    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                  spirv::MemoryAccess::Aligned)) {
+      return op->emitOpError("has unhandled memory operand 'Aligned'");
+    }
   }
 
   // TODO: Verify the memory object behind the pointer:
@@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
                                 getResult().getType(), getMemoryOperandAttr());
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6ebd8515caf037d..6cd75ee6d9cba48 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include 
 #include 
 #include 
@@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::MatrixTimesScalarOp::verify() {
-  if (auto inputCoopmat = llvm::dyn_cast(
-          getMatrix().getType())) {
-    if (inputCoopmat.getElementType() != getScalar().getType())
-      return emitError("input matrix components' type and scaling value must "
-                       "have the same type");
-    return success();
-  }
+  Type elementType =
+      llvm::TypeSwitch(getMatrix().getType())
+          .Case(
+              [](auto matrixType) { return matrixType.getElementType(); })
+          .Default([](Type) { return nullptr; });
+
+  assert(elementType && "Unhandld type");
 
   // Check that the scalar type is the same as the matrix element type.
-  auto inputMatrix = llvm::cast(getMatrix().getType());
-  if (getScalar().getType() != inputMatrix.getElementType())
-    return emitError("input matrix components' type and scaling value must "
-                     "have the same type");
+  if (getScalar().getType() != elementType)
+    return emitOpError("input matrix components' type and scaling value must "
+                       "have the same type");
 
   return success();
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
similarity index 68%
rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 3adcd711f74a8f8..445ab8a48d3ce64 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// CooperativeMatrix (KHR)
+// CooperativeMatrix (KHR) extension ops.
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @cooperative_matrix_length
@@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
@@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32,
+                                                 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+    !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
@@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1
 // -----
 
 //===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
+// Standard ops that can be used CooperativeMatrix types
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @cooperative_matrix_load
-spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  spirv.Return
-}
+!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_memaccess
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
-spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+// These tests are kept in the same order as the list of compatible ops in the
+// SPV_KHR_cooperative_matrix extension spec.
 
-// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+// CHECK-LABEL: @snegate
+spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SNegate %a : !matA_i32
+  %q = spirv.SNegate %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fnegate
+spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FNegate %a : !matA_f32
+  %q = spirv.FNegate %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_length
-spirv.func @cooperative_matrix_length() -> i32 "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @iadd
+spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IAdd %a, %a : !matA_i32
+  %q = spirv.IAdd %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fadd
+spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FAdd %a, %a : !matA_f32
+  %q = spirv.FAdd %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @isub
+spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.ISub %a, %a : !matA_i32
+  %q = spirv.ISub %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fsub
+spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FSub %a, %a : !matA_f32
+  %q = spirv.FSub %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fmul
+spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FMul %a, %a : !matA_f32
+  %q = spirv.FMul %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @imul
+spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IMul %a, %a : !matA_i32
+  %q = spirv.IMul %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @fdiv
+spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FDiv %a, %a : !matA_f32
+  %q = spirv.FDiv %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @sdiv
+spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SDiv %a, %a : !matA_i32
+  %q = spirv.SDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" {
-  %0 = spirv.Constant 0: i32
-  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32
-  %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32
-  spirv.ReturnValue %1 : !spirv.ptr
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @udiv
+spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.UDiv %a, %a : !matA_i32
+  %q = spirv.UDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup...

@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2023

@llvm/pr-subscribers-mlir

Changes
  • Check MakePointer* load/store attribute values.
  • Support coop matrix types in MatrixTimesScalar verification.
  • Add test cases for all the remaining ops that accept coop matrix types.
  • Split NV and KHR tests.
    --

Patch is 33.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66137.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td (+2-3)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+29-10)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+12-11)
  • (renamed) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+114-119)
  • (added) mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir (+177)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a21fc0ce2f9299c..a055cadc756a7e6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
-    Result Type must be an OpTypeMatrix whose Column Type is a vector of
-    floating-point type.
+    Result Type must be a matrix type with a float component type.
 
-     The type of Matrix must be the same as Result Type. Each component in
+    The type of Matrix must be the same as Result Type. Each component in
     each column in Matrix is multiplied by Scalar.
 
     Scalar must have the same type as the Component Type in Result Type.
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d43f7a1823e912b..c8b274ceec3e59d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -20,9 +20,6 @@
 using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
 
 static LogicalResult
 verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
@@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
            << pointeeType;
   }
 
-  // The 'Aligned' memory operand requires an alignment literal to follow, which
-  // needs to be implemented on the level of op parsing and (de-)serialization.
-  // TODO: Consider adding support for this attribute value.
-  if (memoryOperand &&
-      spirv::bitEnumContainsAll(memoryOperand.getValue(),
-                                spirv::MemoryAccess::Aligned)) {
-    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  if (memoryOperand) {
+    spirv::MemoryAccess operandSet = memoryOperand.getValue();
+
+    if (isa(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerAvailable)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerAvailable'");
+    }
+
+    if (isa(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerVisible)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerVisible'");
+    }
+
+    // The 'Aligned' memory operand requires an alignment literal to follow,
+    // which needs to be implemented on the level of op parsing and
+    // (de-)serialization.
+    // TODO: Consider adding support for this attribute value.
+    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                  spirv::MemoryAccess::Aligned)) {
+      return op->emitOpError("has unhandled memory operand 'Aligned'");
+    }
   }
 
   // TODO: Verify the memory object behind the pointer:
@@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
                                 getResult().getType(), getMemoryOperandAttr());
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6ebd8515caf037d..6cd75ee6d9cba48 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include 
 #include 
 #include 
@@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::MatrixTimesScalarOp::verify() {
-  if (auto inputCoopmat = llvm::dyn_cast(
-          getMatrix().getType())) {
-    if (inputCoopmat.getElementType() != getScalar().getType())
-      return emitError("input matrix components' type and scaling value must "
-                       "have the same type");
-    return success();
-  }
+  Type elementType =
+      llvm::TypeSwitch(getMatrix().getType())
+          .Case(
+              [](auto matrixType) { return matrixType.getElementType(); })
+          .Default([](Type) { return nullptr; });
+
+  assert(elementType && "Unhandld type");
 
   // Check that the scalar type is the same as the matrix element type.
-  auto inputMatrix = llvm::cast(getMatrix().getType());
-  if (getScalar().getType() != inputMatrix.getElementType())
-    return emitError("input matrix components' type and scaling value must "
-                     "have the same type");
+  if (getScalar().getType() != elementType)
+    return emitOpError("input matrix components' type and scaling value must "
+                       "have the same type");
 
   return success();
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
similarity index 68%
rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 3adcd711f74a8f8..445ab8a48d3ce64 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// CooperativeMatrix (KHR)
+// CooperativeMatrix (KHR) extension ops.
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @cooperative_matrix_length
@@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
@@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32,
+                                                 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+    !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
@@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1
 // -----
 
 //===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
+// Standard ops that can be used CooperativeMatrix types
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @cooperative_matrix_load
-spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  spirv.Return
-}
+!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_memaccess
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
-spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+// These tests are kept in the same order as the list of compatible ops in the
+// SPV_KHR_cooperative_matrix extension spec.
 
-// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+// CHECK-LABEL: @snegate
+spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SNegate %a : !matA_i32
+  %q = spirv.SNegate %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fnegate
+spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FNegate %a : !matA_f32
+  %q = spirv.FNegate %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_length
-spirv.func @cooperative_matrix_length() -> i32 "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @iadd
+spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IAdd %a, %a : !matA_i32
+  %q = spirv.IAdd %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fadd
+spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FAdd %a, %a : !matA_f32
+  %q = spirv.FAdd %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @isub
+spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.ISub %a, %a : !matA_i32
+  %q = spirv.ISub %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fsub
+spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FSub %a, %a : !matA_f32
+  %q = spirv.FSub %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fmul
+spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FMul %a, %a : !matA_f32
+  %q = spirv.FMul %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @imul
+spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IMul %a, %a : !matA_i32
+  %q = spirv.IMul %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @fdiv
+spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FDiv %a, %a : !matA_f32
+  %q = spirv.FDiv %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @sdiv
+spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SDiv %a, %a : !matA_i32
+  %q = spirv.SDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" {
-  %0 = spirv.Constant 0: i32
-  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32
-  %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32
-  spirv.ReturnValue %1 : !spirv.ptr
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @udiv
+spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.UDiv %a, %a : !matA_i32
+  %q = spirv.UDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup...

@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2023

@llvm/pr-subscribers-mlir-core

Changes
  • Check MakePointer* load/store attribute values.
  • Support coop matrix types in MatrixTimesScalar verification.
  • Add test cases for all the remaining ops that accept coop matrix types.
  • Split NV and KHR tests.
    --

Patch is 33.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66137.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td (+2-3)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+29-10)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+12-11)
  • (renamed) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+114-119)
  • (added) mlir/test/Dialect/SPIRV/IR/nv-cooperative-matrix-ops.mlir (+177)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index a21fc0ce2f9299c..a055cadc756a7e6 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
-    Result Type must be an OpTypeMatrix whose Column Type is a vector of
-    floating-point type.
+    Result Type must be a matrix type with a float component type.
 
-     The type of Matrix must be the same as Result Type. Each component in
+    The type of Matrix must be the same as Result Type. Each component in
     each column in Matrix is multiplied by Scalar.
 
     Scalar must have the same type as the Component Type in Result Type.
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index d43f7a1823e912b..c8b274ceec3e59d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -20,9 +20,6 @@
 using namespace mlir::spirv::AttrNames;
 
 namespace mlir::spirv {
-//===----------------------------------------------------------------------===//
-// spirv.KHR.CooperativeMatrixLoad
-//===----------------------------------------------------------------------===//
 
 static LogicalResult
 verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
@@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
            << pointeeType;
   }
 
-  // The 'Aligned' memory operand requires an alignment literal to follow, which
-  // needs to be implemented on the level of op parsing and (de-)serialization.
-  // TODO: Consider adding support for this attribute value.
-  if (memoryOperand &&
-      spirv::bitEnumContainsAll(memoryOperand.getValue(),
-                                spirv::MemoryAccess::Aligned)) {
-    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  if (memoryOperand) {
+    spirv::MemoryAccess operandSet = memoryOperand.getValue();
+
+    if (isa(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerAvailable)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerAvailable'");
+    }
+
+    if (isa(op) &&
+        spirv::bitEnumContainsAll(operandSet,
+                                  spirv::MemoryAccess::MakePointerVisible)) {
+      return op->emitOpError(
+          "not compatible with memory operand 'MakePointerVisible'");
+    }
+
+    // The 'Aligned' memory operand requires an alignment literal to follow,
+    // which needs to be implemented on the level of op parsing and
+    // (de-)serialization.
+    // TODO: Consider adding support for this attribute value.
+    if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                  spirv::MemoryAccess::Aligned)) {
+      return op->emitOpError("has unhandled memory operand 'Aligned'");
+    }
   }
 
   // TODO: Verify the memory object behind the pointer:
@@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.KHR.CooperativeMatrixLoad
+//===----------------------------------------------------------------------===//
+
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
   return verifyCoopMatrixAccess(*this, getPointer().getType(),
                                 getResult().getType(), getMemoryOperandAttr());
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 6ebd8515caf037d..6cd75ee6d9cba48 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -34,6 +34,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include 
 #include 
 #include 
@@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::MatrixTimesScalarOp::verify() {
-  if (auto inputCoopmat = llvm::dyn_cast(
-          getMatrix().getType())) {
-    if (inputCoopmat.getElementType() != getScalar().getType())
-      return emitError("input matrix components' type and scaling value must "
-                       "have the same type");
-    return success();
-  }
+  Type elementType =
+      llvm::TypeSwitch(getMatrix().getType())
+          .Case(
+              [](auto matrixType) { return matrixType.getElementType(); })
+          .Default([](Type) { return nullptr; });
+
+  assert(elementType && "Unhandld type");
 
   // Check that the scalar type is the same as the matrix element type.
-  auto inputMatrix = llvm::cast(getMatrix().getType());
-  if (getScalar().getType() != inputMatrix.getElementType())
-    return emitError("input matrix components' type and scaling value must "
-                     "have the same type");
+  if (getScalar().getType() != elementType)
+    return emitOpError("input matrix components' type and scaling value must "
+                       "have the same type");
 
   return success();
 }
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
similarity index 68%
rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index 3adcd711f74a8f8..445ab8a48d3ce64 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -1,7 +1,7 @@
 // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s
 
 //===----------------------------------------------------------------------===//
-// CooperativeMatrix (KHR)
+// CooperativeMatrix (KHR) extension ops.
 //===----------------------------------------------------------------------===//
 
 // CHECK-LABEL: @cooperative_matrix_length
@@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
   %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
@@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32,
+                                                 %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+    !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32,
                                      %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
@@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1
 // -----
 
 //===----------------------------------------------------------------------===//
-// NV.CooperativeMatrix
+// Standard ops that can be used CooperativeMatrix types
 //===----------------------------------------------------------------------===//
 
-// CHECK-LABEL: @cooperative_matrix_load
-spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
-  spirv.Return
-}
+!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
+!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_memaccess
-spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>
+!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB>
 
-// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type
-spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.Return
-}
+// These tests are kept in the same order as the list of compatible ops in the
+// SPV_KHR_cooperative_matrix extension spec.
 
-// CHECK-LABEL: @cooperative_matrix_store
-spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup>
+// CHECK-LABEL: @snegate
+spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SNegate %a : !matA_i32
+  %q = spirv.SNegate %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_store_memaccess
-spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" {
-  // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fnegate
+spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FNegate %a : !matA_f32
+  %q = spirv.FNegate %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_length
-spirv.func @cooperative_matrix_length() -> i32 "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  spirv.ReturnValue %0 : i32
-}
-
-// CHECK-LABEL: @cooperative_matrix_muladd
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}  : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @iadd
+spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IAdd %a, %a : !matA_i32
+  %q = spirv.IAdd %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_add
-spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fadd
+spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FAdd %a, %a : !matA_f32
+  %q = spirv.FAdd %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sub
-spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @isub
+spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.ISub %a, %a : !matA_i32
+  %q = spirv.ISub %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_sdiv
-spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fsub
+spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FSub %a, %a : !matA_f32
+  %q = spirv.FSub %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_udiv
-spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
-  %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
+// CHECK-LABEL: @fmul
+spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FMul %a, %a : !matA_f32
+  %q = spirv.FMul %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fadd
-spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @imul
+spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.IMul %a, %a : !matA_i32
+  %q = spirv.IMul %b, %b : !matB_i32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fsub
-spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @fdiv
+spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" {
+  // CHECK:      spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.FDiv %a, %a : !matA_f32
+  %q = spirv.FDiv %b, %b : !matB_f32
   spirv.Return
 }
 
-// CHECK-LABEL: @cooperative_matrix_fdiv
-spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" {
-  // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
-  %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>
+// CHECK-LABEL: @sdiv
+spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.SDiv %a, %a : !matA_i32
+  %q = spirv.SDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-// CHECK-LABEL: @cooperative_matrix_access_chain
-spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" {
-  %0 = spirv.Constant 0: i32
-  // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32
-  %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32
-  spirv.ReturnValue %1 : !spirv.ptr
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
+// CHECK-LABEL: @udiv
+spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" {
+  // CHECK:      spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix
+  %p = spirv.UDiv %a, %a : !matA_i32
+  %q = spirv.UDiv %b, %b : !matB_i32
   spirv.Return
 }
 
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup>
-  spirv.Return
-}
-
-// -----
-
-spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" {
-  // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}}
-  %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup...

@kuhar kuhar merged commit 4c4bdf0 into llvm:main Sep 12, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…vm#66137)

- Check `MakePointer*` load/store attribute values.
- Support coop matrix types in `MatrixTimesScalar` verification.
- Add test cases for all the remaining ops that accept coop matrix types.
- Split NV and KHR tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:spirv mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants