-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SPIRV] Support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension #135225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPIRV] Support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension #135225
Conversation
@llvm/pr-subscribers-backend-spir-v Author: Vyacheslav Levytskyy (VyacheslavLevytskyy) ChangesAdds support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension according to https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroup_matrix_multiply_accumulate.asciidoc Patch is 29.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135225.diff 8 Files Affected:
diff --git a/llvm/docs/SPIRVUsage.rst b/llvm/docs/SPIRVUsage.rst
index 406dfbea20b73..6ff8034cac00c 100644
--- a/llvm/docs/SPIRVUsage.rst
+++ b/llvm/docs/SPIRVUsage.rst
@@ -211,6 +211,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
- Adds the ability to specify the maximum error for floating-point operations.
* - ``SPV_INTEL_ternary_bitwise_function``
- Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
+ * - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
+ - Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 16364ab30f280..e090fb67b3231 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1161,9 +1161,15 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
if (Call->isSpirvOp()) {
- if (GroupBuiltin->NoGroupOperation)
+ if (GroupBuiltin->NoGroupOperation) {
+ SmallVector<uint32_t, 1> ImmArgs;
+ if (GroupBuiltin->Opcode ==
+ SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL &&
+ Call->Arguments.size() > 4)
+ ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[4], MRI));
return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
- GR->getSPIRVTypeID(Call->ReturnType));
+ GR->getSPIRVTypeID(Call->ReturnType), ImmArgs);
+ }
// Group Operation is a literal
Register GroupOpReg = Call->Arguments[1];
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index b504e7b04d336..a3f27dde76b65 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -763,6 +763,7 @@ class GroupBuiltin<string name, Op operation> {
bit NoGroupOperation = !or(IsElect, IsAllOrAny, IsAllEqual,
IsBallot, IsInverseBallot,
IsBallotBitExtract, IsBallotFindBit,
+ !eq(operation, OpSubgroupMatrixMultiplyAccumulateINTEL),
!eq(operation, OpGroupNonUniformShuffle),
!eq(operation, OpGroupNonUniformShuffleXor),
!eq(operation, OpGroupNonUniformShuffleUp),
@@ -847,6 +848,9 @@ defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformBallotFindLSB", 2, 2
defm : DemangledGroupBuiltin<"group_ballot_find_msb", OnlySub, OpGroupNonUniformBallotFindMSB>;
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformBallotFindMSB", 2, 2, OpGroupNonUniformBallotFindMSB>;
+// SPV_INTEL_subgroup_matrix_multiply_accumulate
+defm : DemangledGroupBuiltinWrapper<"__spirv_SubgroupMatrixMultiplyAccumulateINTEL", 4, 5, OpSubgroupMatrixMultiplyAccumulateINTEL>;
+
// cl_khr_subgroup_shuffle
defm : DemangledGroupBuiltin<"group_shuffle", OnlySub, OpGroupNonUniformShuffle>;
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformShuffle", 3, 3, OpGroupNonUniformShuffle>;
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 53e88aa485568..ad0bc5a904682 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -93,6 +93,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
SPIRV::Extension::Extension::SPV_INTEL_long_composites},
{"SPV_INTEL_fp_max_error",
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
+ {"SPV_INTEL_subgroup_matrix_multiply_accumulate",
+ SPIRV::Extension::Extension::
+ SPV_INTEL_subgroup_matrix_multiply_accumulate},
{"SPV_INTEL_ternary_bitwise_function",
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function}};
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index 53064ebb51271..6d8c84945d7d4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -736,6 +736,10 @@ def OpGroupFMax: OpGroup<"FMax", 269>;
def OpGroupUMax: OpGroup<"UMax", 270>;
def OpGroupSMax: OpGroup<"SMax", 271>;
+def OpSubgroupMatrixMultiplyAccumulateINTEL: Op<6237, (outs ID:$res),
+ (ins TYPE:$ty, ID:$KDim, ID:$A, ID:$B, ID:$C, variable_ops),
+ "$res = OpSubgroupMatrixMultiplyAccumulateINTEL $ty $KDim $A $B $C">;
+
// TODO: 3.42.22. Device-Side Enqueue Instructions
def OpEnqueueKernel: Op<292, (outs ID:$res), (ins TYPE:$type, ID:$queue, ID:$flags, ID:$NDR, ID:$nevents, ID:$wevents,
ID:$revent, ID:$invoke, ID:$param, ID:$psize, ID:$palign, variable_ops),
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index b1e5e4328cd32..6e1c41d9f20cb 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1799,6 +1799,20 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
break;
}
+ case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
+ if (!ST.canUseExtension(
+ SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
+ report_fatal_error(
+ "OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
+ "following SPIR-V "
+ "extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
+ false);
+ Reqs.addExtension(
+ SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
+ Reqs.addCapability(
+ SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
+ break;
+ }
case SPIRV::OpBitwiseFunctionINTEL: {
if (!ST.canUseExtension(
SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 0db8a37f8683c..afd3a5206926c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -314,6 +314,7 @@ defm SPV_INTEL_long_composites : ExtensionOperand<117>;
defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
+defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
@@ -515,6 +516,7 @@ defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_ima
defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
defm TernaryBitwiseFunctionINTEL : CapabilityOperand<6241, 0, 0, [SPV_INTEL_ternary_bitwise_function], []>;
+defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_INTEL_subgroup_matrix_multiply_accumulate], []>;
//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
diff --git a/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll
new file mode 100644
index 0000000000000..0cd6992936eeb
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_subgroup_matrix_multiply_accumulate/subgroup_matrix_multiply_accumulate_generic.ll
@@ -0,0 +1,229 @@
+; Adapted from Khronos Translator: subgroup_matrix_multiply_accumulate_generic.ll
+
+; generated with mma.cl:
+; #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+;
+; // all combinations of parameter types
+; int __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int Matrix_A, int8 Matrix_B, int Matrix_C, int Operands);
+; int2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int2 Matrix_A, int8 Matrix_B, int2 Matrix_C, int Operands);
+; int4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int4 Matrix_A, int8 Matrix_B, int4 Matrix_C, int Operands);
+; int8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int8 Matrix_A, int8 Matrix_B, int8 Matrix_C, int Operands);
+;
+; float __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int Matrix_A, int8 Matrix_B, float Matrix_C, int Operands);
+; float2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int2 Matrix_A, int8 Matrix_B, float2 Matrix_C, int Operands);
+; float4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int4 Matrix_A, int8 Matrix_B, float4 Matrix_C, int Operands);
+; float8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, int8 Matrix_A, int8 Matrix_B, float8 Matrix_C, int Operands);
+;
+; int __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short Matrix_A, int8 Matrix_B, int Matrix_C, int Operands);
+; int2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short2 Matrix_A, int8 Matrix_B, int2 Matrix_C, int Operands);
+; int4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, int4 Matrix_C, int Operands);
+; int8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short8 Matrix_A, int8 Matrix_B, int8 Matrix_C, int Operands);
+;
+; float __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short Matrix_A, int8 Matrix_B, float Matrix_C, int Operands);
+; float2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short2 Matrix_A, int8 Matrix_B, float2 Matrix_C, int Operands);
+; float4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, float4 Matrix_C, int Operands);
+; float8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short8 Matrix_A, int8 Matrix_B, float8 Matrix_C, int Operands);
+;
+; half __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short Matrix_A, int8 Matrix_B, half Matrix_C, int Operands);
+; half2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short2 Matrix_A, int8 Matrix_B, half2 Matrix_C, int Operands);
+; half4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, half4 Matrix_C, int Operands);
+; half8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short8 Matrix_A, int8 Matrix_B, half8 Matrix_C, int Operands);
+;
+; short __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short Matrix_A, int8 Matrix_B, short Matrix_C, int Operands);
+; short2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short2 Matrix_A, int8 Matrix_B, short2 Matrix_C, int Operands);
+; short4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, short4 Matrix_C, int Operands);
+; short8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short8 Matrix_A, int8 Matrix_B, short8 Matrix_C, int Operands);
+;
+; float __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, float Matrix_A, float8 Matrix_B, float Matrix_C, int Operands);
+; float2 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, float2 Matrix_A, float8 Matrix_B, float2 Matrix_C, int Operands);
+; float4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, float4 Matrix_A, float8 Matrix_B, float4 Matrix_C, int Operands);
+; float8 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, float8 Matrix_A, float8 Matrix_B, float8 Matrix_C, int Operands);
+;
+; // no operands
+; float4 __spirv_SubgroupMatrixMultiplyAccumulateINTEL(int K_Dim, short4 Matrix_A, int8 Matrix_B, float4 Matrix_C);
+;
+; void foo(int iM, int2 iM2, int4 iM4, int8 iM8,
+; short sM, short2 sM2, short4 sM4, short8 sM8,
+; float fM, float2 fM2, float4 fM4, float8 fM8,
+; half hM, half2 hM2, half4 hM4, half8 hM8) {
+; const int i = 42;
+; int D = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM, iM8, iM, 0xA);
+; int2 D2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM2, iM8, iM2, 0xA);
+; int4 D4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM4, iM8, iM4, 0xA);
+; int8 D8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM8, iM8, iM8, 0xA);
+;
+; float fD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM, iM8, fM, 0xA);
+; float2 fD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM2, iM8, fM2, 0xA);
+; float4 fD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM4, iM8, fM4, 0xA);
+; float8 fD8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, iM8, iM8, fM8, 0xA);
+;
+; int sD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM, iM8, iM, 0xA);
+; int2 sD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM2, iM8, iM2, 0xA);
+; int4 sD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM4, iM8, iM4, 0xA);
+; int8 sD8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM8, iM8, iM8, 0xA);
+;
+; float sfD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM, iM8, fM, 0xA);
+; float2 sfD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM2, iM8, fM2, 0xA);
+; float4 sfD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM4, iM8, fM4, 0xA);
+; float8 sfD8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM8, iM8, fM8, 0xA);
+;
+; half hD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM, iM8, hM, 0xA);
+; half2 hD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM2, iM8, hM2, 0xA);
+; half4 hD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM4, iM8, hM4, 0xA);
+; half8 hD8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM8, iM8, hM8, 0xA);
+;
+; short ssD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM, iM8, sM, 0xA);
+; short2 ssD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM2, iM8, sM2, 0xA);
+; short4 ssD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM4, iM8, sM4, 0xA);
+; short8 ssD8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM8, iM8, sM8, 0xA);
+;
+; float ffD = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, fM, fM8, fM, 0xA);
+; float2 ffD2 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, fM2, fM8, fM2, 0xA);
+; float4 ffD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, fM4, fM8, fM4, 0xA);
+; float8 ffD8 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, fM8, fM8, fM8, 0xA);
+;
+; float4 noOpD4 = __spirv_SubgroupMatrixMultiplyAccumulateINTEL(i, sM4, iM8, fM4);
+; }
+; clang -cc1 -cl-std=clc++2021 -triple spir64-unknown-unknown -emit-llvm -finclude-default-header mma.cl -o tmp.ll
+
+; RUN: not llc -O0 -mtriple=spirv32-unknown-unknown %s -o %t.spvt 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
+; CHECK-ERROR: requires the following SPIR-V extension: SPV_INTEL_subgroup_matrix_multiply_accumulate
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown --spirv-ext=+SPV_INTEL_subgroup_matrix_multiply_accumulate %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: OpCapability SubgroupMatrixMultiplyAccumulateINTEL
+; CHECK: OpExtension "SPV_INTEL_subgroup_matrix_multiply_accumulate"
+; CHECK-DAG: %[[#Int32Ty:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#Int16Ty:]] = OpTypeInt 16 0
+; CHECK-DAG: %[[#Const42:]] = OpConstant %[[#Int32Ty]] 42
+; CHECK-DAG: %[[#VoidTy:]] = OpTypeVoid
+; CHECK-DAG: %[[#Vec2Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 2
+; CHECK-DAG: %[[#Vec4Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 4
+; CHECK-DAG: %[[#Vec8Int32Ty:]] = OpTypeVector %[[#Int32Ty]] 8
+; CHECK-DAG: %[[#Vec2Int16Ty:]] = OpTypeVector %[[#Int16Ty]] 2
+; CHECK-DAG: %[[#Vec4Int16Ty:]] = OpTypeVector %[[#Int16Ty]] 4
+; CHECK-DAG: %[[#Vec8Int16Ty:]] = OpTypeVector %[[#Int16Ty]] 8
+; CHECK-DAG: %[[#FloatTy:]] = OpTypeFloat 32
+; CHECK-DAG: %[[#Vec2FloatTy:]] = OpTypeVector %[[#FloatTy]] 2
+; CHECK-DAG: %[[#Vec4FloatTy:]] = OpTypeVector %[[#FloatTy]] 4
+; CHECK-DAG: %[[#Vec8FloatTy:]] = OpTypeVector %[[#FloatTy]] 8
+; CHECK-DAG: %[[#HalfTy:]] = OpTypeFloat 16
+; CHECK-DAG: %[[#Vec2HalfTy:]] = OpTypeVector %[[#HalfTy]] 2
+; CHECK-DAG: %[[#Vec4HalfTy:]] = OpTypeVector %[[#HalfTy]] 4
+; CHECK-DAG: %[[#Vec8HalfTy:]] = OpTypeVector %[[#HalfTy]] 8
+; CHECK: %[[#iM:]] = OpFunctionParameter %[[#Int32Ty]]
+; CHECK: %[[#iM2:]] = OpFunctionParameter %[[#Vec2Int32Ty]]
+; CHECK: %[[#iM4:]] = OpFunctionParameter %[[#Vec4Int32Ty]]
+; CHECK: %[[#iM8:]] = OpFunctionParameter %[[#Vec8Int32Ty]]
+; CHECK: %[[#sM:]] = OpFunctionParameter %[[#Int16Ty]]
+; CHECK: %[[#sM2:]] = OpFunctionParameter %[[#Vec2Int16Ty]]
+; CHECK: %[[#sM4:]] = OpFunctionParameter %[[#Vec4Int16Ty]]
+; CHECK: %[[#sM8:]] = OpFunctionParameter %[[#Vec8Int16Ty]]
+; CHECK: %[[#fM:]] = OpFunctionParameter %[[#FloatTy]]
+; CHECK: %[[#fM2:]] = OpFunctionParameter %[[#Vec2FloatTy]]
+; CHECK: %[[#fM4:]] = OpFunctionParameter %[[#Vec4FloatTy]]
+; CHECK: %[[#fM8:]] = OpFunctionParameter %[[#Vec8FloatTy]]
+; CHECK: %[[#hM:]] = OpFunctionParameter %[[#HalfTy]]
+; CHECK: %[[#hM2:]] = OpFunctionParameter %[[#Vec2HalfTy]]
+; CHECK: %[[#hM4:]] = OpFunctionParameter %[[#Vec4HalfTy]]
+; CHECK: %[[#hM8:]] = OpFunctionParameter %[[#Vec8HalfTy]]
+; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#iM]] 10
+; CHECK: %[[#]] = OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#iM2]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#iM4]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#iM8]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#iM]] %[[#iM8]] %[[#fM]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#iM2]] %[[#iM8]] %[[#fM2]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#iM4]] %[[#iM8]] %[[#fM4]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#iM8]] %[[#iM8]] %[[#fM8]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Int32Ty]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#iM]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2Int32Ty]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#iM2]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4Int32Ty]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#iM4]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8Int32Ty]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#iM8]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#FloatTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#fM]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2FloatTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#fM2]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4FloatTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#fM4]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8FloatTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#fM8]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#HalfTy]] %[[#Const42]] %[[#sM]] %[[#iM8]] %[[#hM]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec2HalfTy]] %[[#Const42]] %[[#sM2]] %[[#iM8]] %[[#hM2]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec4HalfTy]] %[[#Const42]] %[[#sM4]] %[[#iM8]] %[[#hM4]] 10
+; CHECK: OpSubgroupMatrixMultiplyAccumulateINTEL %[[#Vec8HalfTy]] %[[#Const42]] %[[#sM8]] %[[#iM8]] %[[#hM8]]...
[truncated]
|
@MrSidims can you please have a look |
… SPIR-V extension (llvm#135225) Adds support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension according to https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroup_matrix_multiply_accumulate.asciidoc
… SPIR-V extension (llvm#135225) Adds support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension according to https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroup_matrix_multiply_accumulate.asciidoc
… SPIR-V extension (llvm#135225) Adds support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension according to https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroup_matrix_multiply_accumulate.asciidoc
Adds support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension according to https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroup_matrix_multiply_accumulate.asciidoc