Skip to content

[mlir][spirv] Add definition for VectorTimesMatrixOp #124571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2025

Conversation

IgWod-IMG
Copy link
Contributor

Adding op as defined in section 3.52.13. (Arithmetic Instructions) of the SPIR-V specification.

@llvmbot
Copy link
Member

llvmbot commented Jan 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Igor Wodiany (IgWod-IMG)

Changes

Adding op as defined in section 3.52.13. (Arithmetic Instructions) of the SPIR-V specification.


Full diff: https://github.com/llvm/llvm-project/pull/124571.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+4-2)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td (+48-3)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+28)
  • (modified) mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir (+39-1)
  • (modified) mlir/test/Target/SPIRV/matrix.mlir (+7)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index c84677d26a8b69..2f50f9b6111822 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4387,7 +4387,8 @@ def SPIRV_OC_OpFRem                         : I32EnumAttrCase<"OpFRem", 140>;
 def SPIRV_OC_OpFMod                         : I32EnumAttrCase<"OpFMod", 141>;
 def SPIRV_OC_OpVectorTimesScalar            : I32EnumAttrCase<"OpVectorTimesScalar", 142>;
 def SPIRV_OC_OpMatrixTimesScalar            : I32EnumAttrCase<"OpMatrixTimesScalar", 143>;
-def SPIRV_OC_OpMatrixTimesVector           : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
+def SPIRV_OC_OpVectorTimesMatrix            : I32EnumAttrCase<"OpVectorTimesMatrix", 144>;
+def SPIRV_OC_OpMatrixTimesVector            : I32EnumAttrCase<"OpMatrixTimesVector", 145>;
 def SPIRV_OC_OpMatrixTimesMatrix            : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>;
 def SPIRV_OC_OpDot                          : I32EnumAttrCase<"OpDot", 148>;
 def SPIRV_OC_OpIAddCarry                    : I32EnumAttrCase<"OpIAddCarry", 149>;
@@ -4559,7 +4560,8 @@ def SPIRV_OpcodeAttr :
       SPIRV_OC_OpFSub, SPIRV_OC_OpIMul, SPIRV_OC_OpFMul, SPIRV_OC_OpUDiv,
       SPIRV_OC_OpSDiv, SPIRV_OC_OpFDiv, SPIRV_OC_OpUMod, SPIRV_OC_OpSRem,
       SPIRV_OC_OpSMod, SPIRV_OC_OpFRem, SPIRV_OC_OpFMod,
-      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar, SPIRV_OC_OpMatrixTimesVector,
+      SPIRV_OC_OpVectorTimesScalar, SPIRV_OC_OpMatrixTimesScalar,
+      SPIRV_OC_OpVectorTimesMatrix, SPIRV_OC_OpMatrixTimesVector,
       SPIRV_OC_OpMatrixTimesMatrix, SPIRV_OC_OpDot, SPIRV_OC_OpIAddCarry,
       SPIRV_OC_OpISubBorrow, SPIRV_OC_OpUMulExtended, SPIRV_OC_OpSMulExtended,
       SPIRV_OC_OpIsNan, SPIRV_OC_OpIsInf, SPIRV_OC_OpOrdered, SPIRV_OC_OpUnordered,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
