@@ -265,6 +265,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
265
265
bool selectSpvThreadId (Register ResVReg, const SPIRVType *ResType,
266
266
MachineInstr &I) const ;
267
267
268
+ bool selectSpvGroupThreadId (Register ResVReg, const SPIRVType *ResType,
269
+ MachineInstr &I) const ;
270
+
268
271
bool selectWaveOpInst (Register ResVReg, const SPIRVType *ResType,
269
272
MachineInstr &I, unsigned Opcode) const ;
270
273
@@ -309,6 +312,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
309
312
SPIRVType *widenTypeToVec4 (const SPIRVType *Type, MachineInstr &I) const ;
310
313
void extractSubvector (Register &ResVReg, const SPIRVType *ResType,
311
314
Register &ReadReg, MachineInstr &InsertionPoint) const ;
315
+ bool loadVec3BuiltinInputID (SPIRV::BuiltIn::BuiltIn BuiltInValue,
316
+ Register ResVReg, const SPIRVType *ResType,
317
+ MachineInstr &I) const ;
312
318
};
313
319
314
320
} // end anonymous namespace
@@ -2852,6 +2858,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
2852
2858
break ;
2853
2859
case Intrinsic::spv_thread_id:
2854
2860
return selectSpvThreadId (ResVReg, ResType, I);
2861
+ case Intrinsic::spv_thread_id_in_group:
2862
+ return selectSpvGroupThreadId (ResVReg, ResType, I);
2855
2863
case Intrinsic::spv_fdot:
2856
2864
return selectFloatDot (ResVReg, ResType, I);
2857
2865
case Intrinsic::spv_udot:
@@ -3551,30 +3559,29 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
3551
3559
.constrainAllUses (TII, TRI, RBI);
3552
3560
}
3553
3561
3554
- bool SPIRVInstructionSelector::selectSpvThreadId (Register ResVReg,
3555
- const SPIRVType *ResType,
3556
- MachineInstr &I) const {
3557
- // DX intrinsic: @llvm.dx.thread.id(i32)
3558
- // ID Name Description
3559
- // 93 ThreadId reads the thread ID
3560
-
3562
+ // Generate the instructions to load 3-element vector builtin input
3563
+ // IDs/Indices.
3564
+ // Like: SV_DispatchThreadID, SV_GroupThreadID, etc....
3565
+ bool SPIRVInstructionSelector::loadVec3BuiltinInputID (
3566
+ SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg,
3567
+ const SPIRVType *ResType, MachineInstr &I) const {
3561
3568
MachineIRBuilder MIRBuilder (I);
3562
3569
const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType (32 , MIRBuilder);
3563
3570
const SPIRVType *Vec3Ty =
3564
3571
GR.getOrCreateSPIRVVectorType (U32Type, 3 , MIRBuilder);
3565
3572
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType (
3566
3573
Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);
3567
3574
3568
- // Create new register for GlobalInvocationID builtin variable.
3575
+ // Create new register for the input ID builtin variable.
3569
3576
Register NewRegister =
3570
3577
MIRBuilder.getMRI ()->createVirtualRegister (&SPIRV::iIDRegClass);
3571
3578
MIRBuilder.getMRI ()->setType (NewRegister, LLT::pointer (0 , 64 ));
3572
3579
GR.assignSPIRVTypeToVReg (PtrType, NewRegister, MIRBuilder.getMF ());
3573
3580
3574
- // Build GlobalInvocationID global variable with the necessary decorations.
3581
+ // Build global variable with the necessary decorations for the input ID
3582
+ // builtin variable.
3575
3583
Register Variable = GR.buildGlobalVariable (
3576
- NewRegister, PtrType,
3577
- getLinkStringForBuiltIn (SPIRV::BuiltIn::GlobalInvocationId), nullptr ,
3584
+ NewRegister, PtrType, getLinkStringForBuiltIn (BuiltInValue), nullptr ,
3578
3585
SPIRV::StorageClass::Input, nullptr , true , true ,
3579
3586
SPIRV::LinkageType::Import, MIRBuilder, false );
3580
3587
@@ -3591,12 +3598,12 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
3591
3598
.addUse (GR.getSPIRVTypeID (Vec3Ty))
3592
3599
.addUse (Variable);
3593
3600
3594
- // Get Thread ID index. Expecting operand is a constant immediate value,
3601
+ // Get the input ID index. Expecting operand is a constant immediate value,
3595
3602
// wrapped in a type assignment.
3596
3603
assert (I.getOperand (2 ).isReg ());
3597
3604
const uint32_t ThreadId = foldImm (I.getOperand (2 ), MRI);
3598
3605
3599
- // Extract the thread ID from the loaded vector value.
3606
+ // Extract the input ID from the loaded vector value.
3600
3607
MachineBasicBlock &BB = *I.getParent ();
3601
3608
auto MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpCompositeExtract))
3602
3609
.addDef (ResVReg)
@@ -3606,6 +3613,32 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
3606
3613
return Result && MIB.constrainAllUses (TII, TRI, RBI);
3607
3614
}
3608
3615
3616
+ bool SPIRVInstructionSelector::selectSpvThreadId (Register ResVReg,
3617
+ const SPIRVType *ResType,
3618
+ MachineInstr &I) const {
3619
+ // DX intrinsic: @llvm.dx.thread.id(i32)
3620
+ // ID Name Description
3621
+ // 93 ThreadId reads the thread ID
3622
+ //
3623
+ // In SPIR-V, llvm.dx.thread.id maps to a `GlobalInvocationId` builtin
3624
+ // variable
3625
+ return loadVec3BuiltinInputID (SPIRV::BuiltIn::GlobalInvocationId, ResVReg,
3626
+ ResType, I);
3627
+ }
3628
+
3629
+ bool SPIRVInstructionSelector::selectSpvGroupThreadId (Register ResVReg,
3630
+ const SPIRVType *ResType,
3631
+ MachineInstr &I) const {
3632
+ // DX intrinsic: @llvm.dx.thread.id.in.group(i32)
3633
+ // ID Name Description
3634
+ // 95 GroupThreadId Reads the thread ID within the group
3635
+ //
3636
+ // In SPIR-V, llvm.dx.thread.id.in.group maps to a `LocalInvocationId` builtin
3637
+ // variable
3638
+ return loadVec3BuiltinInputID (SPIRV::BuiltIn::LocalInvocationId, ResVReg,
3639
+ ResType, I);
3640
+ }
3641
+
3609
3642
SPIRVType *SPIRVInstructionSelector::widenTypeToVec4 (const SPIRVType *Type,
3610
3643
MachineInstr &I) const {
3611
3644
MachineIRBuilder MIRBuilder (I);
0 commit comments