@@ -265,6 +265,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
265
265
bool selectSpvThreadId (Register ResVReg, const SPIRVType *ResType,
266
266
MachineInstr &I) const ;
267
267
268
+ bool selectSpvGroupThreadId (Register ResVReg, const SPIRVType *ResType,
269
+ MachineInstr &I) const ;
270
+
268
271
bool selectWaveOpInst (Register ResVReg, const SPIRVType *ResType,
269
272
MachineInstr &I, unsigned Opcode) const ;
270
273
@@ -310,6 +313,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
310
313
void extractSubvector (Register &ResVReg, const SPIRVType *ResType,
311
314
Register &ReadReg, MachineInstr &InsertionPoint) const ;
312
315
bool BuildCOPY (Register DestReg, Register SrcReg, MachineInstr &I) const ;
316
+ bool loadVec3BuiltinInputID (SPIRV::BuiltIn::BuiltIn BuiltInValue,
317
+ Register ResVReg, const SPIRVType *ResType,
318
+ MachineInstr &I) const ;
313
319
};
314
320
315
321
} // end anonymous namespace
@@ -2826,6 +2832,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
2826
2832
break ;
2827
2833
case Intrinsic::spv_thread_id:
2828
2834
return selectSpvThreadId (ResVReg, ResType, I);
2835
+ case Intrinsic::spv_thread_id_in_group:
2836
+ return selectSpvGroupThreadId (ResVReg, ResType, I);
2829
2837
case Intrinsic::spv_fdot:
2830
2838
return selectFloatDot (ResVReg, ResType, I);
2831
2839
case Intrinsic::spv_udot:
@@ -3525,30 +3533,29 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg,
3525
3533
.constrainAllUses (TII, TRI, RBI);
3526
3534
}
3527
3535
3528
- bool SPIRVInstructionSelector::selectSpvThreadId (Register ResVReg,
3529
- const SPIRVType *ResType,
3530
- MachineInstr &I) const {
3531
- // DX intrinsic: @llvm.dx.thread.id(i32)
3532
- // ID Name Description
3533
- // 93 ThreadId reads the thread ID
3534
-
3536
+ // Generate the instructions to load 3-element vector builtin input
3537
+ // IDs/Indices.
3538
+ // Like: SV_DispatchThreadID, SV_GroupThreadID, etc....
3539
+ bool SPIRVInstructionSelector::loadVec3BuiltinInputID (
3540
+ SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg,
3541
+ const SPIRVType *ResType, MachineInstr &I) const {
3535
3542
MachineIRBuilder MIRBuilder (I);
3536
3543
const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType (32 , MIRBuilder);
3537
3544
const SPIRVType *Vec3Ty =
3538
3545
GR.getOrCreateSPIRVVectorType (U32Type, 3 , MIRBuilder);
3539
3546
const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType (
3540
3547
Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input);
3541
3548
3542
- // Create new register for GlobalInvocationID builtin variable.
3549
+ // Create new register for the input ID builtin variable.
3543
3550
Register NewRegister =
3544
3551
MIRBuilder.getMRI ()->createVirtualRegister (&SPIRV::iIDRegClass);
3545
3552
MIRBuilder.getMRI ()->setType (NewRegister, LLT::pointer (0 , 64 ));
3546
3553
GR.assignSPIRVTypeToVReg (PtrType, NewRegister, MIRBuilder.getMF ());
3547
3554
3548
- // Build GlobalInvocationID global variable with the necessary decorations.
3555
+ // Build global variable with the necessary decorations for the input ID
3556
+ // builtin variable.
3549
3557
Register Variable = GR.buildGlobalVariable (
3550
- NewRegister, PtrType,
3551
- getLinkStringForBuiltIn (SPIRV::BuiltIn::GlobalInvocationId), nullptr ,
3558
+ NewRegister, PtrType, getLinkStringForBuiltIn (BuiltInValue), nullptr ,
3552
3559
SPIRV::StorageClass::Input, nullptr , true , true ,
3553
3560
SPIRV::LinkageType::Import, MIRBuilder, false );
3554
3561
@@ -3565,12 +3572,12 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
3565
3572
.addUse (GR.getSPIRVTypeID (Vec3Ty))
3566
3573
.addUse (Variable);
3567
3574
3568
- // Get Thread ID index. Expecting operand is a constant immediate value,
3575
+ // Get the input ID index. Expecting operand is a constant immediate value,
3569
3576
// wrapped in a type assignment.
3570
3577
assert (I.getOperand (2 ).isReg ());
3571
3578
const uint32_t ThreadId = foldImm (I.getOperand (2 ), MRI);
3572
3579
3573
- // Extract the thread ID from the loaded vector value.
3580
+ // Extract the input ID from the loaded vector value.
3574
3581
MachineBasicBlock &BB = *I.getParent ();
3575
3582
auto MIB = BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpCompositeExtract))
3576
3583
.addDef (ResVReg)
@@ -3580,6 +3587,32 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg,
3580
3587
return Result && MIB.constrainAllUses (TII, TRI, RBI);
3581
3588
}
3582
3589
3590
+ bool SPIRVInstructionSelector::selectSpvThreadId (Register ResVReg,
3591
+ const SPIRVType *ResType,
3592
+ MachineInstr &I) const {
3593
+ // DX intrinsic: @llvm.dx.thread.id(i32)
3594
+ // ID Name Description
3595
+ // 93 ThreadId reads the thread ID
3596
+ //
3597
+ // In SPIR-V, llvm.dx.thread.id maps to a `GlobalInvocationId` builtin
3598
+ // variable
3599
+ return loadVec3BuiltinInputID (SPIRV::BuiltIn::GlobalInvocationId, ResVReg,
3600
+ ResType, I);
3601
+ }
3602
+
3603
+ bool SPIRVInstructionSelector::selectSpvGroupThreadId (Register ResVReg,
3604
+ const SPIRVType *ResType,
3605
+ MachineInstr &I) const {
3606
+ // DX intrinsic: @llvm.dx.thread.id.in.group(i32)
3607
+ // ID Name Description
3608
+ // 95 GroupThreadId Reads the thread ID within the group
3609
+ //
3610
+ // In SPIR-V, llvm.dx.thread.id.in.group maps to a `LocalInvocationId` builtin
3611
+ // variable
3612
+ return loadVec3BuiltinInputID (SPIRV::BuiltIn::LocalInvocationId, ResVReg,
3613
+ ResType, I);
3614
+ }
3615
+
3583
3616
SPIRVType *SPIRVInstructionSelector::widenTypeToVec4 (const SPIRVType *Type,
3584
3617
MachineInstr &I) const {
3585
3618
MachineIRBuilder MIRBuilder (I);
0 commit comments