Skip to content

Commit e5a28a3

Browse files
authored
[mlir][spirv] Add MatrixTimesVector Op (#122302)
(From SPIRV reference here : https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpMatrixTimesVector)
1 parent 7b3a353 commit e5a28a3

File tree

5 files changed

+111
-1
lines changed

5 files changed

+111
-1
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4171,6 +4171,7 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
41714171
def SPIRV_IsCooperativeMatrixType :
41724172
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
41734173
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
4174+
def SPIRV_IsVectorType : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
41744175
def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">;
41754176
def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
41764177
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
@@ -4202,6 +4203,8 @@ def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
42024203
"any SPIR-V cooperative matrix type">;
42034204
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
42044205
"any SPIR-V image type">;
4206+
def SPIRV_AnyVector : DialectType<SPIRV_Dialect, SPIRV_IsVectorType,
4207+
"any SPIR-V vector type">;
42054208
def SPIRV_AnyMatrix : DialectType<SPIRV_Dialect, SPIRV_IsMatrixType,
42064209
"any SPIR-V matrix type">;
42074210
def SPIRV_AnyRTArray : DialectType<SPIRV_Dialect, SPIRV_IsRTArrayType,
@@ -4384,6 +4387,7 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
43844387
def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
43854388
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
43864389
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
4390+
def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
43874391
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
43884392
def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
43894393
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4553,7 +4557,7 @@ def SPIRV_OpcodeAttr :
45534557
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
45544558
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
45554559
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
4556-
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
4560+
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
45574561
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
45584562
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
45594563
SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,47 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
114114

115115
// -----
116116

117+
def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
118+
let summary = "Linear-algebraic multiply of matrix X vector.";
119+
120+
let description = [{
121+
Result Type must be a vector of floating-point type.
122+
123+
Matrix must be an OpTypeMatrix whose Column Type is Result Type.
124+
125+
Vector must be a vector with the same Component Type as the Component Type in Result Type. Its number of components must equal the number of columns in Matrix.
126+
127+
#### Example:
128+
129+
```mlir
130+
%0 = spirv.MatrixTimesVector %matrix, %vector :
131+
!spirv.matrix<3 x vector<2xf32>>, vector<3xf32> -> vector<2xf32>
132+
```
133+
}];
134+
135+
let availability = [
136+
MinVersion<SPIRV_V_1_0>,
137+
MaxVersion<SPIRV_V_1_6>,
138+
Extension<[]>,
139+
Capability<[SPIRV_C_Matrix]>
140+
];
141+
142+
let arguments = (ins
143+
SPIRV_AnyMatrix:$matrix,
144+
SPIRV_AnyVector:$vector
145+
);
146+
147+
let results = (outs
148+
SPIRV_AnyVector:$result
149+
);
150+
151+
let assemblyFormat = [{
152+
operands attr-dict `:` type($matrix) `,` type($vector) `->` type($result)
153+
}];
154+
}
155+
156+
// -----
157+
117158
def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
118159
let summary = "Transpose a matrix.";
119160

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,33 @@ LogicalResult spirv::TransposeOp::verify() {
16981698
return success();
16991699
}
17001700

1701+
//===----------------------------------------------------------------------===//
1702+
// spirv.MatrixTimesVector
1703+
//===----------------------------------------------------------------------===//
1704+
1705+
LogicalResult spirv::MatrixTimesVectorOp::verify() {
1706+
auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
1707+
auto vectorType = llvm::cast<VectorType>(getVector().getType());
1708+
auto resultType = llvm::cast<VectorType>(getType());
1709+
1710+
if (matrixType.getNumColumns() != vectorType.getNumElements())
1711+
return emitOpError("matrix columns (")
1712+
<< matrixType.getNumColumns() << ") must match vector operand size ("
1713+
<< vectorType.getNumElements() << ")";
1714+
1715+
if (resultType.getNumElements() != matrixType.getNumRows())
1716+
return emitOpError("result size (")
1717+
<< resultType.getNumElements() << ") must match the matrix rows ("
1718+
<< matrixType.getNumRows() << ")";
1719+
1720+
auto matrixElementType = matrixType.getElementType();
1721+
if (matrixElementType != vectorType.getElementType() ||
1722+
matrixElementType != resultType.getElementType())
1723+
return emitOpError("matrix, vector, and result element types must match");
1724+
1725+
return success();
1726+
}
1727+
17011728
//===----------------------------------------------------------------------===//
17021729
// spirv.MatrixTimesMatrix
17031730
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
2929
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
3030
}
3131

32+
// CHECK-LABEL: @matrix_times_vector_1
33+
spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
34+
// CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
35+
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
36+
spirv.ReturnValue %result : vector<4xf32>
37+
}
38+
3239
// CHECK-LABEL: @matrix_times_matrix_1
3340
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
3441
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -124,3 +131,27 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3
124131
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
125132
return
126133
}
134+
135+
// -----
136+
137+
func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
138+
// expected-error @+1 {{matrix, vector, and result element types must match}}
139+
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
140+
return
141+
}
142+
143+
// -----
144+
145+
func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
146+
// expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
147+
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
148+
return
149+
}
150+
151+
// -----
152+
153+
func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<3xf32>) {
154+
// expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
155+
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
156+
return
157+
}

mlir/test/Target/SPIRV/matrix.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
3636
spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>>
3737
}
3838

39+
// CHECK-LABEL: @matrix_times_vector_1
40+
spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
41+
// CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
42+
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
43+
spirv.ReturnValue %result : vector<4xf32>
44+
}
45+
3946
// CHECK-LABEL: @matrix_times_matrix_1
4047
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
4148
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>

0 commit comments

Comments
 (0)