Skip to content

Commit 08425de

Browse files
authored
[mlir][spirv] Improve coop matrix attribute handling (#66020)
- Fix values of Matrix Operand bit enums. - Add verification for the aligned Memory Operand attributes. Mark the 'Aligned' enumerant as not supported. The target test passes validation with `spirv-val`.
1 parent c1796be commit 08425de

File tree

4 files changed

+70
-14
lines changed

4 files changed

+70
-14
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4077,11 +4077,11 @@ def SPIRV_KHR_CooperativeMatrixLayoutAttr :
40774077

40784078
// Cooperative Matrix Operands for the SPV_KHR_cooperative_matrix extension.
40794079
def SPIRV_KHR_CMO_None : I32BitEnumAttrCaseNone<"None">;
4080-
def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 1>;
4081-
def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 2>;
4082-
def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 4>;
4083-
def SPIRV_KHR_CMO_Result_Signed : I32BitEnumAttrCaseBit<"ResultSigned", 8>;
4084-
def SPIRV_KHR_CMO_AccSat : I32BitEnumAttrCaseBit<"AccSat", 16>;
4080+
def SPIRV_KHR_CMO_MatrixA_Signed : I32BitEnumAttrCaseBit<"ASigned", 0>;
4081+
def SPIRV_KHR_CMO_MatrixB_Signed : I32BitEnumAttrCaseBit<"BSigned", 1>;
4082+
def SPIRV_KHR_CMO_MatrixC_Signed : I32BitEnumAttrCaseBit<"CSigned", 2>;
4083+
def SPIRV_KHR_CMO_Result_Signed : I32BitEnumAttrCaseBit<"ResultSigned", 3>;
4084+
def SPIRV_KHR_CMO_AccSat : I32BitEnumAttrCaseBit<"AccSat", 4>;
40854085

40864086
def SPIRV_KHR_CooperativeMatrixOperandsAttr :
40874087
SPIRV_BitEnumAttr<"CooperativeMatrixOperandsKHR",

mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "SPIRVParsingUtils.h"
14+
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
1415
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1516
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1617
#include "llvm/ADT/STLExtras.h"
@@ -23,16 +24,26 @@ namespace mlir::spirv {
2324
// spirv.KHR.CooperativeMatrixLoad
2425
//===----------------------------------------------------------------------===//
2526

26-
static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
27-
Type coopMatrix) {
27+
static LogicalResult
28+
verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
29+
spirv::MemoryAccessAttr memoryOperand) {
2830
auto pointerType = cast<PointerType>(pointer);
2931
Type pointeeType = pointerType.getPointeeType();
3032
if (!isa<ScalarType, VectorType>(pointeeType)) {
31-
return op->emitError(
33+
return op->emitOpError(
3234
"Pointer must point to a scalar or vector type but provided ")
3335
<< pointeeType;
3436
}
3537

38+
// The 'Aligned' memory operand requires an alignment literal to follow, which
39+
// needs to be implemented on the level of op parsing and (de-)serialization.
40+
// TODO: Consider adding support for this attribute value.
41+
if (memoryOperand &&
42+
spirv::bitEnumContainsAll(memoryOperand.getValue(),
43+
spirv::MemoryAccess::Aligned)) {
44+
return op->emitOpError("has unhandled memory operand 'Aligned'");
45+
}
46+
3647
// TODO: Verify the memory object behind the pointer:
3748
// > If the Shader capability was declared, Pointer must point into an array
3849
// > and any ArrayStride decoration on Pointer is ignored.
@@ -41,17 +52,17 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
4152
}
4253

4354
LogicalResult KHRCooperativeMatrixLoadOp::verify() {
44-
return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
45-
getResult().getType());
55+
return verifyCoopMatrixAccess(*this, getPointer().getType(),
56+
getResult().getType(), getMemoryOperandAttr());
4657
}
4758

4859
//===----------------------------------------------------------------------===//
4960
// spirv.KHR.CooperativeMatrixStore
5061
//===----------------------------------------------------------------------===//
5162

5263
LogicalResult KHRCooperativeMatrixStoreOp::verify() {
53-
return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
54-
getObject().getType());
64+
return verifyCoopMatrixAccess(*this, getPointer().getType(),
65+
getObject().getType(), getMemoryOperandAttr());
5566
}
5667

5768
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,24 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr<i32, StorageBuf
136136

137137
// -----
138138

139+
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
140+
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
141+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Aligned> :
142+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
143+
spirv.Return
144+
}
145+
146+
// -----
147+
148+
spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32) "None" {
149+
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
150+
%0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, <ColumnMajor>, <Volatile|Aligned> :
151+
!spirv.ptr<i32, StorageBuffer>, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>
152+
spirv.Return
153+
}
154+
155+
// -----
156+
139157
spirv.func @cooperative_matrix_store_missing_attr(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
140158
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
141159
// expected-error @+1 {{expected ','}}
@@ -166,6 +184,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr<i32, Stor
166184

167185
// -----
168186

187+
spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr<i32, StorageBuffer>, %stride : i32,
188+
%m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" {
189+
// expected-error @+1 {{op has unhandled memory operand 'Aligned'}}
190+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Aligned> :
191+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32
192+
spirv.Return
193+
}
194+
195+
// -----
196+
169197
spirv.func @cooperative_matrix_muladd(%a : !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA>,
170198
%b : !spirv.coopmatrix<16x4xi32, Subgroup, MatrixB>,
171199
%c : !spirv.coopmatrix<8x4xi32, Subgroup, MatrixAcc>) "None" {

mlir/test/Target/SPIRV/khr-cooperative-matrix-ops.mlir

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ spirv.module Logical GLSL450 requires
3737
// CHECK-SAME: : !spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
3838
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor> :
3939
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
40+
41+
// CHECK-NEXT: spirv.KHR.CooperativeMatrixStore {{%.*}}, {{%.*}}, <RowMajor>, <Volatile|Nontemporal>
42+
spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, <RowMajor>, <Volatile|Nontemporal> :
43+
!spirv.ptr<i32, StorageBuffer>, !spirv.coopmatrix<16x8xi32, Workgroup, MatrixA>, i32
4044
spirv.Return
4145
}
4246

@@ -62,16 +66,29 @@ spirv.module Logical GLSL450 requires
6266
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
6367
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
6468

65-
// CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
69+
// CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <ASigned> :
6670
// CHECK-SAME: !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
6771
// CHECK-SAME: !spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
6872
// CHECK-SAME: -> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
6973
%q = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
74+
<ASigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
75+
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
76+
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
77+
78+
// CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}}, <BSigned> :
79+
%r = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
7080
<BSigned> : !spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
7181
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
7282
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
7383

74-
// TODO: Handle multiple matrix operands and add relevant testcases here.
84+
// CHECK-NEXT: {{%.+}} = spirv.KHR.CooperativeMatrixMulAdd
85+
// CHECK-SAME: <ASigned|BSigned|ResultSigned|AccSat> :
86+
%s = spirv.KHR.CooperativeMatrixMulAdd %a, %b, %c,
87+
<ASigned|BSigned|ResultSigned|AccSat> :
88+
!spirv.coopmatrix<8x16xi8, Subgroup, MatrixA>,
89+
!spirv.coopmatrix<16x8xi16, Subgroup, MatrixB>
90+
-> !spirv.coopmatrix<8x8xi32, Subgroup, MatrixAcc>
91+
7592
spirv.Return
7693
}
7794

0 commit comments

Comments
 (0)