Skip to content

Commit cc2d5fa

Browse files
authored
[mlir][spirv] Make CooperativeMatrixType a ShapedType (#142784)
This is to enable `CooperativeMatrixType` to be used with `DenseElementsAttr`, so that a `spirv.Constant` can be easily built from `OpConstantComposite`. For example: ```mlir %cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc> ``` Constraints of arithmetic operations are changed, as `SameOperandsAndResultType` can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match. This patch does not enable the actual deserialization. This will be done in a subsequent PR.
1 parent bc5d827 commit cc2d5fa

File tree

4 files changed

+51
-13
lines changed

4 files changed

+51
-13
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
2323
// Operands type same as result type.
2424
SPIRV_BinaryOp<mnemonic, type, type,
2525
!listconcat(traits,
26-
[Pure, SameOperandsAndResultType])> {
26+
[Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> {
2727
// In addition to normal types arithmetic instructions can support cooperative
2828
// matrix.
2929
let arguments = (ins
@@ -42,7 +42,7 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
4242
// Operand type same as result type.
4343
SPIRV_UnaryOp<mnemonic, type, type,
4444
!listconcat(traits,
45-
[Pure, SameOperandsAndResultType])> {
45+
[Pure, AllTypesMatch<["operand", "result"]>])> {
4646
// In addition to normal types arithmetic instructions can support cooperative
4747
// matrix.
4848
let arguments = (ins

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
394394
// SPIR-V KHR cooperative matrix type
395395
class CooperativeMatrixType
396396
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
397-
detail::CooperativeMatrixTypeStorage> {
397+
detail::CooperativeMatrixTypeStorage,
398+
ShapedType::Trait> {
398399
public:
399400
using Base::Base;
400401

@@ -418,6 +419,22 @@ class CooperativeMatrixType
418419
std::optional<StorageClass> storage = std::nullopt);
419420
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
420421
std::optional<StorageClass> storage = std::nullopt);
422+
423+
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
424+
425+
ArrayRef<int64_t> getShape() const;
426+
427+
bool hasRank() const { return true; }
428+
429+
CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
430+
Type elementType) const {
431+
if (!shape)
432+
return get(elementType, getRows(), getColumns(), getScope(), getUse());
433+
434+
assert(shape.value().size() == 2);
435+
return get(elementType, shape.value()[0], shape.value()[1], getScope(),
436+
getUse());
437+
}
421438
};
422439

423440
// SPIR-V matrix type

mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,21 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
194194
//===----------------------------------------------------------------------===//
195195

196196
struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
197+
// In the specification dimensions of the Cooperative Matrix are 32-bit
198+
// integers --- the initial implementation kept those values as such. However,
199+
// the `ShapedType` expects the shape to be `int64_t`. We could keep the shape
200+
// as 32-bits and expose it as int64_t through `getShape`, however, this
201+
// method returns an `ArrayRef`, so returning `ArrayRef<int64_t>` having two
202+
// 32-bits integers would require an extra logic and storage. So, we diverge
203+
// from the spec and internally represent the dimensions as 64-bit integers,
204+
// so we can easily return an `ArrayRef` from `getShape` without any extra
205+
// logic. Alternatively, we could store both rows and columns (both 32-bits)
206+
// and shape (64-bits), assigning rows and columns to shape whenever
207+
// `getShape` is called. This would be at the cost of extra logic and storage.
208+
// Note: Because `ArrayRef` is returned we cannot construct an object in
209+
// `getShape` on the fly.
197210
using KeyTy =
198-
std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
211+
std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
199212

200213
static CooperativeMatrixTypeStorage *
201214
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
@@ -204,17 +217,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
204217
}
205218

206219
bool operator==(const KeyTy &key) const {
207-
return key == KeyTy(elementType, rows, columns, scope, use);
220+
return key == KeyTy(elementType, shape[0], shape[1], scope, use);
208221
}
209222

210223
CooperativeMatrixTypeStorage(const KeyTy &key)
211-
: elementType(std::get<0>(key)), rows(std::get<1>(key)),
212-
columns(std::get<2>(key)), scope(std::get<3>(key)),
224+
: elementType(std::get<0>(key)),
225+
shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
213226
use(std::get<4>(key)) {}
214227

215228
Type elementType;
216-
uint32_t rows;
217-
uint32_t columns;
229+
// [#rows, #columns]
230+
std::array<int64_t, 2> shape;
218231
Scope scope;
219232
CooperativeMatrixUseKHR use;
220233
};
@@ -231,10 +244,18 @@ Type CooperativeMatrixType::getElementType() const {
231244
return getImpl()->elementType;
232245
}
233246

234-
uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
247+
uint32_t CooperativeMatrixType::getRows() const {
248+
assert(getImpl()->shape[0] != ShapedType::kDynamic);
249+
return static_cast<uint32_t>(getImpl()->shape[0]);
250+
}
235251

236252
uint32_t CooperativeMatrixType::getColumns() const {
237-
return getImpl()->columns;
253+
assert(getImpl()->shape[1] != ShapedType::kDynamic);
254+
return static_cast<uint32_t>(getImpl()->shape[1]);
255+
}
256+
257+
ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
258+
return getImpl()->shape;
238259
}
239260

240261
Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }

mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
524524

525525
spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
526526
%b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
527-
// expected-error @+1 {{op requires the same type for all operands and results}}
527+
// expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}}
528528
%q = "spirv.IAdd"(%a, %b) :
529529
(!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
530530
-> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
535535

536536
spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
537537
%b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
538-
// expected-error @+1 {{op requires the same type for all operands and results}}
538+
// expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}}
539539
%q = "spirv.FAdd"(%a, %b) :
540540
(!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
541541
-> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>

0 commit comments

Comments
 (0)