-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Fix remaining coop matrix verification corner cases #66137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Check `MakePointer*` load/store attribute values. - Suuport coop matrix types in `MatrixTimesScalar` verification. - Add testcases for all the remaining ops that accept coop matrix types. - Split NV and KHR tests.
@llvm/pr-subscribers-mlir-spirv Changes
Patch is 33.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66137.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td index a21fc0ce2f9299c..a055cadc756a7e6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td @@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op< let summary = "Scale a floating-point matrix."; let description = [{ - Result Type must be an OpTypeMatrix whose Column Type is a vector of - floating-point type. + Result Type must be a matrix type with a float component type. - The type of Matrix must be the same as Result Type. Each component in + The type of Matrix must be the same as Result Type. Each component in each column in Matrix is multiplied by Scalar. Scalar must have the same type as the Component Type in Result Type. diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index d43f7a1823e912b..c8b274ceec3e59d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -20,9 +20,6 @@ using namespace mlir::spirv::AttrNames; namespace mlir::spirv { -//===----------------------------------------------------------------------===// -// spirv.KHR.CooperativeMatrixLoad -//===----------------------------------------------------------------------===// static LogicalResult verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, @@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, << pointeeType; } - // The 'Aligned' memory operand requires an alignment literal to follow, which - // needs to be implemented on the level of op parsing and (de-)serialization. - // TODO: Consider adding support for this attribute value. - if (memoryOperand && - spirv::bitEnumContainsAll(memoryOperand.getValue(), - spirv::MemoryAccess::Aligned)) { - return op->emitOpError("has unhandled memory operand 'Aligned'"); + if (memoryOperand) { + spirv::MemoryAccess operandSet = memoryOperand.getValue(); + + if (isa(op) && + spirv::bitEnumContainsAll(operandSet, + spirv::MemoryAccess::MakePointerAvailable)) { + return op->emitOpError( + "not compatible with memory operand 'MakePointerAvailable'"); + } + + if (isa(op) && + spirv::bitEnumContainsAll(operandSet, + spirv::MemoryAccess::MakePointerVisible)) { + return op->emitOpError( + "not compatible with memory operand 'MakePointerVisible'"); + } + + // The 'Aligned' memory operand requires an alignment literal to follow, + // which needs to be implemented on the level of op parsing and + // (de-)serialization. + // TODO: Consider adding support for this attribute value. + if (spirv::bitEnumContainsAll(memoryOperand.getValue(), + spirv::MemoryAccess::Aligned)) { + return op->emitOpError("has unhandled memory operand 'Aligned'"); + } } // TODO: Verify the memory object behind the pointer: @@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, return success(); } +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixLoad +//===----------------------------------------------------------------------===// + LogicalResult KHRCooperativeMatrixLoadOp::verify() { return verifyCoopMatrixAccess(*this, getPointer().getType(), getResult().getType(), getMemoryOperandAttr()); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 6ebd8515caf037d..6cd75ee6d9cba48 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include #include #include @@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesScalarOp::verify() { - if (auto inputCoopmat = llvm::dyn_cast( - getMatrix().getType())) { - if (inputCoopmat.getElementType() != getScalar().getType()) - return emitError("input matrix components' type and scaling value must " - "have the same type"); - return success(); - } + Type elementType = + llvm::TypeSwitch(getMatrix().getType()) + .Case( + [](auto matrixType) { return matrixType.getElementType(); }) + .Default([](Type) { return nullptr; }); + + assert(elementType && "Unhandld type"); // Check that the scalar type is the same as the matrix element type. - auto inputMatrix = llvm::cast(getMatrix().getType()); - if (getScalar().getType() != inputMatrix.getElementType()) - return emitError("input matrix components' type and scaling value must " - "have the same type"); + if (getScalar().getType() != elementType) + return emitOpError("input matrix components' type and scaling value must " + "have the same type"); return success(); } diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir similarity index 68% rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 3adcd711f74a8f8..445ab8a48d3ce64 100644 --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// -// CooperativeMatrix (KHR) +// CooperativeMatrix (KHR) extension ops. //===----------------------------------------------------------------------===// // CHECK-LABEL: @cooperative_matrix_length @@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" { + // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}} + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : + !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" { // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : @@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32, + %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { + // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}} + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 + spirv.Return +} + +// ----- + spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} @@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1 // ----- //===----------------------------------------------------------------------===// -// NV.CooperativeMatrix +// Standard ops that can be used CooperativeMatrix types //===----------------------------------------------------------------------===// -// CHECK-LABEL: @cooperative_matrix_load -spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - spirv.Return -} +!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> +!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB> -// CHECK-LABEL: @cooperative_matrix_load_memaccess -spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} +!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> +!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB> -// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type -spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} +// These tests are kept in the same order as the list of compatible ops in the +// SPV_KHR_cooperative_matrix extension spec. -// CHECK-LABEL: @cooperative_matrix_store -spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> +// CHECK-LABEL: @snegate +spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.SNegate {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix + %p = spirv.SNegate %a : !matA_i32 + %q = spirv.SNegate %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_store_memaccess -spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fnegate +spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FNegate {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix + %p = spirv.FNegate %a : !matA_f32 + %q = spirv.FNegate %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_length -spirv.func @cooperative_matrix_length() -> i32 "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.ReturnValue %0 : i32 -} - -// CHECK-LABEL: @cooperative_matrix_muladd -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> +// CHECK-LABEL: @iadd +spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.IAdd %a, %a : !matA_i32 + %q = spirv.IAdd %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_add -spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fadd +spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FAdd %a, %a : !matA_f32 + %q = spirv.FAdd %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_sub -spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @isub +spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.ISub %a, %a : !matA_i32 + %q = spirv.ISub %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_sdiv -spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fsub +spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FSub %a, %a : !matA_f32 + %q = spirv.FSub %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_udiv -spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fmul +spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FMul %a, %a : !matA_f32 + %q = spirv.FMul %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fadd -spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @imul +spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.IMul %a, %a : !matA_i32 + %q = spirv.IMul %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fsub -spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @fdiv +spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FDiv %a, %a : !matA_f32 + %q = spirv.FDiv %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fdiv -spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @sdiv +spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.SDiv %a, %a : !matA_i32 + %q = spirv.SDiv %b, %b : !matB_i32 spirv.Return } -// ----- - -// CHECK-LABEL: @cooperative_matrix_access_chain -spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { - %0 = spirv.Constant 0: i32 - // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 - %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 - spirv.ReturnValue %1 : !spirv.ptr -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> +// CHECK-LABEL: @udiv +spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.UDiv %a, %a : !matA_i32 + %q = spirv.UDiv %b, %b : !matB_i32 spirv.Return } -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup... |
@llvm/pr-subscribers-mlir Changes
Patch is 33.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66137.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td index a21fc0ce2f9299c..a055cadc756a7e6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td @@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op< let summary = "Scale a floating-point matrix."; let description = [{ - Result Type must be an OpTypeMatrix whose Column Type is a vector of - floating-point type. + Result Type must be a matrix type with a float component type. - The type of Matrix must be the same as Result Type. Each component in + The type of Matrix must be the same as Result Type. Each component in each column in Matrix is multiplied by Scalar. Scalar must have the same type as the Component Type in Result Type. diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index d43f7a1823e912b..c8b274ceec3e59d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -20,9 +20,6 @@ using namespace mlir::spirv::AttrNames; namespace mlir::spirv { -//===----------------------------------------------------------------------===// -// spirv.KHR.CooperativeMatrixLoad -//===----------------------------------------------------------------------===// static LogicalResult verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, @@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, << pointeeType; } - // The 'Aligned' memory operand requires an alignment literal to follow, which - // needs to be implemented on the level of op parsing and (de-)serialization. - // TODO: Consider adding support for this attribute value. - if (memoryOperand && - spirv::bitEnumContainsAll(memoryOperand.getValue(), - spirv::MemoryAccess::Aligned)) { - return op->emitOpError("has unhandled memory operand 'Aligned'"); + if (memoryOperand) { + spirv::MemoryAccess operandSet = memoryOperand.getValue(); + + if (isa(op) && + spirv::bitEnumContainsAll(operandSet, + spirv::MemoryAccess::MakePointerAvailable)) { + return op->emitOpError( + "not compatible with memory operand 'MakePointerAvailable'"); + } + + if (isa(op) && + spirv::bitEnumContainsAll(operandSet, + spirv::MemoryAccess::MakePointerVisible)) { + return op->emitOpError( + "not compatible with memory operand 'MakePointerVisible'"); + } + + // The 'Aligned' memory operand requires an alignment literal to follow, + // which needs to be implemented on the level of op parsing and + // (de-)serialization. + // TODO: Consider adding support for this attribute value. + if (spirv::bitEnumContainsAll(memoryOperand.getValue(), + spirv::MemoryAccess::Aligned)) { + return op->emitOpError("has unhandled memory operand 'Aligned'"); + } } // TODO: Verify the memory object behind the pointer: @@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, return success(); } +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixLoad +//===----------------------------------------------------------------------===// + LogicalResult KHRCooperativeMatrixLoadOp::verify() { return verifyCoopMatrixAccess(*this, getPointer().getType(), getResult().getType(), getMemoryOperandAttr()); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 6ebd8515caf037d..6cd75ee6d9cba48 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include #include #include @@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesScalarOp::verify() { - if (auto inputCoopmat = llvm::dyn_cast( - getMatrix().getType())) { - if (inputCoopmat.getElementType() != getScalar().getType()) - return emitError("input matrix components' type and scaling value must " - "have the same type"); - return success(); - } + Type elementType = + llvm::TypeSwitch(getMatrix().getType()) + .Case( + [](auto matrixType) { return matrixType.getElementType(); }) + .Default([](Type) { return nullptr; }); + + assert(elementType && "Unhandld type"); // Check that the scalar type is the same as the matrix element type. - auto inputMatrix = llvm::cast(getMatrix().getType()); - if (getScalar().getType() != inputMatrix.getElementType()) - return emitError("input matrix components' type and scaling value must " - "have the same type"); + if (getScalar().getType() != elementType) + return emitOpError("input matrix components' type and scaling value must " + "have the same type"); return success(); } diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir similarity index 68% rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 3adcd711f74a8f8..445ab8a48d3ce64 100644 --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// -// CooperativeMatrix (KHR) +// CooperativeMatrix (KHR) extension ops. //===----------------------------------------------------------------------===// // CHECK-LABEL: @cooperative_matrix_length @@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" { + // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}} + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : + !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" { // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : @@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32, + %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { + // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}} + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 + spirv.Return +} + +// ----- + spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} @@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1 // ----- //===----------------------------------------------------------------------===// -// NV.CooperativeMatrix +// Standard ops that can be used CooperativeMatrix types //===----------------------------------------------------------------------===// -// CHECK-LABEL: @cooperative_matrix_load -spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - spirv.Return -} +!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> +!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB> -// CHECK-LABEL: @cooperative_matrix_load_memaccess -spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} +!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> +!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB> -// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type -spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} +// These tests are kept in the same order as the list of compatible ops in the +// SPV_KHR_cooperative_matrix extension spec. -// CHECK-LABEL: @cooperative_matrix_store -spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> +// CHECK-LABEL: @snegate +spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.SNegate {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix + %p = spirv.SNegate %a : !matA_i32 + %q = spirv.SNegate %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_store_memaccess -spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fnegate +spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FNegate {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix + %p = spirv.FNegate %a : !matA_f32 + %q = spirv.FNegate %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_length -spirv.func @cooperative_matrix_length() -> i32 "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.ReturnValue %0 : i32 -} - -// CHECK-LABEL: @cooperative_matrix_muladd -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> +// CHECK-LABEL: @iadd +spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.IAdd %a, %a : !matA_i32 + %q = spirv.IAdd %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_add -spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fadd +spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FAdd %a, %a : !matA_f32 + %q = spirv.FAdd %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_sub -spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @isub +spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.ISub %a, %a : !matA_i32 + %q = spirv.ISub %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_sdiv -spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fsub +spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FSub %a, %a : !matA_f32 + %q = spirv.FSub %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_udiv -spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fmul +spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FMul %a, %a : !matA_f32 + %q = spirv.FMul %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fadd -spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @imul +spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.IMul %a, %a : !matA_i32 + %q = spirv.IMul %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fsub -spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @fdiv +spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FDiv %a, %a : !matA_f32 + %q = spirv.FDiv %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fdiv -spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @sdiv +spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.SDiv %a, %a : !matA_i32 + %q = spirv.SDiv %b, %b : !matB_i32 spirv.Return } -// ----- - -// CHECK-LABEL: @cooperative_matrix_access_chain -spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { - %0 = spirv.Constant 0: i32 - // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 - %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 - spirv.ReturnValue %1 : !spirv.ptr -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> +// CHECK-LABEL: @udiv +spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.UDiv %a, %a : !matA_i32 + %q = spirv.UDiv %b, %b : !matB_i32 spirv.Return } -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup... |
@llvm/pr-subscribers-mlir-core Changes
Patch is 33.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66137.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td index a21fc0ce2f9299c..a055cadc756a7e6 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td @@ -75,10 +75,9 @@ def SPIRV_MatrixTimesScalarOp : SPIRV_Op< let summary = "Scale a floating-point matrix."; let description = [{ - Result Type must be an OpTypeMatrix whose Column Type is a vector of - floating-point type. + Result Type must be a matrix type with a float component type. - The type of Matrix must be the same as Result Type. Each component in + The type of Matrix must be the same as Result Type. Each component in each column in Matrix is multiplied by Scalar. Scalar must have the same type as the Component Type in Result Type. diff --git a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp index d43f7a1823e912b..c8b274ceec3e59d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/CooperativeMatrixOps.cpp @@ -20,9 +20,6 @@ using namespace mlir::spirv::AttrNames; namespace mlir::spirv { -//===----------------------------------------------------------------------===// -// spirv.KHR.CooperativeMatrixLoad -//===----------------------------------------------------------------------===// static LogicalResult verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, @@ -35,13 +32,31 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, << pointeeType; } - // The 'Aligned' memory operand requires an alignment literal to follow, which - // needs to be implemented on the level of op parsing and (de-)serialization. - // TODO: Consider adding support for this attribute value. - if (memoryOperand && - spirv::bitEnumContainsAll(memoryOperand.getValue(), - spirv::MemoryAccess::Aligned)) { - return op->emitOpError("has unhandled memory operand 'Aligned'"); + if (memoryOperand) { + spirv::MemoryAccess operandSet = memoryOperand.getValue(); + + if (isa(op) && + spirv::bitEnumContainsAll(operandSet, + spirv::MemoryAccess::MakePointerAvailable)) { + return op->emitOpError( + "not compatible with memory operand 'MakePointerAvailable'"); + } + + if (isa(op) && + spirv::bitEnumContainsAll(operandSet, + spirv::MemoryAccess::MakePointerVisible)) { + return op->emitOpError( + "not compatible with memory operand 'MakePointerVisible'"); + } + + // The 'Aligned' memory operand requires an alignment literal to follow, + // which needs to be implemented on the level of op parsing and + // (de-)serialization. + // TODO: Consider adding support for this attribute value. + if (spirv::bitEnumContainsAll(memoryOperand.getValue(), + spirv::MemoryAccess::Aligned)) { + return op->emitOpError("has unhandled memory operand 'Aligned'"); + } } // TODO: Verify the memory object behind the pointer: @@ -51,6 +66,10 @@ verifyCoopMatrixAccess(Operation *op, Type pointer, Type coopMatrix, return success(); } +//===----------------------------------------------------------------------===// +// spirv.KHR.CooperativeMatrixLoad +//===----------------------------------------------------------------------===// + LogicalResult KHRCooperativeMatrixLoadOp::verify() { return verifyCoopMatrixAccess(*this, getPointer().getType(), getResult().getType(), getMemoryOperandAttr()); diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 6ebd8515caf037d..6cd75ee6d9cba48 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -34,6 +34,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" #include #include #include @@ -1604,19 +1605,19 @@ LogicalResult spirv::VectorShuffleOp::verify() { //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesScalarOp::verify() { - if (auto inputCoopmat = llvm::dyn_cast( - getMatrix().getType())) { - if (inputCoopmat.getElementType() != getScalar().getType()) - return emitError("input matrix components' type and scaling value must " - "have the same type"); - return success(); - } + Type elementType = + llvm::TypeSwitch(getMatrix().getType()) + .Case( + [](auto matrixType) { return matrixType.getElementType(); }) + .Default([](Type) { return nullptr; }); + + assert(elementType && "Unhandld type"); // Check that the scalar type is the same as the matrix element type. - auto inputMatrix = llvm::cast(getMatrix().getType()); - if (getScalar().getType() != inputMatrix.getElementType()) - return emitError("input matrix components' type and scaling value must " - "have the same type"); + if (getScalar().getType() != elementType) + return emitOpError("input matrix components' type and scaling value must " + "have the same type"); return success(); } diff --git a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir similarity index 68% rename from mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir rename to mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index 3adcd711f74a8f8..445ab8a48d3ce64 100644 --- a/mlir/test/Dialect/SPIRV/IR/cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt --split-input-file --verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// -// CooperativeMatrix (KHR) +// CooperativeMatrix (KHR) extension ops. //===----------------------------------------------------------------------===// // CHECK-LABEL: @cooperative_matrix_length @@ -136,6 +136,15 @@ spirv.func @cooperative_matrix_load_bad_result(%ptr : !spirv.ptr, %stride : i32) "None" { + // expected-error @+1 {{op not compatible with memory operand 'MakePointerAvailable'}} + %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : + !spirv.ptr, i32 -> !spirv.coopmatrix<8x16xi32, Subgroup, MatrixA> + spirv.Return +} + +// ----- + spirv.func @cooperative_matrix_load_aligned(%ptr : !spirv.ptr, %stride : i32) "None" { // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} %0 = spirv.KHR.CooperativeMatrixLoad %ptr, %stride, , : @@ -184,6 +193,16 @@ spirv.func @cooperative_matrix_store_bad_object_type(%ptr : !spirv.ptr, %stride : i32, + %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { + // expected-error @+1 {{op not compatible with memory operand 'MakePointerVisible'}} + spirv.KHR.CooperativeMatrixStore %ptr, %m, %stride, , : + !spirv.ptr, !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>, i32 + spirv.Return +} + +// ----- + spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.coopmatrix<8x16xi32, Workgroup, MatrixA>) "None" { // expected-error @+1 {{op has unhandled memory operand 'Aligned'}} @@ -406,177 +425,153 @@ spirv.func @cooperative_matrix_muladd_matrix_operands(%a : !spirv.coopmatrix<8x1 // ----- //===----------------------------------------------------------------------===// -// NV.CooperativeMatrix +// Standard ops that can be used CooperativeMatrix types //===----------------------------------------------------------------------===// -// CHECK-LABEL: @cooperative_matrix_load -spirv.func @cooperative_matrix_load(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b : !spirv.ptr as !spirv.NV.coopmatrix<16x8xi32, Workgroup> - spirv.Return -} +!matA_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> +!matB_i32 = !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB> -// CHECK-LABEL: @cooperative_matrix_load_memaccess -spirv.func @cooperative_matrix_load_memaccess(%ptr : !spirv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} +!matA_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA> +!matB_f32 = !spirv.coopmatrix<2x2xf32, Subgroup, MatrixB> -// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type -spirv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spirv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLoad {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %b ["Volatile"] : !spirv.ptr, StorageBuffer> as !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.Return -} +// These tests are kept in the same order as the list of compatible ops in the +// SPV_KHR_cooperative_matrix extension spec. -// CHECK-LABEL: @cooperative_matrix_store -spirv.func @cooperative_matrix_store(%ptr : !spirv.ptr, %stride : i32, %m : !spirv.NV.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Workgroup> +// CHECK-LABEL: @snegate +spirv.func @snegate(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.SNegate {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.SNegate {{%.*}} : !spirv.coopmatrix + %p = spirv.SNegate %a : !matA_i32 + %q = spirv.SNegate %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_store_memaccess -spirv.func @cooperative_matrix_store_memaccess(%ptr : !spirv.ptr, %m : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spirv.NV.CooperativeMatrixStore {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.NV.CooperativeMatrixStore %ptr, %m, %stride, %b ["Volatile"] : !spirv.ptr, !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fnegate +spirv.func @fnegate(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FNegate {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FNegate {{%.*}} : !spirv.coopmatrix + %p = spirv.FNegate %a : !matA_f32 + %q = spirv.FNegate %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_length -spirv.func @cooperative_matrix_length() -> i32 "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - spirv.ReturnValue %0 : i32 -} - -// CHECK-LABEL: @cooperative_matrix_muladd -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, %b : !spirv.NV.coopmatrix<32x8xi8, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.NV.CooperativeMatrixMulAdd {{%.*}}, {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x32xi8, Subgroup>, !spirv.NV.coopmatrix<32x8xi8, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> +// CHECK-LABEL: @iadd +spirv.func @iadd(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.IAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.IAdd %a, %a : !matA_i32 + %q = spirv.IAdd %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_add -spirv.func @cooperative_matrix_add(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.IAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.IAdd %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fadd +spirv.func @fadd(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FAdd {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FAdd %a, %a : !matA_f32 + %q = spirv.FAdd %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_sub -spirv.func @cooperative_matrix_sub(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.ISub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.ISub %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @isub +spirv.func @isub(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.ISub {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.ISub %a, %a : !matA_i32 + %q = spirv.ISub %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_sdiv -spirv.func @cooperative_matrix_sdiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.SDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.SDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fsub +spirv.func @fsub(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FSub {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FSub %a, %a : !matA_f32 + %q = spirv.FSub %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_udiv -spirv.func @cooperative_matrix_udiv(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.UDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xi32, Subgroup> - %r = spirv.UDiv %a, %b : !spirv.NV.coopmatrix<8x16xi32, Subgroup> +// CHECK-LABEL: @fmul +spirv.func @fmul(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FMul %a, %a : !matA_f32 + %q = spirv.FMul %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fadd -spirv.func @cooperative_matrix_fadd(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FAdd {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FAdd %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @imul +spirv.func @imul(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.IMul {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.IMul %a, %a : !matA_i32 + %q = spirv.IMul %b, %b : !matB_i32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fsub -spirv.func @cooperative_matrix_fsub(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FSub {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FSub %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @fdiv +spirv.func @fdiv(%a: !matA_f32, %b: !matB_f32) "None" { + // CHECK: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.FDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.FDiv %a, %a : !matA_f32 + %q = spirv.FDiv %b, %b : !matB_f32 spirv.Return } -// CHECK-LABEL: @cooperative_matrix_fdiv -spirv.func @cooperative_matrix_fdiv(%a : !spirv.NV.coopmatrix<8x16xf32, Subgroup>, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup>) "None" { - // CHECK: {{%.*}} = spirv.FDiv {{%.*}}, {{%.*}} : !spirv.NV.coopmatrix<8x16xf32, Subgroup> - %r = spirv.FDiv %a, %b : !spirv.NV.coopmatrix<8x16xf32, Subgroup> +// CHECK-LABEL: @sdiv +spirv.func @sdiv(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.SDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.SDiv %a, %a : !matA_i32 + %q = spirv.SDiv %b, %b : !matB_i32 spirv.Return } -// ----- - -// CHECK-LABEL: @cooperative_matrix_access_chain -spirv.func @cooperative_matrix_access_chain(%a : !spirv.ptr, Function>) -> !spirv.ptr "None" { - %0 = spirv.Constant 0: i32 - // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}] : !spirv.ptr, Function>, i32 - %1 = spirv.AccessChain %a[%0] : !spirv.ptr, Function>, i32 - spirv.ReturnValue %1 : !spirv.ptr -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<16x16xi32, Subgroup>, !spirv.NV.coopmatrix<16x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> +// CHECK-LABEL: @udiv +spirv.func @udiv(%a: !matA_i32, %b: !matB_i32) "None" { + // CHECK: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + // CHECK-NEXT: spirv.UDiv {{%.*}}, {{%.*}} : !spirv.coopmatrix + %p = spirv.UDiv %a, %a : !matA_i32 + %q = spirv.UDiv %b, %b : !matB_i32 spirv.Return } -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<8x8xi32, Subgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix size must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, !spirv.NV.coopmatrix<8x8xi32, Subgroup> -> !spirv.NV.coopmatrix<8x8xi32, Subgroup> - spirv.Return -} - -// ----- - -spirv.func @cooperative_matrix_muladd(%a : !spirv.NV.coopmatrix<8x16xi32, Subgroup>, %b : !spirv.NV.coopmatrix<16x8xi32, Workgroup>, %c : !spirv.NV.coopmatrix<8x8xi32, Subgroup>) "None" { - // expected-error @+1 {{'spirv.NV.CooperativeMatrixMulAdd' op matrix scope must match}} - %r = spirv.NV.CooperativeMatrixMulAdd %a, %b, %c : !spirv.NV.coopmatrix<8x16xi32, Subgroup... |
…vm#66137) - Check `MakePointer*` load/store attribute values. - Support coop matrix types in `MatrixTimesScalar` verification. - Add test cases for all the remaining ops that accept coop matrix types. - Split NV and KHR tests.
MakePointer*
load/store attribute values.MatrixTimesScalar
verification.