Skip to content

Commit b7c5218

Browse files
authored
Add ComponentTypeInterpretation for joint matrix type (#1835)
It specifies how to interpret 'Component Type' when components of a joint matrix are storages for values of different types, for example float for TF32, unsigned short for bfloat16. At this point only tf32 type interpretation is added during SPIR-V generation. Adding it to bf16 is a breaking change and requires adaptation across drivers. Spec update: intel/llvm#8175 Signed-off-by: Sidorov, Dmitry [email protected]
1 parent 9858104 commit b7c5218

File tree

7 files changed

+292
-24
lines changed

7 files changed

+292
-24
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,28 @@ Type *SPIRVToLLVM::transType(SPIRVType *T, bool UseTPT) {
445445
(unsigned)S};
446446
if (auto *Use = MT->getUse())
447447
Params.push_back(static_cast<SPIRVConstant *>(Use)->getZExtIntValue());
448+
auto *CTI = MT->getComponentTypeInterpretation();
449+
if (!CTI)
450+
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
451+
transTypeToOCLTypeName(MT->getCompType()),
452+
Params, !UseTPT));
453+
std::string ComponentTypeName;
454+
switch (static_cast<SPIRVConstant *>(CTI)->getZExtIntValue()) {
455+
case internal::InternalJointMatrixCTI::TF32:
456+
ComponentTypeName = "tf32";
457+
break;
458+
case internal::InternalJointMatrixCTI::Bfloat16:
459+
ComponentTypeName = "bfloat16";
460+
break;
461+
case internal::InternalJointMatrixCTI::PackedInt2:
462+
case internal::InternalJointMatrixCTI::PackedInt4:
463+
// Do nothing just now
464+
break;
465+
default:
466+
llvm_unreachable("Unexpected joint matrix component type");
467+
}
448468
return mapType(T, getSPIRVType(internal::OpTypeJointMatrixINTEL,
449-
transTypeToOCLTypeName(MT->getCompType()),
450-
Params, !UseTPT));
469+
ComponentTypeName, Params, !UseTPT));
451470
}
452471
case OpTypeForwardPointer: {
453472
SPIRVTypeForwardPointer *FP =

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
610610

611611
// Representation in LLVM IR before the translator is a pointer to an opaque
612612
// structure:
613-
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%scope%_%use%
613+
// %spirv.JointMatrixINTEL._%element_type%_%rows%_%cols%_%layout%_%scope%_%use%
614614
// Here we check the structure name yet again. Another option would be to
615615
// check SPIR-V friendly function calls (by their name) and obtain return
616616
// or their parameter types, assuming, that the appropriate types are Matrix
@@ -621,6 +621,18 @@ SPIRVType *LLVMToSPIRVBase::transPointerType(SPIRVType *ET, unsigned AddrSpc) {
621621
// simply not true.
622622
SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
623623
SmallVector<std::string, 8> Postfixes) {
624+
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
625+
unsigned long long N = 0;
626+
if (consumeUnsignedInteger(Postfix, 10, N))
627+
BM->getErrorLog().checkError(
628+
false, SPIRVEC_InvalidLlvmModule,
629+
"TypeJointMatrixINTEL expects integer parameters");
630+
return getUInt32(M, N);
631+
};
632+
std::vector<SPIRVValue *> Args;
633+
for (size_t I = 1; I != Postfixes.size(); ++I)
634+
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
635+
624636
Type *ElemTy = nullptr;
625637
StringRef Ty{Postfixes[0]};
626638
auto NumBits = llvm::StringSwitch<unsigned>(Ty)
@@ -629,32 +641,30 @@ SPIRVType *LLVMToSPIRVBase::transSPIRVJointMatrixINTELType(
629641
.Case("int", 32)
630642
.Case("long", 64)
631643
.Default(0);
632-
if (NumBits)
644+
if (NumBits) {
633645
ElemTy = IntegerType::get(M->getContext(), NumBits);
634-
else if (Ty == "half")
646+
} else if (Ty == "half") {
635647
ElemTy = Type::getHalfTy(M->getContext());
636-
else if (Ty == "float")
648+
} else if (Ty == "float") {
637649
ElemTy = Type::getFloatTy(M->getContext());
638-
else if (Ty == "double")
650+
} else if (Ty == "double") {
639651
ElemTy = Type::getDoubleTy(M->getContext());
640-
else if (Ty == "bfloat16")
652+
} else if (Ty == "bfloat16") {
641653
ElemTy = Type::getInt16Ty(M->getContext());
642-
else
654+
// TODO: add BF16 CTI when we do breaking change
655+
// auto *CTI = transConstant(getUInt32(M, static_cast<uint64_t>(
656+
// internal::InternalJointMatrixCTI::Bfloat16)));
657+
// Args.push_back(CTI);
658+
// BM->addCapability(internal::CapabilityJointMatrixBF16ComponentTypeINTEL);
659+
} else if (Ty == "tf32") {
660+
ElemTy = Type::getFloatTy(M->getContext());
661+
auto *CTI = transConstant(getUInt32(
662+
M, static_cast<uint64_t>(internal::InternalJointMatrixCTI::TF32)));
663+
Args.push_back(CTI);
664+
BM->addCapability(internal::CapabilityJointMatrixTF32ComponentTypeINTEL);
665+
} else {
643666
llvm_unreachable("Unexpected type for matrix!");
644-
645-
auto ParseInteger = [this](StringRef Postfix) -> ConstantInt * {
646-
unsigned long long N = 0;
647-
if (consumeUnsignedInteger(Postfix, 10, N)) {
648-
BM->getErrorLog().checkError(
649-
false, SPIRVEC_InvalidLlvmModule,
650-
"TypeJointMatrixINTEL expects integer parameters");
651-
return 0;
652-
}
653-
return getUInt32(M, N);
654-
};
655-
std::vector<SPIRVValue *> Args;
656-
for (size_t I = 1; I != Postfixes.size(); ++I)
657-
Args.emplace_back(transConstant(ParseInteger(Postfixes[I])));
667+
}
658668
return BM->addJointMatrixINTELType(transType(ElemTy), Args);
659669
}
660670

lib/SPIRV/libSPIRV/SPIRVEnum.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ template <> inline void SPIRVMap<SPIRVCapabilityKind, SPIRVCapVec>::init() {
205205
{CapabilitySubgroupAvcMotionEstimationIntraINTEL});
206206
ADD_VEC_INIT(internal::CapabilityJointMatrixWIInstructionsINTEL,
207207
{internal::CapabilityJointMatrixINTEL});
208+
ADD_VEC_INIT(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
209+
{internal::CapabilityJointMatrixINTEL});
210+
ADD_VEC_INIT(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
211+
{internal::CapabilityJointMatrixINTEL});
212+
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
213+
{internal::CapabilityJointMatrixINTEL});
214+
ADD_VEC_INIT(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
215+
{internal::CapabilityJointMatrixINTEL});
208216
}
209217

210218
template <> inline void SPIRVMap<SPIRVExecutionModelKind, SPIRVCapVec>::init() {

lib/SPIRV/libSPIRV/SPIRVNameMapEnum.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,14 @@ template <> inline void SPIRVMap<Capability, std::string>::init() {
633633
"TensorFloat32ConversionINTEL");
634634
add(internal::CapabilityJointMatrixWIInstructionsINTEL,
635635
"JointMatrixWIInstructionsINTEL");
636+
add(internal::CapabilityJointMatrixTF32ComponentTypeINTEL,
637+
"JointMatrixTF32ComponentTypeINTEL");
638+
add(internal::CapabilityJointMatrixBF16ComponentTypeINTEL,
639+
"JointMatrixBF16ComponentTypeINTEL");
640+
add(internal::CapabilityJointMatrixPackedInt2ComponentTypeINTEL,
641+
"JointMatrixPackedInt2ComponentTypeINTEL");
642+
add(internal::CapabilityJointMatrixPackedInt4ComponentTypeINTEL,
643+
"JointMatrixPackedInt4ComponentTypeINTEL");
636644
}
637645
SPIRV_DEF_NAMEMAP(Capability, SPIRVCapabilityNameMap)
638646

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,9 @@ class SPIRVTypeJointMatrixINTEL : public SPIRVType {
10881088
SPIRVValue *getLayout() const { return Args[2]; }
10891089
SPIRVValue *getScope() const { return Args[3]; }
10901090
SPIRVValue *getUse() const { return Args.size() > 4 ? Args[4] : nullptr; }
1091+
SPIRVValue *getComponentTypeInterpretation() const {
1092+
return Args.size() > 5 ? Args[5] : nullptr;
1093+
}
10911094
};
10921095

10931096
} // namespace SPIRV

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ enum InternalCapability {
9797
ICapabilityComplexFloatMulDivINTEL = 6414,
9898
ICapabilityTensorFloat32ConversionINTEL = 6425,
9999
ICapabilityMaskedGatherScatterINTEL = 6427,
100-
ICapabilityJointMatrixWIInstructionsINTEL = 6435
100+
ICapabilityJointMatrixWIInstructionsINTEL = 6435,
101+
ICapabilityJointMatrixTF32ComponentTypeINTEL = 6436,
102+
ICapabilityJointMatrixBF16ComponentTypeINTEL = 6437,
103+
ICapabilityJointMatrixPackedInt2ComponentTypeINTEL = 6438,
104+
ICapabilityJointMatrixPackedInt4ComponentTypeINTEL = 6439
101105
};
102106

103107
enum InternalFunctionControlMask { IFunctionControlOptNoneINTELMask = 0x10000 };
@@ -118,6 +122,14 @@ enum InternalJointMatrixLayout {
118122

119123
enum InternalJointMatrixUse { MatrixA = 0, MatrixB = 1, Accumulator = 2 };
120124

125+
enum InternalJointMatrixCTI {
126+
None = 0,
127+
TF32 = 1,
128+
Bfloat16 = 2,
129+
PackedInt2 = 3,
130+
PackedInt4 = 4
131+
};
132+
121133
enum InternalBuiltIn {
122134
IBuiltInSubDeviceIDINTEL = 6135,
123135
IBuiltInGlobalHWThreadIDINTEL = 6136,
@@ -126,6 +138,10 @@ enum InternalBuiltIn {
126138
#define _SPIRV_OP(x, y) constexpr x x##y = static_cast<x>(I##x##y);
127139
_SPIRV_OP(Capability, JointMatrixINTEL)
128140
_SPIRV_OP(Capability, JointMatrixWIInstructionsINTEL)
141+
_SPIRV_OP(Capability, JointMatrixTF32ComponentTypeINTEL)
142+
_SPIRV_OP(Capability, JointMatrixBF16ComponentTypeINTEL)
143+
_SPIRV_OP(Capability, JointMatrixPackedInt2ComponentTypeINTEL)
144+
_SPIRV_OP(Capability, JointMatrixPackedInt4ComponentTypeINTEL)
129145
_SPIRV_OP(Op, TypeJointMatrixINTEL)
130146
_SPIRV_OP(Op, JointMatrixLoadINTEL)
131147
_SPIRV_OP(Op, JointMatrixStoreINTEL)

0 commit comments

Comments
 (0)