Skip to content

[mlir][spirv] Add MatrixTimesVector Op #122302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4171,6 +4171,7 @@ def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
def SPIRV_IsCooperativeMatrixType :
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
def SPIRV_IsVectorType : CPred<"::llvm::isa<::mlir::VectorType>($_self)">;
def SPIRV_IsMatrixType : CPred<"::llvm::isa<::mlir::spirv::MatrixType>($_self)">;
def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
Expand Down Expand Up @@ -4202,6 +4203,8 @@ def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
"any SPIR-V cooperative matrix type">;
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
"any SPIR-V image type">;
def SPIRV_AnyVector : DialectType<SPIRV_Dialect, SPIRV_IsVectorType,
"any SPIR-V vector type">;
def SPIRV_AnyMatrix : DialectType<SPIRV_Dialect, SPIRV_IsMatrixType,
"any SPIR-V matrix type">;
def SPIRV_AnyRTArray : DialectType<SPIRV_Dialect, SPIRV_IsRTArrayType,
Expand Down Expand Up @@ -4384,6 +4387,7 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
Expand Down Expand Up @@ -4553,7 +4557,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
Expand Down
41 changes: 41 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,47 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<

// -----

def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
let summary = "Linear-algebraic multiply of matrix X vector.";

let description = [{
Result Type must be a vector of floating-point type.

Matrix must be an OpTypeMatrix whose Column Type is Result Type.

Vector must be a vector with the same Component Type as the Component Type in Result Type. Its number of components must equal the number of columns in Matrix.

#### Example:

```mlir
%0 = spirv.MatrixTimesVector %matrix, %vector :
!spirv.matrix<3 x vector<2xf32>>, vector<3xf32> -> vector<2xf32>
```
}];

let availability = [
MinVersion<SPIRV_V_1_0>,
MaxVersion<SPIRV_V_1_6>,
Extension<[]>,
Capability<[SPIRV_C_Matrix]>
];

let arguments = (ins
SPIRV_AnyMatrix:$matrix,
SPIRV_AnyVector:$vector
);

let results = (outs
SPIRV_AnyVector:$result
);

let assemblyFormat = [{
operands attr-dict `:` type($matrix) `,` type($vector) `->` type($result)
}];
}

// -----

def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
let summary = "Transpose a matrix.";

Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,33 @@ LogicalResult spirv::TransposeOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// spirv.MatrixTimesVector
//===----------------------------------------------------------------------===//

LogicalResult spirv::MatrixTimesVectorOp::verify() {
auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
auto vectorType = llvm::cast<VectorType>(getVector().getType());
auto resultType = llvm::cast<VectorType>(getType());

if (matrixType.getNumColumns() != vectorType.getNumElements())
return emitOpError("matrix columns (")
<< matrixType.getNumColumns() << ") must match vector operand size ("
<< vectorType.getNumElements() << ")";

if (resultType.getNumElements() != matrixType.getNumRows())
return emitOpError("result size (")
<< resultType.getNumElements() << ") must match the matrix rows ("
<< matrixType.getNumRows() << ")";

auto matrixElementType = matrixType.getElementType();
if (matrixElementType != vectorType.getElementType() ||
matrixElementType != resultType.getElementType())
return emitOpError("matrix, vector, and result element types must match");

return success();
}

//===----------------------------------------------------------------------===//
// spirv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 31 additions & 0 deletions mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>>
}

// CHECK-LABEL: @matrix_times_vector_1
spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
spirv.ReturnValue %result : vector<4xf32>
}

// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
Expand Down Expand Up @@ -124,3 +131,27 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3
%result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
return
}

// -----

func.func @matrix_times_vector_element_type_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf16>) {
// expected-error @+1 {{matrix, vector, and result element types must match}}
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf16> -> vector<3xf32>
return
}

// -----

func.func @matrix_times_vector_row_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<4xf32>) {
// expected-error @+1 {{spirv.MatrixTimesVector' op result size (4) must match the matrix rows (3)}}
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<4xf32> -> vector<4xf32>
return
}

// -----

func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3xf32>>, %arg1: vector<3xf32>) {
// expected-error @+1 {{spirv.MatrixTimesVector' op matrix columns (4) must match vector operand size (3)}}
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
return
}
7 changes: 7 additions & 0 deletions mlir/test/Target/SPIRV/matrix.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %result : !spirv.matrix<2 x vector<3xf32>>
}

// CHECK-LABEL: @matrix_times_vector_1
spirv.func @matrix_times_vector_1(%arg0: !spirv.matrix<3 x vector<4xf32>>, %arg1: vector<3xf32>) -> vector<4xf32> "None" {
// CHECK: {{%.*}} = spirv.MatrixTimesVector {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
spirv.ReturnValue %result : vector<4xf32>
}

// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
Expand Down
Loading