@@ -223,10 +223,10 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
223
223
doInsertBitcast (STI, MRI, GR, I, PtrReg, 0 , NewPtrType);
224
224
}
225
225
226
- static void validateGroupAsyncCopyPtr (const SPIRVSubtarget &STI,
227
- MachineRegisterInfo *MRI,
228
- SPIRVGlobalRegistry &GR, MachineInstr &I ,
229
- unsigned OpIdx) {
226
+ static void validatePtrUnwrapStructField (const SPIRVSubtarget &STI,
227
+ MachineRegisterInfo *MRI,
228
+ SPIRVGlobalRegistry &GR ,
229
+ MachineInstr &I, unsigned OpIdx) {
230
230
MachineFunction *MF = I.getParent ()->getParent ();
231
231
Register OpReg = I.getOperand (OpIdx).getReg ();
232
232
Register OpTypeReg = getTypeReg (MRI, OpReg);
@@ -440,8 +440,8 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
440
440
validateLifetimeStart (STI, MRI, GR, MI);
441
441
break ;
442
442
case SPIRV::OpGroupAsyncCopy:
443
- validateGroupAsyncCopyPtr (STI, MRI, GR, MI, 3 );
444
- validateGroupAsyncCopyPtr (STI, MRI, GR, MI, 4 );
443
+ validatePtrUnwrapStructField (STI, MRI, GR, MI, 3 );
444
+ validatePtrUnwrapStructField (STI, MRI, GR, MI, 4 );
445
445
break ;
446
446
case SPIRV::OpGroupWaitEvents:
447
447
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
@@ -467,6 +467,49 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
467
467
if (Type->getParent () == Curr && !Curr->pred_empty ())
468
468
ToMove.insert (const_cast <MachineInstr *>(Type));
469
469
} break ;
470
+ case SPIRV::OpExtInst: {
471
+ // prefetch
472
+ if (!MI.getOperand (2 ).isImm () || !MI.getOperand (3 ).isImm () ||
473
+ MI.getOperand (2 ).getImm () != SPIRV::InstructionSet::OpenCL_std)
474
+ continue ;
475
+ switch (MI.getOperand (3 ).getImm ()) {
476
+ case SPIRV::OpenCLExtInst::remquo: {
477
+ // The last operand must be of a pointer to the return type.
478
+ MachineIRBuilder MIB (MI);
479
+ SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType (32 , MIB);
480
+ SPIRVType *RetType = MRI->getVRegDef (MI.getOperand (1 ).getReg ());
481
+ assert (RetType && " Expected return type" );
482
+ validatePtrTypes (
483
+ STI, MRI, GR, MI, MI.getNumOperands () - 1 ,
484
+ RetType->getOpcode () != SPIRV::OpTypeVector
485
+ ? Int32Type
486
+ : GR.getOrCreateSPIRVVectorType (
487
+ Int32Type, RetType->getOperand (2 ).getImm (), MIB));
488
+ } break ;
489
+ case SPIRV::OpenCLExtInst::fract:
490
+ case SPIRV::OpenCLExtInst::frexp:
491
+ case SPIRV::OpenCLExtInst::lgamma_r:
492
+ case SPIRV::OpenCLExtInst::modf:
493
+ case SPIRV::OpenCLExtInst::sincos:
494
+ // The last operand must be of a pointer to the base type represented
495
+ // by the previous operand.
496
+ assert (MI.getOperand (MI.getNumOperands () - 2 ).isReg () &&
497
+ " Expected v-reg" );
498
+ validatePtrTypes (
499
+ STI, MRI, GR, MI, MI.getNumOperands () - 1 ,
500
+ GR.getSPIRVTypeForVReg (
501
+ MI.getOperand (MI.getNumOperands () - 2 ).getReg ()));
502
+ break ;
503
+ case SPIRV::OpenCLExtInst::prefetch:
504
+ // Expected `ptr` type is a pointer to float, integer or vector, but
505
+ // the pontee value can be wrapped into a struct.
506
+ assert (MI.getOperand (MI.getNumOperands () - 2 ).isReg () &&
507
+ " Expected v-reg" );
508
+ validatePtrUnwrapStructField (STI, MRI, GR, MI,
509
+ MI.getNumOperands () - 2 );
510
+ break ;
511
+ }
512
+ } break ;
470
513
}
471
514
}
472
515
for (MachineInstr *MI : ToMove) {
0 commit comments