-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SPIR-V] Ensure correct pointee types of some OpenCL Extended Instructions' pointer arguments #114846
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPIR-V] Ensure correct pointee types of some OpenCL Extended Instructions' pointer arguments #114846
Conversation
@llvm/pr-subscribers-backend-spir-v Author: Vyacheslav Levytskyy (VyacheslavLevytskyy) ChangesOpenCL Extended Instruction Set Specification defines relations between return/operand types and pointee type of pointer arguments in case of remquo, fract, frexp, lgamma_r, modf, sincos and prefetch instructions (https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html). This PR ensures correct pointee types of those OpenCL Extended Instructions' pointer arguments. Full diff: https://github.com/llvm/llvm-project/pull/114846.diff 2 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 682fca7cc7747c..a0b7a27a109ba0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -223,10 +223,10 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
}
-static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
- MachineRegisterInfo *MRI,
- SPIRVGlobalRegistry &GR, MachineInstr &I,
- unsigned OpIdx) {
+static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR,
+ MachineInstr &I, unsigned OpIdx) {
MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
Register OpTypeReg = getTypeReg(MRI, OpReg);
@@ -440,8 +440,8 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
validateLifetimeStart(STI, MRI, GR, MI);
break;
case SPIRV::OpGroupAsyncCopy:
- validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
- validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
+ validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
+ validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
break;
case SPIRV::OpGroupWaitEvents:
// OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
@@ -467,6 +467,49 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
if (Type->getParent() == Curr && !Curr->pred_empty())
ToMove.insert(const_cast<MachineInstr *>(Type));
} break;
+ case SPIRV::OpExtInst: {
+ // prefetch
+ if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
+ MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
+ continue;
+ switch (MI.getOperand(3).getImm()) {
+ case SPIRV::OpenCLExtInst::remquo: {
+ // The last operand must be of a pointer to the return type.
+ MachineIRBuilder MIB(MI);
+ SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
+ SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
+ assert(RetType && "Expect return type");
+ validatePtrTypes(
+ STI, MRI, GR, MI, MI.getNumOperands() - 1,
+ RetType->getOpcode() != SPIRV::OpTypeVector
+ ? Int32Type
+ : GR.getOrCreateSPIRVVectorType(
+ Int32Type, RetType->getOperand(2).getImm(), MIB));
+ } break;
+ case SPIRV::OpenCLExtInst::fract:
+ case SPIRV::OpenCLExtInst::frexp:
+ case SPIRV::OpenCLExtInst::lgamma_r:
+ case SPIRV::OpenCLExtInst::modf:
+ case SPIRV::OpenCLExtInst::sincos:
+ // The last operand must be of a pointer to the base type represented
+ // by the previous operand.
+ assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
+ "Expect v-reg");
+ validatePtrTypes(
+ STI, MRI, GR, MI, MI.getNumOperands() - 1,
+ GR.getSPIRVTypeForVReg(
+ MI.getOperand(MI.getNumOperands() - 2).getReg()));
+ break;
+ case SPIRV::OpenCLExtInst::prefetch:
+ // Expected `ptr` type is a pointer to float, integer or vector, but
+ // the pontee value can be wrapped into a struct.
+ assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
+ "Expect v-reg");
+ validatePtrUnwrapStructField(STI, MRI, GR, MI,
+ MI.getNumOperands() - 2);
+ break;
+ }
+ } break;
}
}
for (MachineInstr *MI : ToMove) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll b/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll
new file mode 100644
index 00000000000000..8e29876d61d339
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll
@@ -0,0 +1,34 @@
+; The goal of the test is to ensure that the output SPIR-V is valid from the perspective of the spirv-val tool.
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+%clsid = type { %arr }
+%arr = type { [1 x i64] }
+%struct_half = type { half }
+
+define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef readonly align 2 %_acc, ptr noundef byval(%clsid) align 8 %_acc_id, ptr addrspace(3) noundef align 2 %_arg_loc) {
+entry:
+ %r1 = load i64, ptr %_acc_id, align 8
+ %add.ptr.i41 = getelementptr inbounds %struct_half, ptr addrspace(1) %_acc, i64 %r1
+ %idx = addrspacecast ptr addrspace(1) %add.ptr.i41 to ptr addrspace(4)
+ %call.i.i290 = call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef %idx, i32 noundef 5)
+ call spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef %call.i.i290, i64 noundef 0)
+
+ %locidx = addrspacecast ptr addrspace(3) %_arg_loc to ptr addrspace(4)
+ %ptr1 = tail call spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef %locidx, i32 noundef 4)
+ %sincos_r = tail call spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef 0xH3145, ptr addrspace(3) noundef %ptr1)
+
+ %p1 = addrspacecast ptr addrspace(1) %_acc to ptr addrspace(4)
+ %ptr2 = tail call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef %p1, i32 noundef 5)
+ %remquo_r = tail call spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef 0xH3A37, half noundef 0xH32F4, ptr addrspace(1) noundef %ptr2)
+
+ ret void
+}
+
+declare dso_local spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef, i64 noundef)
+declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef, i32 noundef)
+
+declare dso_local spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef, ptr addrspace(3) noundef)
+declare dso_local spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef, i32 noundef)
+
+declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef, i32 noundef)
+declare dso_local spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef, half noundef, ptr addrspace(1) noundef)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, added just a comment about three typos
// Expected `ptr` type is a pointer to float, integer or vector, but | ||
// the pontee value can be wrapped into a struct. | ||
assert(MI.getOperand(MI.getNumOperands() - 2).isReg() && | ||
"Expect v-reg"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo expect
-> expected
, 2 more typos in other asserts above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
OpenCL Extended Instruction Set Specification defines relations between return/operand types and pointee type of pointer arguments in case of remquo, fract, frexp, lgamma_r, modf, sincos and prefetch instructions (https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html). This PR ensures correct pointee types of those OpenCL Extended Instructions' pointer arguments.