20
20
using namespace mlir ::spirv::AttrNames;
21
21
22
22
namespace mlir ::spirv {
23
- // ===----------------------------------------------------------------------===//
24
- // spirv.KHR.CooperativeMatrixLoad
25
- // ===----------------------------------------------------------------------===//
26
23
27
24
static LogicalResult
28
25
verifyCoopMatrixAccess (Operation *op, Type pointer, Type coopMatrix,
@@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
35
32
<< pointeeType;
36
33
}
37
34
38
- // The 'Aligned' memory operand requires an alignment literal to follow, which
39
- // needs to be implemented on the level of op parsing and (de-)serialization.
40
- // TODO: Consider adding support for this attribute value.
41
- if (memoryOperand &&
42
- spirv::bitEnumContainsAll (memoryOperand.getValue (),
43
- spirv::MemoryAccess::Aligned)) {
44
- return op->emitOpError (" has unhandled memory operand 'Aligned'" );
35
+ if (memoryOperand) {
36
+ spirv::MemoryAccess operandSet = memoryOperand.getValue ();
37
+
38
+ if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
39
+ spirv::bitEnumContainsAll (operandSet,
40
+ spirv::MemoryAccess::MakePointerAvailable)) {
41
+ return op->emitOpError (
42
+ " not compatible with memory operand 'MakePointerAvailable'" );
43
+ }
44
+
45
+ if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
46
+ spirv::bitEnumContainsAll (operandSet,
47
+ spirv::MemoryAccess::MakePointerVisible)) {
48
+ return op->emitOpError (
49
+ " not compatible with memory operand 'MakePointerVisible'" );
50
+ }
51
+
52
+ // The 'Aligned' memory operand requires an alignment literal to follow,
53
+ // which needs to be implemented on the level of op parsing and
54
+ // (de-)serialization.
55
+ // TODO: Consider adding support for this attribute value.
56
+ if (spirv::bitEnumContainsAll (memoryOperand.getValue (),
57
+ spirv::MemoryAccess::Aligned)) {
58
+ return op->emitOpError (" has unhandled memory operand 'Aligned'" );
59
+ }
45
60
}
46
61
47
62
// TODO: Verify the memory object behind the pointer:
@@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
51
66
return success ();
52
67
}
53
68
69
+ // ===----------------------------------------------------------------------===//
70
+ // spirv.KHR.CooperativeMatrixLoad
71
+ // ===----------------------------------------------------------------------===//
72
+
54
73
LogicalResult KHRCooperativeMatrixLoadOp::verify () {
55
74
return verifyCoopMatrixAccess (*this , getPointer ().getType (),
56
75
getResult ().getType (), getMemoryOperandAttr ());
0 commit comments