Skip to content

[SPIR-V] Ensure correct pointee types of some OpenCL Extended Instructions' pointer arguments #114846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

OpenCL Extended Instruction Set Specification defines relations between return/operand types and pointee type of pointer arguments in case of remquo, fract, frexp, lgamma_r, modf, sincos and prefetch instructions (https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html). This PR ensures correct pointee types of those OpenCL Extended Instructions' pointer arguments.

@llvmbot
Copy link
Member

llvmbot commented Nov 4, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

OpenCL Extended Instruction Set Specification defines relations between return/operand types and pointee type of pointer arguments in case of remquo, fract, frexp, lgamma_r, modf, sincos and prefetch instructions (https://registry.khronos.org/SPIR-V/specs/unified1/OpenCL.ExtendedInstructionSet.100.html). This PR ensures correct pointee types of those OpenCL Extended Instructions' pointer arguments.


Full diff: https://github.com/llvm/llvm-project/pull/114846.diff

2 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+49-6)
  • (added) llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll (+34)
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 682fca7cc7747c..a0b7a27a109ba0 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -223,10 +223,10 @@ static void validateLifetimeStart(const SPIRVSubtarget &STI,
   doInsertBitcast(STI, MRI, GR, I, PtrReg, 0, NewPtrType);
 }
 
-static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI,
-                                      MachineRegisterInfo *MRI,
-                                      SPIRVGlobalRegistry &GR, MachineInstr &I,
-                                      unsigned OpIdx) {
+static void validatePtrUnwrapStructField(const SPIRVSubtarget &STI,
+                                         MachineRegisterInfo *MRI,
+                                         SPIRVGlobalRegistry &GR,
+                                         MachineInstr &I, unsigned OpIdx) {
   MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
   Register OpTypeReg = getTypeReg(MRI, OpReg);
@@ -440,8 +440,8 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
           validateLifetimeStart(STI, MRI, GR, MI);
         break;
       case SPIRV::OpGroupAsyncCopy:
-        validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3);
-        validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4);
+        validatePtrUnwrapStructField(STI, MRI, GR, MI, 3);
+        validatePtrUnwrapStructField(STI, MRI, GR, MI, 4);
         break;
       case SPIRV::OpGroupWaitEvents:
         // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
@@ -467,6 +467,49 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
         if (Type->getParent() == Curr && !Curr->pred_empty())
           ToMove.insert(const_cast<MachineInstr *>(Type));
       } break;
+      case SPIRV::OpExtInst: {
+        // prefetch
+        if (!MI.getOperand(2).isImm() || !MI.getOperand(3).isImm() ||
+            MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
+          continue;
+        switch (MI.getOperand(3).getImm()) {
+        case SPIRV::OpenCLExtInst::remquo: {
+          // The last operand must be of a pointer to the return type.
+          MachineIRBuilder MIB(MI);
+          SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
+          SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
+          assert(RetType && "Expect return type");
+          validatePtrTypes(
+              STI, MRI, GR, MI, MI.getNumOperands() - 1,
+              RetType->getOpcode() != SPIRV::OpTypeVector
+                  ? Int32Type
+                  : GR.getOrCreateSPIRVVectorType(
+                        Int32Type, RetType->getOperand(2).getImm(), MIB));
+        } break;
+        case SPIRV::OpenCLExtInst::fract:
+        case SPIRV::OpenCLExtInst::frexp:
+        case SPIRV::OpenCLExtInst::lgamma_r:
+        case SPIRV::OpenCLExtInst::modf:
+        case SPIRV::OpenCLExtInst::sincos:
+          // The last operand must be of a pointer to the base type represented
+          // by the previous operand.
+          assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
+                 "Expect v-reg");
+          validatePtrTypes(
+              STI, MRI, GR, MI, MI.getNumOperands() - 1,
+              GR.getSPIRVTypeForVReg(
+                  MI.getOperand(MI.getNumOperands() - 2).getReg()));
+          break;
+        case SPIRV::OpenCLExtInst::prefetch:
+          // Expected `ptr` type is a pointer to float, integer or vector, but
+          // the pontee value can be wrapped into a struct.
+          assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
+                 "Expect v-reg");
+          validatePtrUnwrapStructField(STI, MRI, GR, MI,
+                                       MI.getNumOperands() - 2);
+          break;
+        }
+      } break;
       }
     }
     for (MachineInstr *MI : ToMove) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll b/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll
new file mode 100644
index 00000000000000..8e29876d61d339
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/OpExtInst-OpenCL_std-ptr-types.ll
@@ -0,0 +1,34 @@
+; The goal of the test is to ensure that the output SPIR-V is valid from the perspective of the spirv-val tool.
+; RUN: %if spirv-tools %{ llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+%clsid = type { %arr }
+%arr = type { [1 x i64] }
+%struct_half = type { half }
+
+define weak_odr dso_local spir_kernel void @foo(ptr addrspace(1) noundef readonly align 2 %_acc, ptr noundef byval(%clsid) align 8 %_acc_id, ptr addrspace(3) noundef align 2 %_arg_loc) {
+entry:
+  %r1 = load i64, ptr %_acc_id, align 8
+  %add.ptr.i41 = getelementptr inbounds %struct_half, ptr addrspace(1) %_acc, i64 %r1
+  %idx = addrspacecast ptr addrspace(1) %add.ptr.i41 to ptr addrspace(4)
+  %call.i.i290 = call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef %idx, i32 noundef 5)
+  call spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef %call.i.i290, i64 noundef 0)
+
+  %locidx = addrspacecast ptr addrspace(3) %_arg_loc to ptr addrspace(4)
+  %ptr1 = tail call spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef %locidx, i32 noundef 4)
+  %sincos_r = tail call spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef 0xH3145, ptr addrspace(3) noundef %ptr1)
+
+  %p1 = addrspacecast ptr addrspace(1) %_acc to ptr addrspace(4)
+  %ptr2 = tail call spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef %p1, i32 noundef 5)
+  %remquo_r = tail call spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef 0xH3A37, half noundef 0xH32F4, ptr addrspace(1) noundef %ptr2)
+
+  ret void
+}
+
+declare dso_local spir_func void @_Z20__spirv_ocl_prefetchPU3AS1Kcm(ptr addrspace(1) noundef, i64 noundef)
+declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPKvi(ptr addrspace(4) noundef, i32 noundef)
+
+declare dso_local spir_func noundef half @_Z18__spirv_ocl_sincosDF16_PU3AS3DF16_(half noundef, ptr addrspace(3) noundef)
+declare dso_local spir_func noundef ptr addrspace(3) @_Z40__spirv_GenericCastToPtrExplicit_ToLocalPvi(ptr addrspace(4) noundef, i32 noundef)
+
+declare dso_local spir_func noundef ptr addrspace(1) @_Z41__spirv_GenericCastToPtrExplicit_ToGlobalPvi(ptr addrspace(4) noundef, i32 noundef)
+declare dso_local spir_func noundef half @_Z18__spirv_ocl_remquoDF16_DF16_PU3AS1i(half noundef, half noundef, ptr addrspace(1) noundef)

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, added just a comment about three typos

// Expected `ptr` type is a pointer to float, integer or vector, but
// the pontee value can be wrapped into a struct.
assert(MI.getOperand(MI.getNumOperands() - 2).isReg() &&
"Expect v-reg");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo expect -> expected, 2 more typos in other asserts above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 5a06219 into llvm:main Nov 6, 2024
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants