Skip to content

[mlir][spirv] Improve coop matrix attribute handling #66020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 12, 2023
Merged

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 11, 2023

  • Fix values of Matrix Operand bit enums.
  • Add verification for the aligned Memory Operand attributes. Mark the 'Aligned' enumerant as not supported.

The target test passes validation with spirv-val.

@kuhar kuhar requested a review from a team as a code owner September 11, 2023 22:15
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:spirv mlir labels Sep 11, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir-spirv

Changes
  • Fix values of Matrix Operand bit enums.
  • Add verification for the aligned Memory Operand attributes. Mark the 'Aligned' enumerant as not supported.

The target test passes validation with spirv-val.

Full diff: https://github.com/llvm/llvm-project/pull/66020.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+5-5)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-7)
  • (modified) mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir (+28)
  • (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+19-2)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2ce3ad875fa45d1..1013cbc8ca562b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4077,11 +4077,11 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
 
 // Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
 def SPIRV_KHR_CMO_None           : I32BitEnumAttrCaseNone<"None">;
-def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
-def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
-def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
-def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
-def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 16>;
+def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 0>;
+def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 1>;
+def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 2>;
+def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 3>;
+def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 4>;
 
 def SPIRV_KHR_CooperativeMatrixOperandsAttr :
     SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 77dbf130c777857..d43f7a1823e912b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "llvm/ADT/STLExtras.h"
@@ -23,16 +24,26 @@ namespace mlir::spirv {
 // spirv.KHR.CooperativeMatrixLoad
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
-                                                    Type coopMatrix) {
+static LogicalResult
+verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
+                       spirv::MemoryAccessAttr memoryOperand) {
   auto pointerType = cast(pointer);
   Type pointeeType = pointerType.getPointeeType();
   if (!isa(pointeeType)) {
-    return op->emitError(
+    return op->emitOpError(
                "Pointer must point to a scalar or vector type but provided ")
            << pointeeType;
   }
 
+  // The 'Aligned' memory operand requires an alignment literal to follow, which
+  // needs to be implemented on the level of op parsing and (de-)serialization.
+  // TODO: Consider adding support for this attribute value.
+  if (memoryOperand &&
+      spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
+    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  }
+
   // TODO: Verify the memory object behind the pointer:
   // > If the Shader capability was declared, Pointer must point into an array
   // > and any ArrayStride decoration on Pointer is ignored.
@@ -41,8 +52,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
 }
 
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getResult().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getResult().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
@@ -50,8 +61,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getObject().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getObject().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 40736367520e843..3adcd711f74a8f8 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -136,6 +136,24 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr, %stride : i32,
                                                   %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{expected ','}}
@@ -166,6 +184,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+    !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 8546172f4f797b5..153ff4793797267 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -37,6 +37,10 @@ spirv.module Logical GLSL450 requires
     // CHECK-SAME:   : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride,  :
       !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+    // CHECK-NEXT:  spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, , 
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+      !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.Return
   }
 
@@ -62,16 +66,29 @@ spirv.module Logical GLSL450 requires
                                                         !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                         -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
     // CHECK-SAME:   !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
     // CHECK-SAME:   !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
     // CHECK-SAME:   -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
     %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                            : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                                       !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                                       -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
+    %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
                                             : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                                        !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                        -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // TODO: Handle multiple matrix operands and add relevant testcases here.
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd
+    // CHECK-SAME:            :
+    %s = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                            :
+                                           !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                           !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                           -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
     spirv.Return
   }
 

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir-core

Changes
  • Fix values of Matrix Operand bit enums.
  • Add verification for the aligned Memory Operand attributes. Mark the 'Aligned' enumerant as not supported.

The target test passes validation with spirv-val.

Full diff: https://github.com/llvm/llvm-project/pull/66020.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+5-5)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-7)
  • (modified) mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir (+28)
  • (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+19-2)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2ce3ad875fa45d1..1013cbc8ca562b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4077,11 +4077,11 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
 
 // Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
 def SPIRV_KHR_CMO_None           : I32BitEnumAttrCaseNone<"None">;
-def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
-def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
-def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
-def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
-def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 16>;
+def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 0>;
+def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 1>;
+def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 2>;
+def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 3>;
+def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 4>;
 
 def SPIRV_KHR_CooperativeMatrixOperandsAttr :
     SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 77dbf130c777857..d43f7a1823e912b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "llvm/ADT/STLExtras.h"
@@ -23,16 +24,26 @@ namespace mlir::spirv {
 // spirv.KHR.CooperativeMatrixLoad
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
-                                                    Type coopMatrix) {
+static LogicalResult
+verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
+                       spirv::MemoryAccessAttr memoryOperand) {
   auto pointerType = cast(pointer);
   Type pointeeType = pointerType.getPointeeType();
   if (!isa(pointeeType)) {
-    return op->emitError(
+    return op->emitOpError(
                "Pointer must point to a scalar or vector type but provided ")
            << pointeeType;
   }
 
+  // The 'Aligned' memory operand requires an alignment literal to follow, which
+  // needs to be implemented on the level of op parsing and (de-)serialization.
+  // TODO: Consider adding support for this attribute value.
+  if (memoryOperand &&
+      spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
+    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  }
+
   // TODO: Verify the memory object behind the pointer:
   // > If the Shader capability was declared, Pointer must point into an array
   // > and any ArrayStride decoration on Pointer is ignored.
@@ -41,8 +52,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
 }
 
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getResult().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getResult().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
@@ -50,8 +61,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getObject().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getObject().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 40736367520e843..3adcd711f74a8f8 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -136,6 +136,24 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr, %stride : i32,
                                                   %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{expected ','}}
@@ -166,6 +184,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+    !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 8546172f4f797b5..153ff4793797267 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -37,6 +37,10 @@ spirv.module Logical GLSL450 requires
     // CHECK-SAME:   : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride,  :
       !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+    // CHECK-NEXT:  spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, , 
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+      !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.Return
   }
 
@@ -62,16 +66,29 @@ spirv.module Logical GLSL450 requires
                                                         !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                         -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
     // CHECK-SAME:   !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
     // CHECK-SAME:   !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
     // CHECK-SAME:   -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
     %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                            : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                                       !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                                       -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
+    %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
                                             : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                                        !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                        -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // TODO: Handle multiple matrix operands and add relevant testcases here.
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd
+    // CHECK-SAME:            :
+    %s = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                            :
+                                           !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                           !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                           -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
     spirv.Return
   }
 

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2023

@llvm/pr-subscribers-mlir

Changes
  • Fix values of Matrix Operand bit enums.
  • Add verification for the aligned Memory Operand attributes. Mark the 'Aligned' enumerant as not supported.

The target test passes validation with spirv-val.

Full diff: https://github.com/llvm/llvm-project/pull/66020.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+5-5)
  • (modified) mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp (+18-7)
  • (modified) mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir (+28)
  • (modified) mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir (+19-2)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 2ce3ad875fa45d1..1013cbc8ca562b7 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4077,11 +4077,11 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
 
 // Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
 def SPIRV_KHR_CMO_None           : I32BitEnumAttrCaseNone<"None">;
-def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
-def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
-def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
-def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
-def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 16>;
+def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 0>;
+def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 1>;
+def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 2>;
+def SPIRV_KHR_CMO_Result_Signed  : I32BitEnumAttrCaseBit<"ResultSigned", 3>;
+def SPIRV_KHR_CMO_AccSat         : I32BitEnumAttrCaseBit<"AccSat", 4>;
 
 def SPIRV_KHR_CooperativeMatrixOperandsAttr :
     SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",
diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
index 77dbf130c777857..d43f7a1823e912b 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "SPIRVParsingUtils.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "llvm/ADT/STLExtras.h"
@@ -23,16 +24,26 @@ namespace mlir::spirv {
 // spirv.KHR.CooperativeMatrixLoad
 //===----------------------------------------------------------------------===//
 
-static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
-                                                    Type coopMatrix) {
+static LogicalResult
+verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
+                       spirv::MemoryAccessAttr memoryOperand) {
   auto pointerType = cast(pointer);
   Type pointeeType = pointerType.getPointeeType();
   if (!isa(pointeeType)) {
-    return op->emitError(
+    return op->emitOpError(
                "Pointer must point to a scalar or vector type but provided ")
            << pointeeType;
   }
 
+  // The 'Aligned' memory operand requires an alignment literal to follow, which
+  // needs to be implemented on the level of op parsing and (de-)serialization.
+  // TODO: Consider adding support for this attribute value.
+  if (memoryOperand &&
+      spirv::bitEnumContainsAll(memoryOperand.getValue(),
+                                spirv::MemoryAccess::Aligned)) {
+    return op->emitOpError("has unhandled memory operand 'Aligned'");
+  }
+
   // TODO: Verify the memory object behind the pointer:
   // > If the Shader capability was declared, Pointer must point into an array
   // > and any ArrayStride decoration on Pointer is ignored.
@@ -41,8 +52,8 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
 }
 
 LogicalResult KHRCooperativeMatrixLoadOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getResult().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getResult().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
@@ -50,8 +61,8 @@ LogicalResult KHRCooperativeMatrixLoadOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult KHRCooperativeMatrixStoreOp::verify() {
-  return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
-                                        getObject().getType());
+  return verifyCoopMatrixAccess(*this, getPointer().getType(),
+                                getObject().getType(), getMemoryOperandAttr());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
index 40736367520e843..3adcd711f74a8f8 100644
--- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir
@@ -136,6 +136,24 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
+spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, ,  :
+    !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr, %stride : i32,
                                                   %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
   // expected-error @+1 {{expected ','}}
@@ -166,6 +184,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32,
+                                     %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
+  // expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
+  spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+    !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
+  spirv.Return
+}
+
+// -----
+
 spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
                                       %b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
                                       %c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {
diff --git a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
index 8546172f4f797b5..153ff4793797267 100644
--- a/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir
@@ -37,6 +37,10 @@ spirv.module Logical GLSL450 requires
     // CHECK-SAME:   : !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride,  :
       !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
+
+    // CHECK-NEXT:  spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, , 
+    spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, ,  :
+      !spirv.ptr, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
     spirv.Return
   }
 
@@ -62,16 +66,29 @@ spirv.module Logical GLSL450 requires
                                                         !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                         -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
     // CHECK-SAME:   !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
     // CHECK-SAME:   !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
     // CHECK-SAME:   -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
     %q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                            : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                                       !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                                       -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}},  :
+    %r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
                                             : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
                                                        !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
                                                        -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
 
-    // TODO: Handle multiple matrix operands and add relevant testcases here.
+    // CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd
+    // CHECK-SAME:            :
+    %s = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
+                                            :
+                                           !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
+                                           !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
+                                           -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
+
     spirv.Return
   }
 

@kuhar
Copy link
Member Author

kuhar commented Sep 11, 2023

Seems like the failure is just buildkite having a bad day.

- Fix values of Matrix Operand bit enums
- Add verification for the aligned Memory Operand attributes. Mark the
  'Aligned' enumerant as not supported.

The target test passes vaidation with `spirv-val`.
@kuhar
Copy link
Member Author

kuhar commented Sep 12, 2023

(Rebased to retrigger CI)

@kuhar kuhar merged commit 08425de into llvm:main Sep 12, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
- Fix values of Matrix Operand bit enums.
- Add verification for the aligned Memory Operand attributes. Mark the
'Aligned' enumerant as not supported.

The target test passes validation with `spirv-val`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:spirv mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants