-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Support coop matrix in spirv.CompositeConstruct
#66399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-spirv ChangesAlso improve the documentation (code and website). -- Full diff: https://github.com//pull/66399.diff3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td index b8307b488af6fa5..8216814d9f99598 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td @@ -53,7 +53,15 @@ def SPIRV_CompositeConstructOp : SPIRV_Op<"CompositeConstruct", [Pure]> { #### Example: ```mlir - %0 = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32> + %a = spirv.CompositeConstruct %1, %2, %3 : vector<3xf32> + %b = spirv.CompositeConstruct %a, %1 : (vector<3xf32>, f32) -> vector<4xf32> + + %c = spirv.CompositeConstruct %1 : + !spirv.coopmatrix<4x4xf32, Subgroup, MatrixA> + + %d = spirv.CompositeConstruct %a, %4, %5 : + (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> + !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> ``` }]; diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 1f07b0b9e85bff6..3906bf74ea72235 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -29,6 +29,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LogicalResult.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -363,31 +364,35 @@ LogicalResult spirv::AddressOfOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::CompositeConstructOp::verify() { - auto cType = llvm::cast<spirv::CompositeType>(getType()); operand_range constituents = this->getConstituents(); - if (auto coopType = llvm::dyn_cast<spirv::CooperativeMatrixNVType>(cType)) { - if (constituents.size() != 1) - return emitOpError("has incorrect number of operands: expected ") - << "1, but provided " << constituents.size(); - if (coopType.getElementType() != constituents.front().getType()) - return emitOpError("operand type mismatch: expected operand type ") - << coopType.getElementType() << ", but provided " - << constituents.front().getType(); - return success(); - } + // There are 4 cases with varying verification rules: + // 1. Cooperative Matrices (1 constituent) + // 2. Structs (1 constituent for each member) + // 3. Arrays (1 constituent for each array element) + // 4. Vectors (1 constituent (sub-)element for each vector element) + + auto coopElementType = + llvm::TypeSwitch<Type, Type>(getType()) + .Case<spirv::CooperativeMatrixType, spirv::CooperativeMatrixNVType, + spirv::JointMatrixINTELType>( + [](auto coopType) { return coopType.getElementType(); }) + .Default([](Type) { return nullptr; }); - if (auto jointType = llvm::dyn_cast<spirv::JointMatrixINTELType>(cType)) { + // Case 1. -- matrices. + if (coopElementType) { if (constituents.size() != 1) return emitOpError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); - if (jointType.getElementType() != constituents.front().getType()) + if (coopElementType != constituents.front().getType()) return emitOpError("operand type mismatch: expected operand type ") - << jointType.getElementType() << ", but provided " + << coopElementType << ", but provided " << constituents.front().getType(); return success(); } + // Case 2./3./4. -- number of constituents matches the number of elements. + auto cType = llvm::cast<spirv::CompositeType>(getType()); if (constituents.size() == cType.getNumElements()) { for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { if (constituents[index].getType() != cType.getElementType(index)) { @@ -399,8 +404,7 @@ LogicalResult spirv::CompositeConstructOp::verify() { return success(); } - // If not constructing a cooperative matrix type, then we must be constructing - // a vector type. + // Case 4. -- check that all constituents add up tp the expected vector type. auto resultType = llvm::dyn_cast<VectorType>(cType); if (!resultType) return emitOpError( diff --git a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir index ce7f6bc6118b316..2891513961d5e2a 100644 --- a/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/composite-ops.mlir @@ -4,22 +4,20 @@ // spirv.CompositeConstruct //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @composite_construct_vector func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> vector<3xf32> { // CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (f32, f32, f32) -> vector<3xf32> %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (f32, f32, f32) -> vector<3xf32> return %0: vector<3xf32> } -// ----- - +// CHECK-LABEL: func @composite_construct_struct func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> { // CHECK: spirv.CompositeConstruct %0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> return %0: !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> } -// ----- - // CHECK-LABEL: func @composite_construct_mixed_scalar_vector func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 : vector<2xf32>) -> vector<4xf32> { // CHECK: spirv.CompositeConstruct %{{.+}}, %{{.+}}, %{{.+}} : (f32, vector<2xf32>, f32) -> vector<4xf32> @@ -27,9 +25,15 @@ func.func @composite_construct_mixed_scalar_vector(%arg0: f32, %arg1: f32, %arg2 return %0: vector<4xf32> } -// ----- +// CHECK-LABEL: func @composite_construct_coopmatrix_khr +func.func @composite_construct_coopmatrix_khr(%arg0 : f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> { + // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> + %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> + return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> +} -func.func @composite_construct_NV.coopmatrix(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { +// CHECK-LABEL: func @composite_construct_coopmatrix_nv +func.func @composite_construct_coopmatrix_nv(%arg0 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { // CHECK: spirv.CompositeConstruct {{%.*}} : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> %0 = spirv.CompositeConstruct %arg0 : (f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> return %0: !spirv.NV.coopmatrix<8x16xf32, Subgroup> @@ -53,6 +57,24 @@ func.func @composite_construct_invalid_operand_type(%arg0: f32, %arg1: f32, %arg // ----- +func.func @composite_construct_khr_coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> + !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> { + // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} + %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> + return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixA> +} + +// ----- + +func.func @composite_construct_khr_coopmatrix_incorrect_element_type(%arg0 : i32) -> + !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> { + // expected-error @+1 {{operand type mismatch: expected operand type 'f32', but provided 'i32'}} + %0 = spirv.CompositeConstruct %arg0 : (i32) -> !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> + return %0: !spirv.coopmatrix<8x16xf32, Subgroup, MatrixB> +} + +// ----- + func.func @composite_construct_NV.coopmatrix_incorrect_operand_count(%arg0 : f32, %arg1 : f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> { // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} %0 = spirv.CompositeConstruct %arg0, %arg1 : (f32, f32) -> !spirv.NV.coopmatrix<8x16xf32, Subgroup> |
Also improve the documentation (code and website).
0b8c7de
to
f379064
Compare
antiagainst
approved these changes
Sep 14, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
ZijunZhaoCCK
pushed a commit
to ZijunZhaoCCK/llvm-project
that referenced
this pull request
Sep 19, 2023
…#66399) Also improve the documentation (code and website).
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Also improve the documentation (code and website).