Skip to content

Commit 8c47f23

Browse files
[SPIRV] Support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension (#135225)
Adds support for the SPV_INTEL_subgroup_matrix_multiply_accumulate SPIR-V extension according to https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_subgroup_matrix_multiply_accumulate.asciidoc
1 parent 3c3fb35 commit 8c47f23

File tree

8 files changed

+266
-2
lines changed

8 files changed

+266
-2
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
211211
- Adds the ability to specify the maximum error for floating-point operations.
212212
* - ``SPV_INTEL_ternary_bitwise_function``
213213
- Adds a bitwise instruction on three operands and a look-up table index for specifying the bitwise operation to perform.
214+
* - ``SPV_INTEL_subgroup_matrix_multiply_accumulate``
215+
- Adds an instruction to compute the matrix product of an M x K matrix with a K x N matrix and then add an M x N matrix.
214216

215217
To enable multiple extensions, list them separated by comma. For example, to enable support for atomic operations on floating-point numbers and arbitrary precision integers, use:
216218

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,9 +1161,15 @@ static bool generateGroupInst(const SPIRV::IncomingCall *Call,
11611161

11621162
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
11631163
if (Call->isSpirvOp()) {
1164-
if (GroupBuiltin->NoGroupOperation)
1164+
if (GroupBuiltin->NoGroupOperation) {
1165+
SmallVector<uint32_t, 1> ImmArgs;
1166+
if (GroupBuiltin->Opcode ==
1167+
SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL &&
1168+
Call->Arguments.size() > 4)
1169+
ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[4], MRI));
11651170
return buildOpFromWrapper(MIRBuilder, GroupBuiltin->Opcode, Call,
1166-
GR->getSPIRVTypeID(Call->ReturnType));
1171+
GR->getSPIRVTypeID(Call->ReturnType), ImmArgs);
1172+
}
11671173

11681174
// Group Operation is a literal
11691175
Register GroupOpReg = Call->Arguments[1];

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ class GroupBuiltin<string name, Op operation> {
763763
bit NoGroupOperation = !or(IsElect, IsAllOrAny, IsAllEqual,
764764
IsBallot, IsInverseBallot,
765765
IsBallotBitExtract, IsBallotFindBit,
766+
!eq(operation, OpSubgroupMatrixMultiplyAccumulateINTEL),
766767
!eq(operation, OpGroupNonUniformShuffle),
767768
!eq(operation, OpGroupNonUniformShuffleXor),
768769
!eq(operation, OpGroupNonUniformShuffleUp),
@@ -847,6 +848,9 @@ defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformBallotFindLSB", 2, 2
847848
defm : DemangledGroupBuiltin<"group_ballot_find_msb", OnlySub, OpGroupNonUniformBallotFindMSB>;
848849
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformBallotFindMSB", 2, 2, OpGroupNonUniformBallotFindMSB>;
849850

851+
// SPV_INTEL_subgroup_matrix_multiply_accumulate
852+
defm : DemangledGroupBuiltinWrapper<"__spirv_SubgroupMatrixMultiplyAccumulateINTEL", 4, 5, OpSubgroupMatrixMultiplyAccumulateINTEL>;
853+
850854
// cl_khr_subgroup_shuffle
851855
defm : DemangledGroupBuiltin<"group_shuffle", OnlySub, OpGroupNonUniformShuffle>;
852856
defm : DemangledGroupBuiltinWrapper<"__spirv_GroupNonUniformShuffle", 3, 3, OpGroupNonUniformShuffle>;

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
9393
SPIRV::Extension::Extension::SPV_INTEL_long_composites},
9494
{"SPV_INTEL_fp_max_error",
9595
SPIRV::Extension::Extension::SPV_INTEL_fp_max_error},
96+
{"SPV_INTEL_subgroup_matrix_multiply_accumulate",
97+
SPIRV::Extension::Extension::
98+
SPV_INTEL_subgroup_matrix_multiply_accumulate},
9699
{"SPV_INTEL_ternary_bitwise_function",
97100
SPIRV::Extension::Extension::SPV_INTEL_ternary_bitwise_function}};
98101

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,10 @@ def OpGroupFMax: OpGroup<"FMax", 269>;
736736
def OpGroupUMax: OpGroup<"UMax", 270>;
737737
def OpGroupSMax: OpGroup<"SMax", 271>;
738738

