Skip to content

Commit 5a06219

Browse files
[SPIR-V] Ensure correct pointee types of some OpenCL Extended Instructions' pointer arguments (#114846)
OpenCL Extended Instruction Set Specification defines relations between return/operand types and pointee type of pointer arguments in case of remquo, fract, frexp, lgamma_r, modf, sincos and prefetch instructions (https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html). This PR ensures correct pointee types of those OpenCL Extended Instructions' pointer arguments.
1 parent d8139ae commit 5a06219

File tree

2 files changed

+83
-6
lines changed

2 files changed

+83
-6
lines changed

llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,10 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
223223
doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
224224
}
225225

226-
static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
227-
MachineRegisterInfo *MRI,
228-
SPIRVGlobalRegistry &GR, MachineInstr &I,
229-
unsigned OpIdx) {
226+
static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
227+
MachineRegisterInfo *MRI,
228+
SPIRVGlobalRegistry &GR,
229+
MachineInstr &I, unsigned OpIdx) {
230230
MachineFunction *MF = I.getParent()->getParent();
231231
Register OpReg = I.getOperand(OpIdx).getReg();
232232
Register OpTypeReg = getTypeReg(MRI, OpReg);
@@ -440,8 +440,8 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
440440
validateLifetimeStart(STI, MRI, GR, MI);
441441
break;
442442
case SPIRV::OpGroupAsyncCopy:
443-
validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
444-
validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
443+
validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
444+
validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
445445
break;
446446
case SPIRV::OpGroupWaitEvents:
447447
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
@@ -467,6 +467,49 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
467467
if (Type->getParent() == Curr && !Curr->pred_empty())
468468
ToMove.insert(const_cast<MachineInstr *>(Type));
469469
} break;
470+
case SPIRV::OpExtInst: {
471+
// prefetch
472+
if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
473+
MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
474+
continue;
475+
switch (MI.getOperand(3).getImm()) {
476+
case SPIRV::OpenCLExtInst::remquo: {
477+
// The last operand must be of a pointer to the return type.
478+
MachineIRBuilder MIB(MI);
479+
SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
480+
SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
481+
assert(RetType && "Expected return type");
482+
validatePtrTypes(
483+
STI, MRI, GR, MI, MI.getNumOperands() - 1,
484+
RetType->getOpcode() != SPIRV::OpTypeVector
485+
? Int32Type
486+
: GR.getOrCreateSPIRVVectorType(
487+
Int32Type, RetType->getOperand(2).getImm(), MIB));
488+
} break;
489+
case SPIRV::OpenCLExtInst::fract:
490+
case SPIRV::OpenCLExtInst::frexp:
491+
case SPIRV::OpenCLExtInst::lgamma_r:
492+
case SPIRV::OpenCLExtInst::modf:
493+
case SPIRV::OpenCLExtInst::sincos:
494+
// The last operand must be of a pointer to the base type represented
495+
// by the previous operand.
496+
assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
497+
"Expected v-reg");
498+
validatePtrTypes(
499+
STI, MRI, GR, MI, MI.getNumOperands() - 1,
500+
GR.getSPIRVTypeForVReg(
501+
MI.getOperand(MI.getNumOperands() - 2).getReg()));
502+
break;
503+
case SPIRV::OpenCLExtInst::prefetch:
504+
// Expected `ptr` type is a pointer to float, integer or vector, but
505+
// the pontee value can be wrapped into a struct.
506+
assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
507+
"Expected v-reg");
508+
validatePtrUnwrapStructField(STI, MRI, GR, MI,
509+
MI.getNumOperands() - 2);
510+
break;
511+
}
512+
} break;
470513
}
471514
}
472515
for (MachineInstr *MI : ToMove) {
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; The goal of the test is to ensure that the output SPIR-V is valid from the perspective of the spirv-val tool.
2+
; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
%clsid = type { %arr }
5+
%arr = type { [1 x i64] }
6+
%struct_half = type { half }
7+
8+
define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef readonly align 2 %_acc, ptr noundef byval(%clsid) align 8 %_acc_id, ptr addrspace(3) noundef align 2 %_arg_loc) {
9+
entry:
10+
%r1 = load i64, ptr %_acc_id, align 8
11+
%add.ptr.i41 = getelementptr inbounds %struct_half, ptr addrspace(1) %_acc, i64 %r1
12+
%idx = addrspacecast ptr addrspace(1) %add.ptr.i41 to ptr addrspace(4)
13+
%call.i.i290 = call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef %idx, i32 noundef 5)
14+
call spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef %call.i.i290, i64 noundef 0)
15+
16+
%locidx = addrspacecast ptr addrspace(3) %_arg_loc to ptr addrspace(4)
17+
%ptr1 = tail call spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef %locidx, i32 noundef 4)
18+
%sincos_r = tail call spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef 0xH3145, ptr addrspace(3) noundef %ptr1)
19+
20+
%p1 = addrspacecast ptr addrspace(1) %_acc to ptr addrspace(4)
21+
%ptr2 = tail call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef %p1, i32 noundef 5)
22+
%remquo_r = tail call spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef 0xH3A37, half noundef 0xH32F4, ptr addrspace(1) noundef %ptr2)
23+
24+
ret void
25+
}
26+
27+
declare dso_local spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef, i64 noundef)
28+
declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef, i32 noundef)
29+
30+
declare dso_local spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef, ptr addrspace(3) noundef)
31+
declare dso_local spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef, i32 noundef)
32+
33+
declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef, i32 noundef)
34+
declare dso_local spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef, half noundef, ptr addrspace(1) noundef)

0 commit comments

Comments
 (0)