Skip to content

Commit 4c4bdf0

Browse files
authored
[mlir][spirv] Fix remaining coop matrix verification corner cases (llvm#66137)
- Check `MakePointer*` load/store attribute values. - Support coop matrix types in `MatrixTimesScalar` verification. - Add test cases for all the remaining ops that accept coop matrix types. - Split NV and KHR tests.
1 parent a9c7ba9 commit 4c4bdf0

File tree

5 files changed

+334
-143
lines changed

5 files changed

+334
-143
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
7575
let summary = "Scale a floating-point matrix.";
7676

7777
let description = [{
78-
Result Type must be an OpTypeMatrix whose Column Type is a vector of
79-
floating-point type.
78+
Result Type must be a matrix type with a float component type.
8079

81-
The type of Matrix must be the same as Result Type. Each component in
80+
The type of Matrix must be the same as Result Type. Each component in
8281
each column in Matrix is multiplied by Scalar.
8382

8483
Scalar must have the same type as the Component Type in Result Type.

mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
using namespace mlir::spirv::AttrNames;
2121

2222
namespace mlir::spirv {
23-
//===----------------------------------------------------------------------===//
24-
// spirv.KHR.CooperativeMatrixLoad
25-
//===----------------------------------------------------------------------===//
2623

2724
static LogicalResult
2825
verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
@@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
3532
<< pointeeType;
3633
}
3734

38-
// The 'Aligned' memory operand requires an alignment literal to follow, which
39-
// needs to be implemented on the level of op parsing and (de-)serialization.
40-
// TODO: Consider adding support for this attribute value.
41-
if (memoryOperand &&
42-
spirv::bitEnumContainsAll(memoryOperand.getValue(),
43-
spirv::MemoryAccess::Aligned)) {
44-
return op->emitOpError("has unhandled memory operand 'Aligned'");
35+
if (memoryOperand) {
36+
spirv::MemoryAccess operandSet = memoryOperand.getValue();
37+
38+
if (isa<spirv::KHRCooperativeMatrixLoadOp>(op) &&
39+
spirv::bitEnumContainsAll(operandSet,
40+
spirv::MemoryAccess::MakePointerAvailable)) {
41+
return op->emitOpError(
42+
"not compatible with memory operand 'MakePointerAvailable'");
43+
}
44+
45+
if (isa<spirv::KHRCooperativeMatrixStoreOp>(op) &&
46+
spirv::bitEnumContainsAll(operandSet,
47+
spirv::MemoryAccess::MakePointerVisible)) {
48+
return op->emitOpError(
49+
"not compatible with memory operand 'MakePointerVisible'");
50+
}
51+
52+
// The 'Aligned' memory operand requires an alignment literal to follow,
53+
// which needs to be implemented on the level of op parsing and
54+
// (de-)serialization.
55+
// TODO: Consider adding support for this attribute value.
56+
if (spirv::bitEnumContainsAll(memoryOperand.getValue(),
57+
spirv::MemoryAccess::Aligned)) {
58+
return op->emitOpError("has unhandled memory operand 'Aligned'");
59+
}
4560
}
4661

4762
// TODO: Verify the memory object behind the pointer:
@@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix,
5166
return success();
5267
}
5368

69+
//===----------------------------------------------------------------------===//
70+
// spirv.KHR.CooperativeMatrixLoad
71+
//===----------------------------------------------------------------------===//
72+
5473
LogicalResult KHRCooperativeMatrixLoadOp::verify() {
5574
return verifyCoopMatrixAccess(*this, getPointer().getType(),
5675
getResult().getType(), getMemoryOperandAttr());

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "llvm/ADT/ArrayRef.h"
3535
#include "llvm/ADT/STLExtras.h"
3636
#include "llvm/ADT/StringExtras.h"
37+
#include "llvm/ADT/TypeSwitch.h"
3738
#include <cassert>
3839
#include <numeric>
3940
#include <optional>
@@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() {
16041605
//===----------------------------------------------------------------------===//
16051606

16061607
LogicalResult spirv::MatrixTimesScalarOp::verify() {
1607-
if (auto inputCoopmat = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(
1608-
getMatrix().getType())) {
1609-
if (inputCoopmat.getElementType() != getScalar().getType())
1610-
return emitError("input matrix components' type and scaling value must "
1611-
"have the same type");
1612-
return success();
1613-
}
1608+
Type elementType =
1609+
llvm::TypeSwitch<Type, Type>(getMatrix().getType())
1610+
.Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
1611+
spirv::MatrixType>(
1612+
[](auto matrixType) { return matrixType.getElementType(); })
1613+
.Default([](Type) { return nullptr; });
1614+
1615+
assert(elementType && "Unhandled type");
16141616

16151617
// Check that the scalar type is the same as the matrix element type.
1616-
auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1617-
if (getScalar().getType() != inputMatrix.getElementType())
1618-
return emitError("input matrix components' type and scaling value must "
1619-
"have the same type");
1618+
if (getScalar().getType() != elementType)
1619+
return emitOpError("input matrix components' type and scaling value must "
1620+
"have the same type");
16201621

16211622
return success();
16221623
}

0 commit comments

Comments
 (0)