739+
def OpSubgroupMatrixMultiplyAccumulateINTEL: Op<6237, (outs ID:$res),
740+
(ins TYPE:$ty, ID:$KDim, ID:$A, ID:$B, ID:$C, variable_ops),
741+
"$res = OpSubgroupMatrixMultiplyAccumulateINTEL $ty $KDim $A $B $C">;
742+
739743
// TODO: 3.42.22. Device-Side Enqueue Instructions
740744
def OpEnqueueKernel: Op<292, (outs ID:$res), (ins TYPE:$type, ID:$queue, ID:$flags, ID:$NDR, ID:$nevents, ID:$wevents,
741745
ID:$revent, ID:$invoke, ID:$param, ID:$psize, ID:$palign, variable_ops),

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,6 +1799,20 @@ void addInstrRequirements(const MachineInstr &MI,
17991799
Reqs.addCapability(SPIRV::Capability::LongCompositesINTEL);
18001800
break;
18011801
}
1802+
case SPIRV::OpSubgroupMatrixMultiplyAccumulateINTEL: {
1803+
if (!ST.canUseExtension(
1804+
SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate))
1805+
report_fatal_error(
1806+
"OpSubgroupMatrixMultiplyAccumulateINTEL instruction requires the "
1807+
"following SPIR-V "
1808+
"extension: SPV_INTEL_subgroup_matrix_multiply_accumulate",
1809+
false);
1810+
Reqs.addExtension(
1811+
SPIRV::Extension::SPV_INTEL_subgroup_matrix_multiply_accumulate);
1812+
Reqs.addCapability(
1813+
SPIRV::Capability::SubgroupMatrixMultiplyAccumulateINTEL);
1814+
break;
1815+
}
18021816
case SPIRV::OpBitwiseFunctionINTEL: {
18031817
if (!ST.canUseExtension(
18041818
SPIRV::Extension::SPV_INTEL_ternary_bitwise_function))

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ defm SPV_INTEL_long_composites : ExtensionOperand<117>;
314314
defm SPV_INTEL_memory_access_aliasing : ExtensionOperand<118>;
315315
defm SPV_INTEL_fp_max_error : ExtensionOperand<119>;
316316
defm SPV_INTEL_ternary_bitwise_function : ExtensionOperand<120>;
317+
defm SPV_INTEL_subgroup_matrix_multiply_accumulate : ExtensionOperand<121>;
317318

318319
//===----------------------------------------------------------------------===//
319320
// Multiclass used to define Capabilities enum values and at the same time
@@ -515,6 +516,7 @@ defm BindlessImagesINTEL : CapabilityOperand<6528, 0, 0, [SPV_INTEL_bindless_ima
515516
defm MemoryAccessAliasingINTEL : CapabilityOperand<5910, 0, 0, [SPV_INTEL_memory_access_aliasing], []>;
516517
defm FPMaxErrorINTEL : CapabilityOperand<6169, 0, 0, [SPV_INTEL_fp_max_error], []>;
517518
defm TernaryBitwiseFunctionINTEL : CapabilityOperand<6241, 0, 0, [SPV_INTEL_ternary_bitwise_function], []>;
519+
defm SubgroupMatrixMultiplyAccumulateINTEL : CapabilityOperand<6236, 0, 0, [SPV_INTEL_subgroup_matrix_multiply_accumulate], []>;
518520

519521
//===----------------------------------------------------------------------===//
520522
// Multiclass used to define SourceLanguage enum values and at the same time

0 commit comments

Comments
 (0)