Skip to content

Commit f6fb55c

Browse files
[Backport to 14] Support SPV_INTEL_maximum_registers extension (#2398)
* Add infrastructure for translating ExecutionModeId (#2297) * [Backport to 14] Support SPV_INTEL_maximum_registers extension (#2344) Co-authored-by: Viktoria Maximova <[email protected]>
1 parent c6506ba commit f6fb55c

File tree

14 files changed

+294
-44
lines changed

14 files changed

+294
-44
lines changed

include/LLVMSPIRVExtensions.inc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,4 @@ EXT(SPV_INTEL_tensor_float32_conversion) // TODO: to remove old extension
6464
EXT(SPV_INTEL_tensor_float32_rounding)
6565
EXT(SPV_EXT_relaxed_printf_string_address_space)
6666
EXT(SPV_INTEL_cache_controls)
67+
EXT(SPV_INTEL_maximum_registers)

lib/SPIRV/SPIRVReader.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4184,6 +4184,50 @@ bool SPIRVToLLVM::transMetadata() {
41844184
F->setMetadata(kSPIR2MD::IntelFPGAIPInterface,
41854185
MDNode::get(*Context, InterfaceMDVec));
41864186
}
4187+
if (auto *EM = BF->getExecutionMode(
4188+
internal::ExecutionModeMaximumRegistersINTEL)) {
4189+
NamedMDNode *ExecModeMD =
4190+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
4191+
4192+
SmallVector<Metadata *, 4> ValueVec;
4193+
ValueVec.push_back(ConstantAsMetadata::get(F));
4194+
ValueVec.push_back(
4195+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
4196+
ValueVec.push_back(
4197+
ConstantAsMetadata::get(getUInt32(M, EM->getLiterals()[0])));
4198+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
4199+
}
4200+
if (auto *EM = BF->getExecutionMode(
4201+
internal::ExecutionModeMaximumRegistersIdINTEL)) {
4202+
NamedMDNode *ExecModeMD =
4203+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
4204+
4205+
SmallVector<Metadata *, 4> ValueVec;
4206+
ValueVec.push_back(ConstantAsMetadata::get(F));
4207+
ValueVec.push_back(
4208+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
4209+
4210+
auto *ExecOp = BF->getModule()->getValue(EM->getLiterals()[0]);
4211+
ValueVec.push_back(
4212+
MDNode::get(*Context, ConstantAsMetadata::get(cast<ConstantInt>(
4213+
transValue(ExecOp, nullptr, nullptr)))));
4214+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
4215+
}
4216+
if (auto *EM = BF->getExecutionMode(
4217+
internal::ExecutionModeNamedMaximumRegistersINTEL)) {
4218+
NamedMDNode *ExecModeMD =
4219+
M->getOrInsertNamedMetadata(kSPIRVMD::ExecutionMode);
4220+
4221+
SmallVector<Metadata *, 4> ValueVec;
4222+
ValueVec.push_back(ConstantAsMetadata::get(F));
4223+
ValueVec.push_back(
4224+
ConstantAsMetadata::get(getUInt32(M, EM->getExecutionMode())));
4225+
4226+
assert(EM->getLiterals()[0] == 0 &&
4227+
"Invalid named maximum number of registers");
4228+
ValueVec.push_back(MDString::get(*Context, "AutoINTEL"));
4229+
ExecModeMD->addOperand(MDNode::get(*Context, ValueVec));
4230+
}
41874231
}
41884232
NamedMDNode *MemoryModelMD =
41894233
M->getOrInsertNamedMetadata(kSPIRVMD::MemoryModel);

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,9 @@ SPIRVFunction *LLVMToSPIRVBase::transFunctionDecl(Function *F) {
897897

898898
transFPGAFunctionMetadata(BF, F);
899899

900+
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_maximum_registers))
901+
transFunctionMetadataAsExecutionMode(BF, F);
902+
900903
transAuxDataInst(BF, F);
901904

902905
SPIRVDBG(dbgs() << "[transFunction] " << *F << " => ";
@@ -1029,6 +1032,38 @@ void LLVMToSPIRVBase::transFPGAFunctionMetadata(SPIRVFunction *BF,
10291032
transMetadataDecorations(FDecoMD, BF);
10301033
}
10311034

1035+
void LLVMToSPIRVBase::transFunctionMetadataAsExecutionMode(SPIRVFunction *BF,
1036+
Function *F) {
1037+
SmallVector<MDNode *, 1> RegisterAllocModeMDs;
1038+
F->getMetadata("RegisterAllocMode", RegisterAllocModeMDs);
1039+
1040+
for (unsigned I = 0; I < RegisterAllocModeMDs.size(); I++) {
1041+
auto *RegisterAllocMode = RegisterAllocModeMDs[I]->getOperand(0).get();
1042+
if (isa<MDString>(RegisterAllocMode)) {
1043+
const StringRef Str = getMDOperandAsString(RegisterAllocModeMDs[I], 0);
1044+
const internal::InternalNamedMaximumNumberOfRegisters NamedValue =
1045+
SPIRVNamedMaximumNumberOfRegistersNameMap::rmap(Str.str());
1046+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
1047+
OpExecutionMode, BF,
1048+
internal::ExecutionModeNamedMaximumRegistersINTEL, NamedValue)));
1049+
} else if (isa<MDNode>(RegisterAllocMode)) {
1050+
auto *RegisterAllocNodeMDOp =
1051+
getMDOperandAsMDNode(RegisterAllocModeMDs[I], 0);
1052+
const int Num = getMDOperandAsInt(RegisterAllocNodeMDOp, 0);
1053+
auto *Const =
1054+
BM->addConstant(transType(Type::getInt32Ty(F->getContext())), Num);
1055+
BF->addExecutionMode(BM->add(new SPIRVExecutionModeId(
1056+
BF, internal::ExecutionModeMaximumRegistersIdINTEL, Const->getId())));
1057+
} else {
1058+
const int64_t RegisterAllocVal =
1059+
mdconst::dyn_extract<ConstantInt>(RegisterAllocMode)->getZExtValue();
1060+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
1061+
OpExecutionMode, BF, internal::ExecutionModeMaximumRegistersINTEL,
1062+
RegisterAllocVal)));
1063+
}
1064+
}
1065+
}
1066+
10321067
void LLVMToSPIRVBase::transAuxDataInst(SPIRVFunction *BF, Function *F) {
10331068
auto *BM = BF->getModule();
10341069
if (!BM->preserveAuxData())
@@ -4766,19 +4801,20 @@ bool LLVMToSPIRVBase::transExecutionMode() {
47664801
auto AddSingleArgExecutionMode = [&](ExecutionMode EMode) {
47674802
uint32_t Arg = 0;
47684803
N.get(Arg);
4769-
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(BF, EMode, Arg)));
4804+
BF->addExecutionMode(
4805+
BM->add(new SPIRVExecutionMode(OpExecutionMode, BF, EMode, Arg)));
47704806
};
47714807

47724808
switch (EMode) {
47734809
case spv::ExecutionModeContractionOff:
4774-
BF->addExecutionMode(BM->add(
4775-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
4810+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
4811+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
47764812
break;
47774813
case spv::ExecutionModeInitializer:
47784814
case spv::ExecutionModeFinalizer:
47794815
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_1)) {
4780-
BF->addExecutionMode(BM->add(
4781-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
4816+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
4817+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
47824818
} else {
47834819
getErrorLog().checkError(false, SPIRVEC_Requires1_1,
47844820
"Initializer/Finalizer Execution Mode");
@@ -4790,15 +4826,16 @@ bool LLVMToSPIRVBase::transExecutionMode() {
47904826
unsigned X, Y, Z;
47914827
N.get(X).get(Y).get(Z);
47924828
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
4793-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
4829+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
47944830
} break;
47954831
case spv::ExecutionModeMaxWorkgroupSizeINTEL: {
47964832
if (BM->isAllowedToUseExtension(
47974833
ExtensionID::SPV_INTEL_kernel_attributes)) {
47984834
unsigned X, Y, Z;
47994835
N.get(X).get(Y).get(Z);
48004836
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
4801-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
4837+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y,
4838+
Z)));
48024839
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
48034840
BM->addCapability(CapabilityKernelAttributesINTEL);
48044841
}
@@ -4807,8 +4844,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
48074844
if (!BM->isAllowedToUseExtension(
48084845
ExtensionID::SPV_INTEL_kernel_attributes))
48094846
break;
4810-
BF->addExecutionMode(BM->add(
4811-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
4847+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
4848+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
48124849
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
48134850
BM->addCapability(CapabilityKernelAttributesINTEL);
48144851
} break;
@@ -4851,8 +4888,9 @@ bool LLVMToSPIRVBase::transExecutionMode() {
48514888
break;
48524889
unsigned NBarrierCnt = 0;
48534890
N.get(NBarrierCnt);
4854-
BF->addExecutionMode(new SPIRVExecutionMode(
4855-
BF, static_cast<ExecutionMode>(EMode), NBarrierCnt));
4891+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
4892+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
4893+
NBarrierCnt)));
48564894
BM->addExtension(ExtensionID::SPV_INTEL_vector_compute);
48574895
BM->addCapability(CapabilityVectorComputeINTEL);
48584896
} break;
@@ -4883,8 +4921,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
48834921
} break;
48844922
case spv::internal::ExecutionModeFastCompositeKernelINTEL: {
48854923
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fast_composite))
4886-
BF->addExecutionMode(BM->add(
4887-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
4924+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
4925+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
48884926
} break;
48894927
default:
48904928
llvm_unreachable("invalid execution mode");
@@ -4929,8 +4967,8 @@ void LLVMToSPIRVBase::transFPContract() {
49294967
}
49304968

49314969
if (DisableContraction) {
4932-
BF->addExecutionMode(BF->getModule()->add(
4933-
new SPIRVExecutionMode(BF, spv::ExecutionModeContractionOff)));
4970+
BF->addExecutionMode(BF->getModule()->add(new SPIRVExecutionMode(
4971+
OpExecutionMode, BF, spv::ExecutionModeContractionOff)));
49344972
}
49354973
}
49364974
}

lib/SPIRV/SPIRVWriter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class LLVMToSPIRVBase {
122122
void transVectorComputeMetadata(Function *F);
123123
void transFPGAFunctionMetadata(SPIRVFunction *BF, Function *F);
124124
void transAuxDataInst(SPIRVFunction *BF, Function *F);
125-
125+
void transFunctionMetadataAsExecutionMode(SPIRVFunction *BF, Function *F);
126126
bool transGlobalVariables();
127127

128128
Op transBoolOpCode(SPIRVValue *Opn, Op OC);

lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule,
580580
SPIRVExecutionModelKind TheExecModel,
581581
SPIRVId TheId, const std::string &TheName,
582582
std::vector<SPIRVId> Variables)
583-
: SPIRVAnnotation(TheModule->get<SPIRVFunction>(TheId),
583+
: SPIRVAnnotation(OpEntryPoint, TheModule->get<SPIRVFunction>(TheId),
584584
getSizeInWords(TheName) + Variables.size() + 3),
585585
ExecModel(TheExecModel), Name(TheName), Variables(Variables) {}
586586

@@ -628,6 +628,9 @@ void SPIRVExecutionMode::decode(std::istream &I) {
628628
case ExecutionModeSchedulerTargetFmaxMhzINTEL:
629629
case ExecutionModeRegisterMapInterfaceINTEL:
630630
case internal::ExecutionModeStreamingInterfaceINTEL:
631+
case internal::ExecutionModeMaximumRegistersINTEL:
632+
case internal::ExecutionModeMaximumRegistersIdINTEL:
633+
case internal::ExecutionModeNamedMaximumRegistersINTEL:
631634
WordLiterals.resize(1);
632635
break;
633636
default:
@@ -649,7 +652,8 @@ SPIRVForward *SPIRVAnnotationGeneric::getOrCreateTarget() const {
649652
}
650653

651654
SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr)
652-
: SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr) {}
655+
: SPIRVAnnotation(OpName, TheTarget, getSizeInWords(TheStr) + 2),
656+
Str(TheStr) {}
653657

654658
void SPIRVName::encode(spv_ostream &O) const { getEncoder(O) << Target << Str; }
655659

0 commit comments

Comments
 (0)