Skip to content

Commit 445e1c9

Browse files
vmaksimosys-ce-bb
authored andcommitted
Add infrastructure for translating ExecutionModeId (#2297)
This functionality was added in SPIR-V 1.2 and allows using an <id> to set the execution modes SubgroupsPerWorkgroupId, LocalSizeId, and LocalSizeHintI, and others. Original commit: KhronosGroup/SPIRV-LLVM-Translator@10b0aab
1 parent 28a8b6b commit 445e1c9

File tree

4 files changed

+68
-38
lines changed

4 files changed

+68
-38
lines changed

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5721,19 +5721,20 @@ bool LLVMToSPIRVBase::transExecutionMode() {
57215721
auto AddSingleArgExecutionMode = [&](ExecutionMode EMode) {
57225722
uint32_t Arg = ~0u;
57235723
N.get(Arg);
5724-
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(BF, EMode, Arg)));
5724+
BF->addExecutionMode(
5725+
BM->add(new SPIRVExecutionMode(OpExecutionMode, BF, EMode, Arg)));
57255726
};
57265727

57275728
switch (EMode) {
57285729
case spv::ExecutionModeContractionOff:
5729-
BF->addExecutionMode(BM->add(
5730-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
5730+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5731+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
57315732
break;
57325733
case spv::ExecutionModeInitializer:
57335734
case spv::ExecutionModeFinalizer:
57345735
if (BM->isAllowedToUseVersion(VersionNumber::SPIRV_1_1)) {
5735-
BF->addExecutionMode(BM->add(
5736-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
5736+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5737+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
57375738
} else {
57385739
getErrorLog().checkError(false, SPIRVEC_Requires1_1,
57395740
"Initializer/Finalizer Execution Mode");
@@ -5745,15 +5746,16 @@ bool LLVMToSPIRVBase::transExecutionMode() {
57455746
unsigned X = 0, Y = 0, Z = 0;
57465747
N.get(X).get(Y).get(Z);
57475748
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5748-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
5749+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
57495750
} break;
57505751
case spv::ExecutionModeMaxWorkgroupSizeINTEL: {
57515752
if (BM->isAllowedToUseExtension(
57525753
ExtensionID::SPV_INTEL_kernel_attributes)) {
57535754
unsigned X = 0, Y = 0, Z = 0;
57545755
N.get(X).get(Y).get(Z);
57555756
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5756-
BF, static_cast<ExecutionMode>(EMode), X, Y, Z)));
5757+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode), X, Y,
5758+
Z)));
57575759
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
57585760
BM->addCapability(CapabilityKernelAttributesINTEL);
57595761
}
@@ -5762,8 +5764,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
57625764
if (!BM->isAllowedToUseExtension(
57635765
ExtensionID::SPV_INTEL_kernel_attributes))
57645766
break;
5765-
BF->addExecutionMode(BM->add(
5766-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
5767+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5768+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
57675769
BM->addExtension(ExtensionID::SPV_INTEL_kernel_attributes);
57685770
BM->addCapability(CapabilityKernelAttributesINTEL);
57695771
} break;
@@ -5807,7 +5809,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
58075809
unsigned NBarrierCnt = 0;
58085810
N.get(NBarrierCnt);
58095811
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5810-
BF, static_cast<ExecutionMode>(EMode), NBarrierCnt)));
5812+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode),
5813+
NBarrierCnt)));
58115814
BM->addExtension(ExtensionID::SPV_INTEL_vector_compute);
58125815
BM->addCapability(CapabilityVectorComputeINTEL);
58135816
} break;
@@ -5837,8 +5840,8 @@ bool LLVMToSPIRVBase::transExecutionMode() {
58375840
} break;
58385841
case spv::internal::ExecutionModeFastCompositeKernelINTEL: {
58395842
if (BM->isAllowedToUseExtension(ExtensionID::SPV_INTEL_fast_composite))
5840-
BF->addExecutionMode(BM->add(
5841-
new SPIRVExecutionMode(BF, static_cast<ExecutionMode>(EMode))));
5843+
BF->addExecutionMode(BM->add(new SPIRVExecutionMode(
5844+
OpExecutionMode, BF, static_cast<ExecutionMode>(EMode))));
58425845
} break;
58435846
default:
58445847
llvm_unreachable("invalid execution mode");
@@ -5883,8 +5886,8 @@ void LLVMToSPIRVBase::transFPContract() {
58835886
}
58845887

