Skip to content

Commit a7bf89f

Browse files
svenvhvmaksimo
authored andcommitted
Add SPV_KHR_integer_dot_product support
Add support for mapping the SPV_KHR_integer_dot_product extension operations to and from SPIR-V friendly IR. Original commit: KhronosGroup/SPIRV-LLVM-Translator@faea0eb
1 parent f381bde commit a7bf89f

File tree

6 files changed

+676
-0
lines changed

6 files changed

+676
-0
lines changed

llvm-spirv/include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ EXT(SPV_EXT_shader_atomic_float_min_max)
77
EXT(SPV_KHR_no_integer_wrap_decoration)
88
EXT(SPV_KHR_float_controls)
99
EXT(SPV_KHR_linkonce_odr)
10+
EXT(SPV_KHR_integer_dot_product)
1011
EXT(SPV_INTEL_subgroups)
1112
EXT(SPV_INTEL_media_block_io)
1213
EXT(SPV_INTEL_device_side_avc_motion_estimation)

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3578,6 +3578,9 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
35783578
if (OC == OpImageSampleExplicitLod)
35793579
AddRetTypePostfix = true;
35803580

3581+
if (OpSDotKHR <= BI->getOpCode() && BI->getOpCode() <= OpSUDotAccSatKHR)
3582+
AddRetTypePostfix = true;
3583+
35813584
if (AddRetTypePostfix) {
35823585
const Type *RetTy =
35833586
BI->hasType() ? transType(BI->getType()) : Type::getVoidTy(*Context);

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2818,6 +2818,88 @@ class SPIRVExpectINTELInstBase : public SPIRVInstTemplateBase {
28182818
_SPIRV_OP_INTERNAL(ExpectINTEL, true, 5)
28192819
#undef _SPIRV_OP_INTERNAL
28202820

2821+
class SPIRVDotKHRBase : public SPIRVInstTemplateBase {
2822+
protected:
2823+
SPIRVCapVec getRequiredCapability() const override {
2824+
// Both vector operands must have the same type, so analyzing the
2825+
// first operand will suffice.
2826+
SPIRVCapabilityKind ArgCap = getRequiredCapabilityForOperand(Ops[0]);
2827+
return getVec(ArgCap, CapabilityDotProductKHR);
2828+
}
2829+
2830+
llvm::Optional<ExtensionID> getRequiredExtension() const override {
2831+
return ExtensionID::SPV_KHR_integer_dot_product;
2832+
}
2833+
2834+
void validate() const override {
2835+
SPIRVInstruction::validate();
2836+
SPIRVId Vec1 = Ops[0];
2837+
SPIRVId Vec2 = Ops[1];
2838+
(void)Vec1;
2839+
(void)Vec2;
2840+
2841+
assert(getValueType(Vec1) == getValueType(Vec2) &&
2842+
"Input vectors must have the same type");
2843+
assert(getType()->isTypeInt() && "Result type must be an integer type");
2844+
assert(!getType()->isTypeVector() && "Result type must be scalar");
2845+
}
2846+
2847+
private:
2848+
bool isAccSat() const {
2849+
return (OpCode == OpSDotAccSatKHR || OpCode == OpUDotAccSatKHR ||
2850+
OpCode == OpSUDotAccSatKHR);
2851+
}
2852+
2853+
Optional<PackedVectorFormat> getPackedVectorFormat() const {
2854+
size_t PackFmtIdx = 2;
2855+
if (isAccSat()) {
2856+
// AccSat instructions have an additional Accumulator operand.
2857+
PackFmtIdx++;
2858+
}
2859+
2860+
if (PackFmtIdx == Ops.size() - 1)
2861+
return static_cast<PackedVectorFormat>(Ops[PackFmtIdx]);
2862+
2863+
return None;
2864+
}
2865+
2866+
SPIRVCapabilityKind getRequiredCapabilityForOperand(SPIRVId ArgId) const {
2867+
const SPIRVType *T = getValueType(ArgId);
2868+
if (auto PackFmt = getPackedVectorFormat()) {
2869+
switch (*PackFmt) {
2870+
case PackedVectorFormatPackedVectorFormat4x8BitKHR:
2871+
assert(!T->isTypeVector() && T->isTypeInt() && T->getBitWidth() == 32 &&
2872+
"Type does not match pack format");
2873+
return CapabilityDotProductInput4x8BitPackedKHR;
2874+
case PackedVectorFormatMax:
2875+
break;
2876+
}
2877+
llvm_unreachable("Unknown Packed Vector Format");
2878+
}
2879+
2880+
if (T->isTypeVector()) {
2881+
const SPIRVType *EltT = T->getVectorComponentType();
2882+
if (T->getVectorComponentCount() == 4 && EltT->isTypeInt() &&
2883+
EltT->getBitWidth() == 8)
2884+
return CapabilityDotProductInput4x8BitKHR;
2885+
if (EltT->isTypeInt())
2886+
return CapabilityDotProductInputAllKHR;
2887+
}
2888+
2889+
llvm_unreachable("No mapping for argument type to capability.");
2890+
}
2891+
};
2892+
2893+
#define _SPIRV_OP(x, ...) \
2894+
typedef SPIRVInstTemplate<SPIRVDotKHRBase, Op##x, __VA_ARGS__> SPIRV##x;
2895+
_SPIRV_OP(SDotKHR, true, 5, true, 2)
2896+
_SPIRV_OP(UDotKHR, true, 5, true, 2)
2897+
_SPIRV_OP(SUDotKHR, true, 5, true, 2)
2898+
_SPIRV_OP(SDotAccSatKHR, true, 6, true, 3)
2899+
_SPIRV_OP(UDotAccSatKHR, true, 6, true, 3)
2900+
_SPIRV_OP(SUDotAccSatKHR, true, 6, true, 3)
2901+
#undef _SPIRV_OP
2902+
28212903
class SPIRVSubgroupShuffleINTELInstBase : public SPIRVInstTemplateBase {
28222904
protected:
28232905
SPIRVCapVec getRequiredCapability() const override {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@ _SPIRV_OP(GroupNonUniformBitwiseXor, 361)
327327
_SPIRV_OP(GroupNonUniformLogicalAnd, 362)
328328
_SPIRV_OP(GroupNonUniformLogicalOr, 363)
329329
_SPIRV_OP(GroupNonUniformLogicalXor, 364)
330+
_SPIRV_OP(SDotKHR, 4450)
331+
_SPIRV_OP(UDotKHR, 4451)
332+
_SPIRV_OP(SUDotKHR, 4452)
333+
_SPIRV_OP(SDotAccSatKHR, 4453)
334+
_SPIRV_OP(UDotAccSatKHR, 4454)
335+
_SPIRV_OP(SUDotAccSatKHR, 4455)
330336
_SPIRV_OP(SubgroupShuffleINTEL, 5571)
331337
_SPIRV_OP(SubgroupShuffleDownINTEL, 5572)
332338
_SPIRV_OP(SubgroupShuffleUpINTEL, 5573)

0 commit comments

Comments
 (0)