Skip to content

Commit 3c278e5

Browse files
committed
[mlir][spirv] Fix spirv.MatrixTimesScalar for cooperative matrix
spirv.MatrixTimesScalar is allowed to use cooperative matrix. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D139279
1 parent 8a7e69d commit 3c278e5

File tree

5 files changed

+54
-61
lines changed

5 files changed

+54
-61
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4119,6 +4119,9 @@ class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
41194119
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
41204120
SPIRV_CoopMatrixOfType<[type]>]>;
41214121

4122+
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
4123+
AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>;
4124+
41224125
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
41234126
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
41244127

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
7070

7171
// -----
7272

73-
def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure]> {
73+
def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
74+
"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
7475
let summary = "Scale a floating-point matrix.";
7576

7677
let description = [{
@@ -108,18 +109,16 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure]> {
108109
];
109110

110111
let arguments = (ins
111-
SPIRV_AnyMatrix:$matrix,
112+
SPIRV_MatrixOrCoopMatrixOf<SPIRV_Float>:$matrix,
112113
SPIRV_Float:$scalar
113114
);
114115

115116
let results = (outs
116-
SPIRV_AnyMatrix:$result
117+
SPIRV_MatrixOrCoopMatrixOf<SPIRV_Float>:$result
117118
);
118119

119-
// TODO: we need just one matrix type given that the input and result are the
120-
// same and the scalar's type can be deduced from it.
121120
let assemblyFormat = [{
122-
operands attr-dict `:` type($matrix) `,` type($scalar) `->` type($result)
121+
operands attr-dict `:` type($matrix) `,` type($scalar)
123122
}];
124123

125124
let availability = [

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4128,35 +4128,20 @@ LogicalResult spirv::INTELJointMatrixMadOp::verify() {
41284128
//===----------------------------------------------------------------------===//
41294129

41304130
LogicalResult spirv::MatrixTimesScalarOp::verify() {
4131-
// We already checked that result and matrix are both of matrix type in the
4132-
// auto-generated verify method.
4133-
4134-
auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4135-
auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4131+
if (auto inputCoopmat =
4132+
getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) {
4133+
if (inputCoopmat.getElementType() != getScalar().getType())
4134+
return emitError("input matrix components' type and scaling value must "
4135+
"have the same type");
4136+
return success();
4137+
}
41364138

41374139
// Check that the scalar type is the same as the matrix element type.
4140+
auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
41384141
if (getScalar().getType() != inputMatrix.getElementType())
41394142
return emitError("input matrix components' type and scaling value must "
41404143
"have the same type");
41414144

4142-
// Note that the next three checks could be done using the AllTypesMatch
4143-
// trait in the Op definition file but it generates a vague error message.
4144-
4145-
// Check that the input and result matrices have the same columns' count
4146-
if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
4147-
return emitError("input and result matrices must have the same "
4148-
"number of columns");
4149-
4150-
// Check that the input and result matrices' have the same rows count
4151-
if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
4152-
return emitError("input and result matrices' columns must have "
4153-
"the same size");
4154-
4155-
// Check that the input and result matrices' have the same component type
4156-
if (inputMatrix.getElementType() != resultMatrix.getElementType())
4157-
return emitError("input and result matrices' columns must have "
4158-
"the same component type");
4159-
41604145
return success();
41614146
}
41624147

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
22

33
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
4-
// CHECK-LABEL: @matrix_times_scalar
5-
spirv.func @matrix_times_scalar(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
6-
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
7-
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
4+
// CHECK-LABEL: @matrix_times_scalar_1
5+
spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
6+
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
7+
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32
88
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
99
}
1010

11+
// CHECK-LABEL: @matrix_times_scalar_2
12+
spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
13+
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
14+
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
15+
spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
16+
}
17+
1118
// CHECK-LABEL: @matrix_transpose_1
1219
spirv.func @matrix_transpose_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>) -> !spirv.matrix<2 x vector<3xf32>> "None" {
1320
// CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>>
@@ -39,82 +46,74 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
3946

4047
// -----
4148

42-
func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) -> () {
49+
func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) {
4350
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
44-
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16 -> !spirv.matrix<3 x vector<3xf32>>
51+
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16
52+
return
4553
}
4654

4755
// -----
4856

49-
func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) -> () {
57+
func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) {
5058
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
51-
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64 -> !spirv.matrix<3 x vector<3xf32>>
52-
}
53-
54-
// -----
55-
56-
func.func @input_output_component_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
57-
// expected-error @+1 {{input and result matrices' columns must have the same component type}}
58-
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf64>>
59-
}
60-
61-
// -----
62-
63-
func.func @input_output_size_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () {
64-
// expected-error @+1 {{input and result matrices must have the same number of columns}}
65-
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<4 x vector<3xf32>>
59+
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64
60+
return
6661
}
6762

6863
// -----
6964

70-
func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
65+
func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
7166
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
7267
%result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<3 x vector<3xf32>>
73-
spirv.Return
68+
return
7469
}
7570

7671
// -----
7772

78-
func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
73+
func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
7974
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
8075
%result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<2 x vector<4xf32>>
81-
spirv.Return
76+
return
8277
}
8378

8479
// -----
8580

86-
func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () {
81+
func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) {
8782
// expected-error @+1 {{input and output matrices must have the same component type}}
8883
%result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<4 x vector<3xf16>>
89-
spirv.Return
84+
return
9085
}
9186

9287
// -----
9388

9489
func.func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
9590
// expected-error @+1 {{right and result matrices must have equal columns' count}}
9691
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<3 x vector<2xf32>>
92+
return
9793
}
9894

9995
// -----
10096

10197
func.func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){
10298
// expected-error @+1 {{left and result matrices must have equal rows' count}}
10399
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<3xf32>>
100+
return
104101
}
105102

106103
// -----
107104

108105
func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<2xf32>>){
109106
// expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
110107
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<2xf32>> -> !spirv.matrix<2 x vector<2xf32>>
108+
return
111109
}
112110

113111
// -----
114112

115113
func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
116114
// expected-error @+1 {{right and result matrices' component type must be the same}}
117115
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf64>>
116+
return
118117
}
119118

120119

@@ -123,4 +122,5 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
123122
func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
124123
// expected-error @+1 {{left and result matrices' component type must be the same}}
125124
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
125+
return
126126
}

mlir/test/Target/SPIRV/matrix.mlir

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,23 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
1010

1111
// CHECK-LABEL: @matrix_times_scalar_1
1212
spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" {
13-
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
14-
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
13+
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
14+
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32
1515
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
1616
}
1717

1818
// CHECK-LABEL: @matrix_times_scalar_2
1919
spirv.func @matrix_times_scalar_2(%arg0 : !spirv.matrix<3 x vector<3xf16>>, %arg1 : f16) -> !spirv.matrix<3 x vector<3xf16>> "None" {
20-
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf16>>, f16 -> !spirv.matrix<3 x vector<3xf16>>
21-
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf16>>, f16 -> !spirv.matrix<3 x vector<3xf16>>
20+
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf16>>, f16
21+
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf16>>, f16
2222
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf16>>
23+
}
2324

25+
// CHECK-LABEL: @matrix_times_scalar_3
26+
spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" {
27+
// CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
28+
%result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
29+
spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup>
2430
}
2531

2632
// CHECK-LABEL: @matrix_transpose_1

0 commit comments

Comments
 (0)