Skip to content

Commit d057b53

Browse files
authored
[SPIR-V] Add SPV_INTEL_joint_matrix extension (#118578)
The spec is available here: intel/llvm#12497 The PR doesn't add OpCooperativeMatrixApplyFunctionINTEL instruction as it's still experimental and not properly tested E2E. The PR also fixes few bugs in the related code: 1. CooperativeMatrixMulAddKHR optional operand must be literal, not a constant; 2. Fixed available capabilities table creation for a case, when a single extension adds few capabilities, that occupy not contiguous op codes. --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent fe4bba6 commit d057b53

17 files changed

+531
-8
lines changed

llvm/docs/SPIRVUsage.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ list of supported SPIR-V extensions, sorted alphabetically by their extension na
179179
- Introduces two new storage classes that are subclasses of the CrossWorkgroup storage class that provides additional information that can enable optimization.
180180
* - ``SPV_INTEL_variable_length_array``
181181
- Allows to allocate local arrays whose number of elements is unknown at compile time.
182+
* - ``SPV_INTEL_joint_matrix``
183+
- Adds few matrix capabilities on top of SPV_KHR_cooperative_matrix extension, such as matrix prefetch, get element coordinate and checked load/store/construct instructions, tensor float 32 and bfloat type interpretations for multuply-add instruction.
182184
* - ``SPV_KHR_bit_instructions``
183185
- Enables bit instructions to be used by SPIR-V modules without requiring the Shader capability.
184186
* - ``SPV_KHR_expect_assume``

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,12 @@ getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension) {
137137

138138
CapabilityList Capabilities;
139139
while (Entry &&
140-
Entry->Category == SPIRV::OperandCategory::CapabilityOperand &&
141-
Entry->ReqExtension == Extension) {
140+
Entry->Category == SPIRV::OperandCategory::CapabilityOperand) {
141+
// Some capabilities' codes might go not in order.
142+
if (Entry->ReqExtension != Extension) {
143+
++Entry;
144+
continue;
145+
}
142146
Capabilities.push_back(
143147
static_cast<SPIRV::Capability::Capability>(Entry->Value));
144148
++Entry;

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,16 @@ namespace Opcode {
207207
#include "SPIRVGenTables.inc"
208208
} // namespace Opcode
209209

210+
namespace CooperativeMatrixLayout {
211+
#define GET_CooperativeMatrixLayout_DECL
212+
#include "SPIRVGenTables.inc"
213+
} // namespace CooperativeMatrixLayout
214+
215+
namespace CooperativeMatrixOperands {
216+
#define GET_CooperativeMatrixOperands_DECL
217+
#include "SPIRVGenTables.inc"
218+
} // namespace CooperativeMatrixOperands
219+
210220
struct ExtendedBuiltin {
211221
StringRef Name;
212222
InstructionSet::InstructionSet Set;

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVInstPrinter.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,34 @@ void SPIRVInstPrinter::printInst(const MCInst *MI, uint64_t Address,
211211
// are part of the variable value.
212212
printOpConstantVarOps(MI, NumFixedOps - 1, OS);
213213
break;
214+
case SPIRV::OpCooperativeMatrixMulAddKHR: {
215+
const unsigned NumOps = MI->getNumOperands();
216+
if (NumFixedOps == NumOps)
217+
break;
218+
219+
OS << ' ';
220+
const unsigned MulAddOp = MI->getOperand(FirstVariableIndex).getImm();
221+
if (MulAddOp == 0) {
222+
printSymbolicOperand<
223+
OperandCategory::CooperativeMatrixOperandsOperand>(
224+
MI, FirstVariableIndex, OS);
225+
} else {
226+
std::string Buffer;
227+
for (unsigned Mask = 0x1;
228+
Mask != SPIRV::CooperativeMatrixOperands::
229+
MatrixResultBFloat16ComponentsINTEL;
230+
Mask <<= 1) {
231+
if (MulAddOp & Mask) {
232+
if (!Buffer.empty())
233+
Buffer += '|';
234+
Buffer += getSymbolicOperandMnemonic(
235+
OperandCategory::CooperativeMatrixOperandsOperand, Mask);
236+
}
237+
}
238+
OS << Buffer;
239+
}
240+
break;
241+
}
214242
default:
215243
printRemainingVariableOps(MI, NumFixedOps, OS);
216244
break;

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,15 +1969,49 @@ static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
19691969
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
19701970
unsigned Opcode =
19711971
SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
1972-
bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
1972+
bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR &&
1973+
Opcode != SPIRV::OpCooperativeMatrixStoreCheckedINTEL &&
1974+
Opcode != SPIRV::OpCooperativeMatrixPrefetchINTEL;
19731975
unsigned ArgSz = Call->Arguments.size();
19741976
unsigned LiteralIdx = 0;
1975-
if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
1976-
LiteralIdx = 3;
1977-
else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
1978-
LiteralIdx = 4;
1977+
switch (Opcode) {
1978+
// Memory operand is optional and is literal.
1979+
case SPIRV::OpCooperativeMatrixLoadKHR:
1980+
LiteralIdx = ArgSz > 3 ? 3 : 0;
1981+
break;
1982+
case SPIRV::OpCooperativeMatrixStoreKHR:
1983+
LiteralIdx = ArgSz > 4 ? 4 : 0;
1984+
break;
1985+
case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1986+
LiteralIdx = ArgSz > 7 ? 7 : 0;
1987+
break;
1988+
case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1989+
LiteralIdx = ArgSz > 8 ? 8 : 0;
1990+
break;
1991+
// Cooperative Matrix Operands operand is optional and is literal.
1992+
case SPIRV::OpCooperativeMatrixMulAddKHR:
1993+
LiteralIdx = ArgSz > 3 ? 3 : 0;
1994+
break;
1995+
};
1996+
19791997
SmallVector<uint32_t, 1> ImmArgs;
19801998
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
1999+
if (Opcode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
2000+
const uint32_t CacheLevel = getConstFromIntrinsic(Call->Arguments[3], MRI);
2001+
auto MIB = MIRBuilder.buildInstr(SPIRV::OpCooperativeMatrixPrefetchINTEL)
2002+
.addUse(Call->Arguments[0]) // pointer
2003+
.addUse(Call->Arguments[1]) // rows
2004+
.addUse(Call->Arguments[2]) // columns
2005+
.addImm(CacheLevel) // cache level
2006+
.addUse(Call->Arguments[4]); // memory layout
2007+
if (ArgSz > 5)
2008+
MIB.addUse(Call->Arguments[5]); // stride
2009+
if (ArgSz > 6) {
2010+
const uint32_t MemOp = getConstFromIntrinsic(Call->Arguments[6], MRI);
2011+
MIB.addImm(MemOp); // memory operand
2012+
}
2013+
return true;
2014+
}
19812015
if (LiteralIdx > 0)
19822016
ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
19832017
Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);

llvm/lib/Target/SPIRV/SPIRVBuiltins.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -695,6 +695,13 @@ defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, C
695695
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 4, OpCooperativeMatrixMulAddKHR>;
696696
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>;
697697

698+
// Cooperative Matrix Intel builtin records:
699+
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixPrefetchINTEL", OpenCL_std, CoopMatr, 5, 7, OpCooperativeMatrixPrefetchINTEL>;
700+
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadCheckedINTEL", OpenCL_std, CoopMatr, 6, 8, OpCooperativeMatrixLoadCheckedINTEL>;
701+
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreCheckedINTEL", OpenCL_std, CoopMatr, 7, 9, OpCooperativeMatrixStoreCheckedINTEL>;
702+
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixConstructCheckedINTEL", OpenCL_std, CoopMatr, 5, 5, OpCooperativeMatrixConstructCheckedINTEL>;
703+
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixGetElementCoordINTEL", OpenCL_std, CoopMatr, 2, 2, OpCooperativeMatrixGetElementCoordINTEL>;
704+
698705
//===----------------------------------------------------------------------===//
699706
// Class defining a work/sub group builtin that should be translated into a
700707
// SPIR-V instruction using the defined properties.

llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ static const std::map<std::string, SPIRV::Extension::Extension, std::less<>>
5151
SPIRV::Extension::Extension::SPV_INTEL_subgroups},
5252
{"SPV_INTEL_media_block_io",
5353
SPIRV::Extension::Extension::SPV_INTEL_media_block_io},
54+
{"SPV_INTEL_joint_matrix",
55+
SPIRV::Extension::Extension::SPV_INTEL_joint_matrix},
5456
{"SPV_KHR_uniform_group_instructions",
5557
SPIRV::Extension::Extension::SPV_KHR_uniform_group_instructions},
5658
{"SPV_KHR_no_integer_wrap_decoration",

llvm/lib/Target/SPIRV/SPIRVInstrInfo.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,23 @@ def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
895895
def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
896896
"$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
897897

898+
// SPV_INTEL_joint_matrix
899+
def OpCooperativeMatrixLoadCheckedINTEL: Op<6193, (outs ID:$res),
900+
(ins TYPE:$resType, ID:$pointer, ID:$xOffset, ID:$yOffset, ID:$memory_layout, ID:$height, ID:$width, variable_ops),
901+
"$res = OpCooperativeMatrixLoadCheckedINTEL $resType $pointer $xOffset $yOffset $memory_layout $height $width">;
902+
def OpCooperativeMatrixStoreCheckedINTEL: Op<6194, (outs),
903+
(ins ID:$pointer, ID:$xOffset, ID:$yOffset, ID:$objectToStore, ID:$memory_layout, ID:$height, ID:$width, variable_ops),
904+
"OpCooperativeMatrixStoreCheckedINTEL $pointer $xOffset $yOffset $objectToStore $memory_layout $height $width">;
905+
def OpCooperativeMatrixConstructCheckedINTEL: Op<6195, (outs ID:$res),
906+
(ins TYPE:$resType, ID:$xOffset, ID:$yOffset, ID:$height, ID:$width, ID:$value),
907+
"$res = OpCooperativeMatrixConstructCheckedINTEL $resType $xOffset $yOffset $height $width $value">;
908+
def OpCooperativeMatrixGetElementCoordINTEL: Op<6440, (outs ID:$res),
909+
(ins TYPE:$resType, ID:$matrix, ID:$index),
910+
"$res = OpCooperativeMatrixGetElementCoordINTEL $resType $matrix $index">;
911+
def OpCooperativeMatrixPrefetchINTEL: Op<6449, (outs),
912+
(ins ID:$pointer, ID:$rows, ID:$columns, i32imm:$cacheLevel, ID:$memory_layout, variable_ops),
913+
"OpCooperativeMatrixPrefetchINTEL $pointer $rows $columns $cacheLevel $memory_layout">;
914+
898915
// SPV_EXT_arithmetic_fence
899916
def OpArithmeticFenceEXT: Op<6145, (outs ID:$res), (ins TYPE:$type, ID:$target),
900917
"$res = OpArithmeticFenceEXT $type $target">;

llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,138 @@ void addInstrRequirements(const MachineInstr &MI,
14371437
Reqs.addCapability(SPIRV::Capability::SplitBarrierINTEL);
14381438
}
14391439
break;
1440+
case SPIRV::OpCooperativeMatrixMulAddKHR: {
1441+
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1442+
report_fatal_error("Cooperative matrix instructions require the "
1443+
"following SPIR-V extension: "
1444+
"SPV_KHR_cooperative_matrix",
1445+
false);
1446+
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1447+
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1448+
constexpr unsigned MulAddMaxSize = 6;
1449+
if (MI.getNumOperands() != MulAddMaxSize)
1450+
break;
1451+
const int64_t CoopOperands = MI.getOperand(MulAddMaxSize - 1).getImm();
1452+
if (CoopOperands &
1453+
SPIRV::CooperativeMatrixOperands::MatrixAAndBTF32ComponentsINTEL) {
1454+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1455+
report_fatal_error("MatrixAAndBTF32ComponentsINTEL type interpretation "
1456+
"require the following SPIR-V extension: "
1457+
"SPV_INTEL_joint_matrix",
1458+
false);
1459+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1460+
Reqs.addCapability(
1461+
SPIRV::Capability::CooperativeMatrixTF32ComponentTypeINTEL);
1462+
}
1463+
if (CoopOperands & SPIRV::CooperativeMatrixOperands::
1464+
MatrixAAndBBFloat16ComponentsINTEL ||
1465+
CoopOperands &
1466+
SPIRV::CooperativeMatrixOperands::MatrixCBFloat16ComponentsINTEL ||
1467+
CoopOperands & SPIRV::CooperativeMatrixOperands::
1468+
MatrixResultBFloat16ComponentsINTEL) {
1469+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1470+
report_fatal_error("***BF16ComponentsINTEL type interpretations "
1471+
"require the following SPIR-V extension: "
1472+
"SPV_INTEL_joint_matrix",
1473+
false);
1474+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1475+
Reqs.addCapability(
1476+
SPIRV::Capability::CooperativeMatrixBFloat16ComponentTypeINTEL);
1477+
}
1478+
break;
1479+
}
1480+
case SPIRV::OpCooperativeMatrixLoadKHR:
1481+
case SPIRV::OpCooperativeMatrixStoreKHR:
1482+
case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1483+
case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1484+
case SPIRV::OpCooperativeMatrixPrefetchINTEL: {
1485+
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
1486+
report_fatal_error("Cooperative matrix instructions require the "
1487+
"following SPIR-V extension: "
1488+
"SPV_KHR_cooperative_matrix",
1489+
false);
1490+
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
1491+
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
1492+
1493+
// Check Layout operand in case if it's not a standard one and add the
1494+
// appropriate capability.
1495+
std::unordered_map<unsigned, unsigned> LayoutToInstMap = {
1496+
{SPIRV::OpCooperativeMatrixLoadKHR, 3},
1497+
{SPIRV::OpCooperativeMatrixStoreKHR, 2},
1498+
{SPIRV::OpCooperativeMatrixLoadCheckedINTEL, 5},
1499+
{SPIRV::OpCooperativeMatrixStoreCheckedINTEL, 4},
1500+
{SPIRV::OpCooperativeMatrixPrefetchINTEL, 4}};
1501+
1502+
const auto OpCode = MI.getOpcode();
1503+
const unsigned LayoutNum = LayoutToInstMap[OpCode];
1504+
Register RegLayout = MI.getOperand(LayoutNum).getReg();
1505+
const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
1506+
MachineInstr *MILayout = MRI.getUniqueVRegDef(RegLayout);
1507+
if (MILayout->getOpcode() == SPIRV::OpConstantI) {
1508+
const unsigned LayoutVal = MILayout->getOperand(2).getImm();
1509+
if (LayoutVal ==
1510+
static_cast<unsigned>(SPIRV::CooperativeMatrixLayout::PackedINTEL)) {
1511+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1512+
report_fatal_error("PackedINTEL layout require the following SPIR-V "
1513+
"extension: SPV_INTEL_joint_matrix",
1514+
false);
1515+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1516+
Reqs.addCapability(SPIRV::Capability::PackedCooperativeMatrixINTEL);
1517+
}
1518+
}
1519+
1520+
// Nothing to do.
1521+
if (OpCode == SPIRV::OpCooperativeMatrixLoadKHR ||
1522+
OpCode == SPIRV::OpCooperativeMatrixStoreKHR)
1523+
break;
1524+
1525+
std::string InstName;
1526+
switch (OpCode) {
1527+
case SPIRV::OpCooperativeMatrixPrefetchINTEL:
1528+
InstName = "OpCooperativeMatrixPrefetchINTEL";
1529+
break;
1530+
case SPIRV::OpCooperativeMatrixLoadCheckedINTEL:
1531+
InstName = "OpCooperativeMatrixLoadCheckedINTEL";
1532+
break;
1533+
case SPIRV::OpCooperativeMatrixStoreCheckedINTEL:
1534+
InstName = "OpCooperativeMatrixStoreCheckedINTEL";
1535+
break;
1536+
}
1537+
1538+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix)) {
1539+
const std::string ErrorMsg =
1540+
InstName + " instruction requires the "
1541+
"following SPIR-V extension: SPV_INTEL_joint_matrix";
1542+
report_fatal_error(ErrorMsg.c_str(), false);
1543+
}
1544+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1545+
if (OpCode == SPIRV::OpCooperativeMatrixPrefetchINTEL) {
1546+
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixPrefetchINTEL);
1547+
break;
1548+
}
1549+
Reqs.addCapability(
1550+
SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1551+
break;
1552+
}
1553+
case SPIRV::OpCooperativeMatrixConstructCheckedINTEL:
1554+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1555+
report_fatal_error("OpCooperativeMatrixConstructCheckedINTEL "
1556+
"instructions require the following SPIR-V extension: "
1557+
"SPV_INTEL_joint_matrix",
1558+
false);
1559+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1560+
Reqs.addCapability(
1561+
SPIRV::Capability::CooperativeMatrixCheckedInstructionsINTEL);
1562+
break;
1563+
case SPIRV::OpCooperativeMatrixGetElementCoordINTEL:
1564+
if (!ST.canUseExtension(SPIRV::Extension::SPV_INTEL_joint_matrix))
1565+
report_fatal_error("OpCooperativeMatrixGetElementCoordINTEL requires the "
1566+
"following SPIR-V extension: SPV_INTEL_joint_matrix",
1567+
false);
1568+
Reqs.addExtension(SPIRV::Extension::SPV_INTEL_joint_matrix);
1569+
Reqs.addCapability(
1570+
SPIRV::Capability::CooperativeMatrixInvocationInstructionsINTEL);
1571+
break;
14401572
case SPIRV::OpKill: {
14411573
Reqs.addCapability(SPIRV::Capability::Shader);
14421574
} break;

0 commit comments

Comments
 (0)