-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add definition for VectorTimesMatrixOp #124571
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Igor Wodiany (IgWod-IMG) ChangesAdding op as defined in section 3.52.13. (Arithmetic Instructions) of the SPIR-V specification. Full diff: https://github.com/llvm/llvm-project/pull/124571.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index c84677d26a8b69..2f50f9b6111822 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4387,7 +4387,8 @@ def SPIRV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>;
def SPIRV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>;
def SPIRV_OC_OpVectorTimesScalar : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
def SPIRV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
-def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
+def SPIRV_OC_OpVectorTimesMatrix : I32EnumAttrCase<"OpVectorTimesMatrix", 144>;
+def SPIRV_OC_OpMatrixTimesVector : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
def SPIRV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
def SPIRV_OC_OpDot : I32EnumAttrCase<"OpDot", 148>;
def SPIRV_OC_OpIAddCarry : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4559,7 +4560,8 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
- SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
+ SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+ SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector,
SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index 5bd99386e00858..78b5fa2c228dc2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -63,8 +63,7 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
// -----
-def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
- "MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
+def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
let summary = "Scale a floating-point matrix.";
let description = [{
@@ -115,7 +114,7 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
// -----
def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
- let summary = "Linear-algebraic multiply of matrix X vector.";
+ let summary = "Linear-algebraic Matrix X Vector.";
let description = [{
Result Type must be a vector of floating-point type.
@@ -198,4 +197,50 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
// -----
+def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [Pure]> {
+ let summary = "Linear-algebraic Vector X Matrix.";
+
+ let description = [{
+ Result Type must be a vector of floating-point type.
+
+ Vector must be a vector with the same Component Type as the Component
+ Type in Result Type. Its number of components must equal the number of
+ components in each column in Matrix.
+
+ Matrix must be a matrix with the same Component Type as the Component
+ Type in Result Type. Its number of columns must equal the number of
+ components in Result Type.
+
+ <!-- End of AutoGen section -->
+
+ #### Example:
+
+ ```mlir
+ %result = spirv.VectorTimesMatrix %vector, %matrix : vector<4xf32>, !spirv.matrix<4 x vector<4xf32>> -> vector<4xf32>
+ ```
+ }];
+
+ let availability = [
+ MinVersion<SPIRV_V_1_0>,
+ MaxVersion<SPIRV_V_1_6>,
+ Extension<[]>,
+ Capability<[SPIRV_C_Matrix]>
+ ];
+
+ let arguments = (ins
+ SPIRV_AnyVector:$vector,
+ SPIRV_AnyMatrix:$matrix
+ );
+
+ let results = (outs
+ SPIRV_VectorOf<SPIRV_Float>:$result
+ );
+
+ let assemblyFormat = [{
+ operands attr-dict `:` type($vector) `,` type($matrix) `->` type($result)
+ }];
+}
+
+// -----
+
#endif // MLIR_DIALECT_SPIRV_IR_MATRIX_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 040bf6a34cea78..2273e8073503d4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1725,6 +1725,34 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// spirv.VectorTimesMatrix
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::VectorTimesMatrixOp::verify() {
+ auto vectorType = llvm::cast<VectorType>(getVector().getType());
+ auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
+ auto resultType = llvm::cast<VectorType>(getType());
+
+ if (matrixType.getNumRows() != vectorType.getNumElements())
+ return emitOpError("number of components in vector must equal the number "
+ "of components in each column in matrix");
+
+ if (resultType.getNumElements() != matrixType.getNumColumns())
+ return emitOpError("number of columns in matrix must equal the number of "
+ "components in result");
+
+ if (resultType.getElementType() != vectorType.getElementType())
+ return emitOpError("vector must be a vector with the same component type "
+ "as the component type in result");
+
+ if (matrixType.getElementType() != resultType.getElementType())
+ return emitOpError("matrix must be a matrix with the same component type "
+ "as the component type in result");
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// spirv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 37e7514d664ef0..79379b45805ac4 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.ReturnValue %result : vector<4xf32>
}
+ // CHECK-LABEL: @vector_times_matrix_1
+ spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
+ // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ spirv.ReturnValue %result : vector<4xf32>
+ }
+
// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
// CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -123,7 +130,6 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
return
}
-
// -----
func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
@@ -155,3 +161,35 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
return
}
+
+// -----
+
+func.func @vector_times_matrix_vector_matrix_mismatch(%arg0: vector<4xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+ // expected-error @+1 {{number of components in vector must equal the number of components in each column in matrix}}
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<4xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @vector_times_matrix_result_matrix_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+ // expected-error @+1 {{number of columns in matrix must equal the number of components in result}}
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
+ return
+}
+
+// -----
+
+func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xi32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+ // expected-error @+1 {{vector must be a vector with the same component type as the component type in result}}
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xi32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ return
+}
+
+// -----
+
+func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) {
+ // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32>
+ return
+}
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 0ec1dc27e4e932..452f8fc16f2588 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -42,6 +42,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
%result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
spirv.ReturnValue %result : vector<4xf32>
}
+
+ // CHECK-LABEL: @vector_times_matrix_1
+ spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
+ // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+ spirv.ReturnValue %result : vector<4xf32>
+ }
// CHECK-LABEL: @matrix_times_matrix_1
spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
|
Nothing to add, however I will need someone to merge it, as I don't have committers rights. |
I have pushed an updated patch. I made |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good now, but I'd like to keep matvec and vecmat in sync if we can. This is a little bit outside of the initial scope of this PR, so I'd be fine with either changing both in one go or leaving a TODO for matvec.
Adding op as defined in section 3.52.13. (Arithmetic Instructions) of the SPIR-V specification.
Should be ready now. I removed Please merge if happy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@kuhar Could you please commit it, as I don't have rights to do it. I'll request committer rights next week, so I don't have to keep asking - I think I have strong enough case for it to be granted. |
Adding op as defined in section 3.52.13. (Arithmetic Instructions) of the SPIR-V specification.