58855888
if (DisableContraction) {
5886-
BF->addExecutionMode(BF->getModule()->add(
5887-
new SPIRVExecutionMode(BF, spv::ExecutionModeContractionOff)));
5889+
BF->addExecutionMode(BF->getModule()->add(new SPIRVExecutionMode(
5890+
OpExecutionMode, BF, spv::ExecutionModeContractionOff)));
58885891
}
58895892
}
58905893
}

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEntry.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ SPIRVEntryPoint::SPIRVEntryPoint(SPIRVModule *TheModule,
612612
SPIRVExecutionModelKind TheExecModel,
613613
SPIRVId TheId, const std::string &TheName,
614614
std::vector<SPIRVId> Variables)
615-
: SPIRVAnnotation(TheModule->get<SPIRVFunction>(TheId),
615+
: SPIRVAnnotation(OpEntryPoint, TheModule->get<SPIRVFunction>(TheId),
616616
getSizeInWords(TheName) + Variables.size() + 3),
617617
ExecModel(TheExecModel), Name(TheName), Variables(Variables) {}
618618

@@ -681,7 +681,8 @@ SPIRVForward *SPIRVAnnotationGeneric::getOrCreateTarget() const {
681681
}
682682

683683
SPIRVName::SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr)
684-
: SPIRVAnnotation(TheTarget, getSizeInWords(TheStr) + 2), Str(TheStr) {}
684+
: SPIRVAnnotation(OpName, TheTarget, getSizeInWords(TheStr) + 2),
685+
Str(TheStr) {}
685686

686687
void SPIRVName::encode(spv_ostream &O) const { getEncoder(O) << Target << Str; }
687688

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEntry.h

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -521,23 +521,24 @@ class SPIRVAnnotationGeneric : public SPIRVEntryNoIdGeneric {
521521
SPIRVId Target;
522522
};
523523

524-
template <Op OC> class SPIRVAnnotation : public SPIRVAnnotationGeneric {
524+
class SPIRVAnnotation : public SPIRVAnnotationGeneric {
525525
public:
526526
// Complete constructor
527-
SPIRVAnnotation(const SPIRVEntry *TheTarget, unsigned TheWordCount)
527+
SPIRVAnnotation(Op OC, const SPIRVEntry *TheTarget, unsigned TheWordCount)
528528
: SPIRVAnnotationGeneric(TheTarget->getModule(), TheWordCount, OC,
529529
TheTarget->getId()) {}
530-
// Incomplete constructor
531-
SPIRVAnnotation() : SPIRVAnnotationGeneric(OC) {}
530+
// Incomplete constructors
531+
SPIRVAnnotation(Op OC) : SPIRVAnnotationGeneric(OC) {}
532+
SPIRVAnnotation() : SPIRVAnnotationGeneric(OpNop) {}
532533
};
533534

534-
class SPIRVEntryPoint : public SPIRVAnnotation<OpEntryPoint> {
535+
class SPIRVEntryPoint : public SPIRVAnnotation {
535536
public:
536537
static const SPIRVWord FixedWC = 4;
537538
SPIRVEntryPoint(SPIRVModule *TheModule, SPIRVExecutionModelKind,
538539
SPIRVId TheId, const std::string &TheName,
539540
std::vector<SPIRVId> Variables);
540-
SPIRVEntryPoint() {}
541+
SPIRVEntryPoint() : SPIRVAnnotation(OpEntryPoint) {}
541542

542543
_SPIRV_DCL_ENCDEC
543544
protected:
@@ -548,12 +549,12 @@ class SPIRVEntryPoint : public SPIRVAnnotation<OpEntryPoint> {
548549
std::vector<SPIRVId> Variables;
549550
};
550551

551-
class SPIRVName : public SPIRVAnnotation<OpName> {
552+
class SPIRVName : public SPIRVAnnotation {
552553
public:
553554
// Complete constructor
554555
SPIRVName(const SPIRVEntry *TheTarget, const std::string &TheStr);
555556
// Incomplete constructor
556-
SPIRVName() {}
557+
SPIRVName() : SPIRVAnnotation(OpName) {}
557558

558559
protected:
559560
_SPIRV_DCL_ENCDEC
@@ -562,18 +563,18 @@ class SPIRVName : public SPIRVAnnotation<OpName> {
562563
std::string Str;
563564
};
564565

565-
class SPIRVMemberName : public SPIRVAnnotation<OpName> {
566+
class SPIRVMemberName : public SPIRVAnnotation {
566567
public:
567568
static const SPIRVWord FixedWC = 3;
568569
// Complete constructor
569570
SPIRVMemberName(const SPIRVEntry *TheTarget, SPIRVWord TheMemberNumber,
570571
const std::string &TheStr)
571-
: SPIRVAnnotation(TheTarget, FixedWC + getSizeInWords(TheStr)),
572+
: SPIRVAnnotation(OpName, TheTarget, FixedWC + getSizeInWords(TheStr)),
572573
MemberNumber(TheMemberNumber), Str(TheStr) {
573574
validate();
574575
}
575576
// Incomplete constructor
576-
SPIRVMemberName() : MemberNumber(SPIRVWORD_MAX) {}
577+
SPIRVMemberName() : SPIRVAnnotation(OpName), MemberNumber(SPIRVWORD_MAX) {}
577578

578579
protected:
579580
_SPIRV_DCL_ENCDEC
@@ -649,31 +650,33 @@ class SPIRVLine : public SPIRVEntry {
649650
SPIRVWord Column;
650651
};
651652

652-
class SPIRVExecutionMode : public SPIRVAnnotation<OpExecutionMode> {
653+
class SPIRVExecutionMode : public SPIRVAnnotation {
653654
public:
654655
// Complete constructor for LocalSize, LocalSizeHint
655-
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode,
656-
SPIRVWord X, SPIRVWord Y, SPIRVWord Z)
657-
: SPIRVAnnotation(TheTarget, 6), ExecMode(TheExecMode) {
656+
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
657+
SPIRVExecutionModeKind TheExecMode, SPIRVWord X,
658+
SPIRVWord Y, SPIRVWord Z)
659+
: SPIRVAnnotation(OC, TheTarget, 6), ExecMode(TheExecMode) {
658660
WordLiterals.push_back(X);
659661
WordLiterals.push_back(Y);
660662
WordLiterals.push_back(Z);
661663
updateModuleVersion();
662664
}
663665
// Complete constructor for VecTypeHint, SubgroupSize, SubgroupsPerWorkgroup
664-
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode,
665-
SPIRVWord Code)
666-
: SPIRVAnnotation(TheTarget, 4), ExecMode(TheExecMode) {
666+
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
667+
SPIRVExecutionModeKind TheExecMode, SPIRVWord Code)
668+
: SPIRVAnnotation(OC, TheTarget, 4), ExecMode(TheExecMode) {
667669
WordLiterals.push_back(Code);
668-
updateModuleVersion();
669670
}
670671
// Complete constructor for ContractionOff
671-
SPIRVExecutionMode(SPIRVEntry *TheTarget, SPIRVExecutionModeKind TheExecMode)
672-
: SPIRVAnnotation(TheTarget, 3), ExecMode(TheExecMode) {
672+
SPIRVExecutionMode(Op OC, SPIRVEntry *TheTarget,
673+
SPIRVExecutionModeKind TheExecMode)
674+
: SPIRVAnnotation(OC, TheTarget, 3), ExecMode(TheExecMode) {
673675
updateModuleVersion();
674676
}
675677
// Incomplete constructor
676-
SPIRVExecutionMode() : ExecMode(ExecutionModeInvocations) {}
678+
SPIRVExecutionMode()
679+
: SPIRVAnnotation(OpExecutionMode), ExecMode(ExecutionModeInvocations) {}
677680
SPIRVExecutionModeKind getExecutionMode() const { return ExecMode; }
678681
const std::vector<SPIRVWord> &getLiterals() const { return WordLiterals; }
679682
SPIRVCapVec getRequiredCapability() const override {
@@ -699,6 +702,28 @@ class SPIRVExecutionMode : public SPIRVAnnotation<OpExecutionMode> {
699702
std::vector<SPIRVWord> WordLiterals;
700703
};
701704

705+
class SPIRVExecutionModeId : public SPIRVExecutionMode {
706+
public:
707+
// Complete constructor for LocalSizeId, LocalSizeHintId
708+
SPIRVExecutionModeId(SPIRVEntry *TheTarget,
709+
SPIRVExecutionModeKind TheExecMode, SPIRVWord X,
710+
SPIRVWord Y, SPIRVWord Z)
711+
: SPIRVExecutionMode(OpExecutionModeId, TheTarget, TheExecMode, X, Y, Z) {
712+
updateModuleVersion();
713+
}
714+
// Complete constructor for SubgroupsPerWorkgroupId
715+
SPIRVExecutionModeId(SPIRVEntry *TheTarget,
716+
SPIRVExecutionModeKind TheExecMode, SPIRVWord Code)
717+
: SPIRVExecutionMode(OpExecutionModeId, TheTarget, TheExecMode, Code) {
718+
updateModuleVersion();
719+
}
720+
// Incomplete constructor
721+
SPIRVExecutionModeId() : SPIRVExecutionMode() {}
722+
SPIRVWord getRequiredSPIRVVersion() const override {
723+
return static_cast<SPIRVWord>(VersionNumber::SPIRV_1_2);
724+
}
725+
};
726+
702727
class SPIRVComponentExecutionModes {
703728
typedef std::multimap<SPIRVExecutionModeKind, SPIRVExecutionMode *>
704729
SPIRVExecutionModeMap;

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,7 @@ _SPIRV_OP(TypePipeStorage, 322)
295295
_SPIRV_OP(ConstantPipeStorage, 323)
296296
_SPIRV_OP(CreatePipeFromPipeStorage, 324)
297297
_SPIRV_OP(ModuleProcessed, 330)
298+
_SPIRV_OP(ExecutionModeId, 331)
298299
_SPIRV_OP(DecorateId, 332)
299300
_SPIRV_OP(GroupNonUniformElect, 333)
300301
_SPIRV_OP(GroupNonUniformAll, 334)

0 commit comments

Comments
 (0)