index 5bd99386e00858..78b5fa2c228dc2 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td
@@ -63,8 +63,7 @@ def SPIRV_MatrixTimesMatrixOp : SPIRV_Op<"MatrixTimesMatrix", [Pure]> {
 
 // -----
 
-def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
-    "MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
+def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> {
   let summary = "Scale a floating-point matrix.";
 
   let description = [{
@@ -115,7 +114,7 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op<
 // -----
 
 def SPIRV_MatrixTimesVectorOp : SPIRV_Op<"MatrixTimesVector", [Pure]> {
-  let summary = "Linear-algebraic multiply of matrix X vector.";
+  let summary = "Linear-algebraic Matrix X Vector.";
 
   let description = [{
     Result Type must be a vector of floating-point type.
@@ -198,4 +197,50 @@ def SPIRV_TransposeOp : SPIRV_Op<"Transpose", [Pure]> {
 
 // -----
 
+def SPIRV_VectorTimesMatrixOp : SPIRV_Op<"VectorTimesMatrix", [Pure]> {
+  let summary = "Linear-algebraic Vector X Matrix.";
+
+  let description = [{
+    Result Type must be a vector of floating-point type.
+
+    Vector must be a vector with the same Component Type as the Component
+    Type in Result Type. Its number of components must equal the number of
+    components in each column in Matrix.
+
+    Matrix must be a matrix with the same Component Type as the Component
+    Type in Result Type. Its number of columns must equal the number of
+    components in Result Type.
+
+    <!-- End of AutoGen section -->
+
+    #### Example:
+
+    ```mlir
+    %result = spirv.VectorTimesMatrix %vector, %matrix : vector<4xf32>, !spirv.matrix<4 x vector<4xf32>> -> vector<4xf32>
+    ```
+  }];
+
+  let availability = [
+    MinVersion<SPIRV_V_1_0>,
+    MaxVersion<SPIRV_V_1_6>,
+    Extension<[]>,
+    Capability<[SPIRV_C_Matrix]>
+  ];
+
+  let arguments = (ins
+    SPIRV_AnyVector:$vector,
+    SPIRV_AnyMatrix:$matrix
+  );
+
+  let results = (outs
+    SPIRV_VectorOf<SPIRV_Float>:$result
+  );
+
+  let assemblyFormat = [{
+    operands attr-dict `:` type($vector) `,` type($matrix) `->` type($result)
+  }];
+}
+
+// -----
+
 #endif // MLIR_DIALECT_SPIRV_IR_MATRIX_OPS
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 040bf6a34cea78..2273e8073503d4 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1725,6 +1725,34 @@ LogicalResult spirv::MatrixTimesVectorOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// spirv.VectorTimesMatrix
+//===----------------------------------------------------------------------===//
+
+LogicalResult spirv::VectorTimesMatrixOp::verify() {
+  auto vectorType = llvm::cast<VectorType>(getVector().getType());
+  auto matrixType = llvm::cast<spirv::MatrixType>(getMatrix().getType());
+  auto resultType = llvm::cast<VectorType>(getType());
+
+  if (matrixType.getNumRows() != vectorType.getNumElements())
+    return emitOpError("number of components in vector must equal the number "
+                       "of components in each column in matrix");
+
+  if (resultType.getNumElements() != matrixType.getNumColumns())
+    return emitOpError("number of columns in matrix must equal the number of "
+                       "components in result");
+
+  if (resultType.getElementType() != vectorType.getElementType())
+    return emitOpError("vector must be a vector with the same component type "
+                       "as the component type in result");
+
+  if (matrixType.getElementType() != resultType.getElementType())
+    return emitOpError("matrix must be a matrix with the same component type "
+                       "as the component type in result");
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // spirv.MatrixTimesMatrix
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
index 37e7514d664ef0..79379b45805ac4 100644
--- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir
@@ -36,6 +36,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     spirv.ReturnValue %result : vector<4xf32>
   }
 
+  // CHECK-LABEL: @vector_times_matrix_1
+  spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+    %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
+
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{
     // CHECK: {{%.*}} = spirv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>>
@@ -123,7 +130,6 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
    return
 }
 
-
 // -----
 
 func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){
@@ -155,3 +161,35 @@ func.func @matrix_times_vector_column_mismatch(%arg0: !spirv.matrix<4 x vector<3
   %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<4 x vector<3xf32>>, vector<3xf32> -> vector<3xf32>
   return
 }
+
+// -----
+
+func.func @vector_times_matrix_vector_matrix_mismatch(%arg0: vector<4xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+  // expected-error @+1 {{number of components in vector must equal the number of components in each column in matrix}}
+  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<4xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @vector_times_matrix_result_matrix_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+  // expected-error @+1 {{number of columns in matrix must equal the number of components in result}}
+  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<3xf32>
+  return
+}
+
+// -----
+
+func.func @vector_times_matrix_vector_type_mismatch(%arg0: vector<3xi32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) {
+  // expected-error @+1 {{vector must be a vector with the same component type as the component type in result}}
+  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xi32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+  return
+}
+
+// -----
+
+func.func @vector_times_matrix_matrix_type_mismatch(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf16>>) {
+  // expected-error @+1 {{matrix must be a matrix with the same component type as the component type in result}}
+  %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf16>> -> vector<4xf32>
+  return
+}
diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir
index 0ec1dc27e4e932..452f8fc16f2588 100644
--- a/mlir/test/Target/SPIRV/matrix.mlir
+++ b/mlir/test/Target/SPIRV/matrix.mlir
@@ -42,6 +42,13 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
     %result = spirv.MatrixTimesVector %arg0, %arg1 : !spirv.matrix<3 x vector<4xf32>>, vector<3xf32> -> vector<4xf32>
     spirv.ReturnValue %result : vector<4xf32>
   }
+
+  // CHECK-LABEL: @vector_times_matrix_1
+  spirv.func @vector_times_matrix_1(%arg0: vector<3xf32>, %arg1: !spirv.matrix<4 x vector<3xf32>>) -> vector<4xf32> "None" {
+    // CHECK: {{%.*}} = spirv.VectorTimesMatrix {{%.*}}, {{%.*}} : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+    %result = spirv.VectorTimesMatrix %arg0, %arg1 : vector<3xf32>, !spirv.matrix<4 x vector<3xf32>> -> vector<4xf32>
+    spirv.ReturnValue %result : vector<4xf32>
+  }
   
   // CHECK-LABEL: @matrix_times_matrix_1
   spirv.func @matrix_times_matrix_1(%arg0: !spirv.matrix<3 x vector<3xf32>>, %arg1: !spirv.matrix<3 x vector<3xf32>>) -> !spirv.matrix<3 x vector<3xf32>> "None"{

@IgWod-IMG
Copy link
Contributor Author

Nothing to add, however I will need someone to merge it, as I don't have committers rights.

@IgWod-IMG
Copy link
Contributor Author

I have pushed an updated patch. I made matrix SPIRV_MatrixOrCoopMatrixOf<SPIRV_Float>. Not sure the coop part is 100% correct, but it seems to be fairly consistent with other ops, hence the choice. Let me know if there is anything else that needs addressing.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good now, but I'd like to keep matvec and vecmat in sync if we can. This is a little bit outside of the initial scope of this PR, so I'd be fine with either changing both in one go or leaving a TODO for matvec.

Adding op as defined in section 3.52.13. (Arithmetic Instructions)
of the SPIR-V specification.
@IgWod-IMG
Copy link
Contributor Author

Should be ready now. I removed CoopMatrix. I didn't realise MatrixTimesScalar was the only supported op for it. I had to add MatrixOf, as it seems it wasn't something that was available - hope I did it right! I also updated MatrixTimesVector, probably the code could be further cleaned up by extracting common code in verification, but I don't think it's worth the effort. I think we should eventually replace all verification with ODS here - I'll try to look into it when I have a bit more time.

Please merge if happy.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@IgWod-IMG
Copy link
Contributor Author

@kuhar Could you please commit it, as I don't have rights to do it. I'll request committer rights next week, so I don't have to keep asking - I think I have strong enough case for it to be granted.

@kuhar kuhar merged commit a01097f into llvm:main Jan 29, 2025
8 checks passed
@IgWod-IMG IgWod-IMG deleted the vec-x-mat branch January 29, 2025 15:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants