Skip to content

Commit a01097f

Browse files
authored
[mlir][spirv] Add definition for VectorTimesMatrixOp (#124571)
Adding op as defined in section 3.52.13. (Arithmetic Instructions) of the SPIR-V specification.
1 parent f226cab commit a01097f

File tree

5 files changed

+144
-16
lines changed

5 files changed

+144
-16
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4234,8 +4234,13 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
42344234
"::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
42354235
"Cooperative Matrix">;
42364236

4237+
class SPIRV_MatrixOfType<list<Type> allowedTypes> :
4238+
ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsMatrixType,
4239+
"::llvm::cast<::mlir::spirv::MatrixType>($_self).getElementType()",
4240+
"Matrix">;
4241+
42374242
class SPIRV_VectorOf<Type type> :
4238-
VectorOfLengthAndType<[2, 3, 4, 8,16], [type]>;
4243+
VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>;
42394244

42404245
class SPIRV_ScalarOrVectorOf<Type type> :
42414246
AnyTypeOf<[type, SPIRV_VectorOf<type>]>;
@@ -4248,6 +4253,9 @@ class SPIRV_MatrixOrCoopMatrixOf<Type type> :
42484253
AnyTypeOf<[SPIRV_AnyMatrix,
42494254
SPIRV_CoopMatrixOfType<[type]>]>;
42504255

4256+
class SPIRV_MatrixOf<Type type> :
4257+
SPIRV_MatrixOfType<[type]>;
4258+
42514259
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
42524260
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
42534261

@@ -4387,7 +4395,8 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
43874395
def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
43884396
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
43894397
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
4390-
def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
4398+
def SPIRV_OC_OpVectorTimesMatrix : I32EnumAttrCase<"OpVectorTimesMatrix", 144>;
4399+
def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
43914400
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
43924401
def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
43934402
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4559,7 +4568,8 @@ def SPIRV_OpcodeAttr :
45594568
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
45604569
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
45614570
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
4562-
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
4571+
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
4572+
SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector,
45634573
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
45644574
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
45654575
SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
6363

6464
// -----
6565

66-
def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
67-
"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
66+
def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
6867
let summary = "Scale a floating-point matrix.";
6968

7069
let description = [{
@@ -114,8 +113,11 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
114113

115114
// -----
116115

117-
def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
118-
let summary = "Linear-algebraic multiply of matrix X vector.";
116+
def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [
117+
Pure,
118+
AllElementTypesMatch<["vector", "result"]>
119+
]> {
120+
let summary = "Linear-algebraic Matrix X Vector.";
119121

120122
let description = [{
121123
Result Type must be a vector of floating-point type.
@@ -140,12 +142,12 @@ def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
140142
];
141143

142144
let arguments = (ins
143-
SPIRV_AnyMatrix:$matrix,
144-
SPIRV_AnyVector:$vector
145+
SPIRV_MatrixOf<SPIRV_Float>:$matrix,
146+
SPIRV_VectorOf<SPIRV_Float>:$vector
145147
);
146148

147149
let results = (outs
148-
SPIRV_AnyVector:$result
150+
SPIRV_VectorOf<SPIRV_Float>:$result
149151
);
150152

151153
let assemblyFormat = [{
@@ -198,4 +200,53 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
198200

199201
// -----
200202

203+
def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [
204+
Pure,
205+
AllElementTypesMatch<["vector", "result"]>
206+
]> {
207+
let summary = "Linear-algebraic Vector X Matrix.";
208+
209+
let description = [{
210+
Result Type must be a vector of floating-point type.
211+
212+
Vector must be a vector with the same Component Type as the Component
213+
Type in Result Type. Its number of components must equal the number of
214+
components in each column in Matrix.
215+
216+
Matrix must be a matrix with the same Component Type as the Component
217+
Type in Result Type. Its number of columns must equal the number of
218+
components in Result Type.
219+
220+
<!-- End of AutoGen section -->
221+
222+
#### Example:
223+
224+
```mlir
225+
%result = spirv.VectorTimesMatrix %vector, %matrix : vector<4xf32>, !spirv.matrix<4 x vector<4xf32>> -> vector<4xf32>
226+
```
227+
}];
228+
229+
let availability = [
230+
MinVersion<SPIRV_V_1_0>,
231+
MaxVersion<SPIRV_V_1_6>,
232+
Extension<[]>,
233+
Capability<[SPIRV_C_Matrix]>
234+
];
235+
236+
let arguments = (ins
237+
SPIRV_VectorOf<SPIRV_Float>:$vector,
238+
SPIRV_MatrixOf<SPIRV_Float>:$matrix
239+
);
240+
241+
let results = (outs
242+
SPIRV_VectorOf<SPIRV_Float>:$result
243+
);
244+
245+
let assemblyFormat = [{
246+
operands attr-dict `:` type($vector) `,` type($matrix) `->` type($result)
247+
}];
248+
}
249+
250+
// -----
251+
201252
#endif // MLIR_DIALECT_SPIRV_IR_MATRIX_OPS

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,10 +1717,32 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
17171717
<< resultType.getNumElements() << ") must match the matrix rows ("
17181718
<< matrixType.getNumRows() << ")";
17191719

1720-
auto matrixElementType = matrixType.getElementType();
1721-
if (matrixElementType != vectorType.getElementType() ||
1722-
matrixElementType != resultType.getElementType())
1723-
return emitOpError("matrix, vector, and result element types must match");
1720+
if (matrixType.getElementType() != resultType.getElementType())
1721+
return emitOpError("matrix and result element types must match");
1722+
1723+
return success();
1724+
}
1725+
1726+
//===----------------------------------------------------------------------===//
1727+
// spirv.VectorTimesMatrix
1728+
//===----------------------------------------------------------------------===//
1729+
1730+
LogicalResult spirv::VectorTimesMatrixOp::verify() {
1731+
auto vectorType = llvm::cast<VectorType>(getVector().getType());
1732+
auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1733+
auto resultType = llvm::cast<VectorType>(getType());
1734+
1735+
if (matrixType.getNumRows() != vectorType.getNumElements())
1736+
return emitOpError("number of components in vector must equal the number "
1737+
"of components in each column in matrix");
1738+
1739+
if (resultType.getNumElements() != matrixType.getNumColumns())
1740+
return emitOpError("number of columns in matrix must equal the number of "
1741+
"components in result");
1742+
1743+
if (matrixType.getElementType() != resultType.getElementType())
1744+
return emitOpError("matrix must be a matrix with the same component type "
1745+
"as the component type in result");
17241746

17251747
return success();
17261748
}

mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
3636
spirv.ReturnValue %result : vector<4xf32>
3737
}
3838

39+
// CHECK-LABEL: @vector_times_matrix_1
40+
spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
41+
// CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
42+
%result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
43+
spirv.ReturnValue %result : vector<4xf32>
44+
}
45+
3946
// CHECK-LABEL: @matrix_times_matrix_1
4047
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
4148
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -123,7 +130,6 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
123130
return
124131
}
125132

126-
127133
// -----
128134

129135
func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
@@ -135,7 +141,7 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3
135141
// -----
136142

137143
func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
138-
// expected-error @+1 {{matrix, vector, and result element types must match}}
144+
// expected-error @+1 {{op failed to verify that all of {vector, result} have same element type}}
139145
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
140146
return
141147
}
@@ -155,3 +161,35 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3
155161
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
156162
return
157163
}
164+
165+
// -----
166+
167+
func.func @vector_times_matrix_vector_matrix_mismatch(%arg0: vector<4xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
168+
// expected-error @+1 {{number of components in vector must equal the number of components in each column in matrix}}
169+
%result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<4xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
170+
return
171+
}
172+
173+
// -----
174+
175+
func.func @vector_times_matrix_result_matrix_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
176+
// expected-error @+1 {{number of columns in matrix must equal the number of components in result}}
177+
%result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
178+
return
179+
}
180+
181+
// -----
182+
183+
func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xf16>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
184+
// expected-error @+1 {{op failed to verify that all of {vector, result} have same element type}}
185+
%result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf16>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
186+
return
187+
}
188+
189+
// -----
190+
191+
func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) {
192+
// expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
193+
%result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32>
194+
return
195+
}

mlir/test/Target/SPIRV/matrix.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
4242
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
4343
spirv.ReturnValue %result : vector<4xf32>
4444
}
45+
46+
// CHECK-LABEL: @vector_times_matrix_1
47+
spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
48+
// CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
49+
%result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
50+
spirv.ReturnValue %result : vector<4xf32>
51+
}
4552

4653
// CHECK-LABEL: @matrix_times_matrix_1
4754
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{

0 commit comments

Comments
 (0)