Skip to content

[SPIR-V] Ensure correct pointee types of some OpenCL Extended Instructions' pointer arguments #114846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,10 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
}

static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI,
SPIRVGlobalRegistry &GR, MachineInstr &I,
unsigned OpIdx) {
static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
MachineRegisterInfo *MRI,
SPIRVGlobalRegistry &GR,
MachineInstr &I, unsigned OpIdx) {
MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
Register OpTypeReg = getTypeReg(MRI, OpReg);
Expand Down Expand Up @@ -440,8 +440,8 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
validateLifetimeStart(STI, MRI, GR, MI);
break;
case SPIRV::OpGroupAsyncCopy:
validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
break;
case SPIRV::OpGroupWaitEvents:
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
Expand All @@ -467,6 +467,49 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
if (Type->getParent() == Curr && !Curr->pred_empty())
ToMove.insert(const_cast<MachineInstr *>(Type));
} break;
case SPIRV::OpExtInst: {
// prefetch
if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
continue;
switch (MI.getOperand(3).getImm()) {
case SPIRV::OpenCLExtInst::remquo: {
// The last operand must be of a pointer to the return type.
MachineIRBuilder MIB(MI);
SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
assert(RetType && "Expected return type");
validatePtrTypes(
STI, MRI, GR, MI, MI.getNumOperands() - 1,
RetType->getOpcode() != SPIRV::OpTypeVector
? Int32Type
: GR.getOrCreateSPIRVVectorType(
Int32Type, RetType->getOperand(2).getImm(), MIB));
} break;
case SPIRV::OpenCLExtInst::fract:
case SPIRV::OpenCLExtInst::frexp:
case SPIRV::OpenCLExtInst::lgamma_r:
case SPIRV::OpenCLExtInst::modf:
case SPIRV::OpenCLExtInst::sincos:
// The last operand must be of a pointer to the base type represented
// by the previous operand.
assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
"Expected v-reg");
validatePtrTypes(
STI, MRI, GR, MI, MI.getNumOperands() - 1,
GR.getSPIRVTypeForVReg(
MI.getOperand(MI.getNumOperands() - 2).getReg()));
break;
case SPIRV::OpenCLExtInst::prefetch:
// Expected `ptr` type is a pointer to float, integer or vector, but
// the pontee value can be wrapped into a struct.
assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
"Expected v-reg");
validatePtrUnwrapStructField(STI, MRI, GR, MI,
MI.getNumOperands() - 2);
break;
}
} break;
}
}
for (MachineInstr *MI : ToMove) {
Expand Down
34 changes: 34 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; The goal of the test is to ensure that the output SPIR-V is valid from the perspective of the spirv-val tool.
; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

%clsid = type { %arr }
%arr = type { [1 x i64] }
%struct_half = type { half }

define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef readonly align 2 %_acc, ptr noundef byval(%clsid) align 8 %_acc_id, ptr addrspace(3) noundef align 2 %_arg_loc) {
entry:
%r1 = load i64, ptr %_acc_id, align 8
%add.ptr.i41 = getelementptr inbounds %struct_half, ptr addrspace(1) %_acc, i64 %r1
%idx = addrspacecast ptr addrspace(1) %add.ptr.i41 to ptr addrspace(4)
%call.i.i290 = call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef %idx, i32 noundef 5)
call spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef %call.i.i290, i64 noundef 0)

%locidx = addrspacecast ptr addrspace(3) %_arg_loc to ptr addrspace(4)
%ptr1 = tail call spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef %locidx, i32 noundef 4)
%sincos_r = tail call spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef 0xH3145, ptr addrspace(3) noundef %ptr1)

%p1 = addrspacecast ptr addrspace(1) %_acc to ptr addrspace(4)
%ptr2 = tail call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef %p1, i32 noundef 5)
%remquo_r = tail call spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef 0xH3A37, half noundef 0xH32F4, ptr addrspace(1) noundef %ptr2)

ret void
}

declare dso_local spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef, i64 noundef)
declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef, i32 noundef)

declare dso_local spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef, ptr addrspace(3) noundef)
declare dso_local spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef, i32 noundef)

declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef, i32 noundef)
declare dso_local spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef, half noundef, ptr addrspace(1) noundef)
Loading