Skip to content

[mlir][spirv] Support coop matrix in spirv.CompositeConstruct #66399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 14, 2023

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented Sep 14, 2023

Also improve the documentation (code and website).

@llvmbot
Copy link
Member

llvmbot commented Sep 14, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Changes Also improve the documentation (code and website). -- Full diff: https://github.com//pull/66399.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td (+9-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+20-16)
  • (modified) mlir/test/Dialect/SPIRV/IR/composite-ops.mlir (+28-6)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
index b8307b488af6fa5..8216814d9f99598 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td
@@ -53,7 +53,15 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> {
     #### Example:
 
     ```mlir
-    %0 = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
+    %a = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32>
+    %b = spirv.CompositeConstruct %a, %1 : (vector<3xf32>, f32) -> vector<4xf32>
+
+    %c = spirv.CompositeConstruct %1 :
+      !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA>
+
+    %d = spirv.CompositeConstruct %a, %4, %5 :
+      (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) ->
+        !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
     ```
   }];
 
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 1f07b0b9e85bff6..3906bf74ea72235 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -29,6 +29,7 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
+#include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -363,31 +364,35 @@ LogicalResult spirv::AddressOfOp::verify() {
 //===----------------------------------------------------------------------===//
 
 LogicalResult spirv::CompositeConstructOp::verify() {
-  auto cType = llvm::cast<spirv::CompositeType>(getType());
   operand_range constituents = this->getConstituents();
 
-  if (auto coopType = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(cType)) {
-    if (constituents.size() != 1)
-      return emitOpError("has incorrect number of operands: expected ")
-             << "1, but provided " << constituents.size();
-    if (coopType.getElementType() != constituents.front().getType())
-      return emitOpError("operand type mismatch: expected operand type ")
-             << coopType.getElementType() << ", but provided "
-             << constituents.front().getType();
-    return success();
-  }
+  // There are 4 cases with varying verification rules:
+  // 1. Cooperative Matrices (1 constituent)
+  // 2. Structs (1 constituent for each member)
+  // 3. Arrays (1 constituent for each array element)
+  // 4. Vectors (1 constituent (sub-)element for each vector element)
+
+  auto coopElementType =
+      llvm::TypeSwitch<Type, Type>(getType())
+          .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType,
+                spirv::JointMatrixINTELType>(
+              [](auto coopType) { return coopType.getElementType(); })
+          .Default([](Type) { return nullptr; });
 
-  if (auto jointType = llvm::dyn_cast<spirv::JointMatrixINTELType>(cType)) {
+  // Case 1. -- matrices.
+  if (coopElementType) {
     if (constituents.size() != 1)
       return emitOpError("has incorrect number of operands: expected ")
              << "1, but provided " << constituents.size();
-    if (jointType.getElementType() != constituents.front().getType())
+    if (coopElementType != constituents.front().getType())
       return emitOpError("operand type mismatch: expected operand type ")
-             << jointType.getElementType() << ", but provided "
+             << coopElementType << ", but provided "
              << constituents.front().getType();
     return success();
   }
 
+  // Case 2./3./4. -- number of constituents matches the number of elements.
+  auto cType = llvm::cast<spirv::CompositeType>(getType());
   if (constituents.size() == cType.getNumElements()) {
     for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
       if (constituents[index].getType() != cType.getElementType(index)) {
@@ -399,8 +404,7 @@ LogicalResult spirv::CompositeConstructOp::verify() {
     return success();
   }
 
-  // If not constructing a cooperative matrix type, then we must be constructing
-  // a vector type.
+  // Case 4. -- check that all constituents add up tp the expected vector type.
   auto resultType = llvm::dyn_cast<VectorType>(cType);
   if (!resultType)
     return emitOpError(
diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
index ce7f6bc6118b316..2891513961d5e2a 100644
--- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir
@@ -4,22 +4,20 @@
 // spirv.CompositeConstruct
 //===----------------------------------------------------------------------===//
 
+// CHECK-LABEL: func @composite_construct_vector
 func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> {
   // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32>
   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32>
   return %0: vector<3xf32>
 }
 
-// -----
-
+// CHECK-LABEL: func @composite_construct_struct
 func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
   // CHECK: spirv.CompositeConstruct
   %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
   return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)>
 }
 
-// -----
-
 // CHECK-LABEL: func @composite_construct_mixed_scalar_vector
 func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> {
   // CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32>
@@ -27,9 +25,15 @@ func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2
   return %0: vector<4xf32>
 }
 
-// -----
+// CHECK-LABEL: func @composite_construct_coopmatrix_khr
+func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
+  // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+  %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+}
 
-func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
+// CHECK-LABEL: func @composite_construct_coopmatrix_nv
+func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
   // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
   %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>
   return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup>
@@ -53,6 +57,24 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg
 
 // -----
 
+func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) ->
+  !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> {
+  // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
+  %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA>
+}
+
+// -----
+
+func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) ->
+  !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> {
+  // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}}
+  %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
+  return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB>
+}
+
+// -----
+
 func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> {
   // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}}
   %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup>

Also improve the documentation (code and website).
@kuhar kuhar force-pushed the coop-composite-construct branch from 0b8c7de to f379064 Compare September 14, 2023 16:46
Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@kuhar kuhar merged commit 12175bc into llvm:main Sep 14, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…#66399)

Also improve the documentation (code and website).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:spirv mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants