Skip to content

[SPIR-V] Validate type of the last parameter of OpGroupWaitEvents #93661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V Backend inserts a bitcast before OpGroupWaitEvents if the last argument is a pointer that doesn't point to OpTypeEvent.

@llvmbot
Copy link
Member

llvmbot commented May 29, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V Backend inserts a bitcast before OpGroupWaitEvents if the last argument is a pointer that doesn't point to OpTypeEvent.


Full diff: https://github.com/llvm/llvm-project/pull/93661.diff

2 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+73-26)
  • (added) llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll (+28)
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 2bd22bbd63169..5ccbaf12ddee2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -104,6 +104,47 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
   return std::make_pair(0u, RC);
 }
 
+inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
+  SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
+  return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
+             ? TypeInst->getOperand(1).getReg()
+             : OpReg;
+}
+
+static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
+                            SPIRVGlobalRegistry &GR, MachineInstr &I,
+                            Register OpReg, unsigned OpIdx,
+                            SPIRVType *NewPtrType) {
+  Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+  MachineIRBuilder MIB(I);
+  bool Res = MIB.buildInstr(SPIRV::OpBitcast)
+                 .addDef(NewReg)
+                 .addUse(GR.getSPIRVTypeID(NewPtrType))
+                 .addUse(OpReg)
+                 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
+                                   *STI.getRegBankInfo());
+  if (!Res)
+    report_fatal_error("insert validation bitcast: cannot constrain all uses");
+  MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
+  GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
+  I.getOperand(OpIdx).setReg(NewReg);
+}
+
+static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
+                                   SPIRVType *OpType, bool ReuseType,
+                                   bool EmitIR, SPIRVType *ResType,
+                                   const Type *ResTy) {
+  SPIRV::StorageClass::StorageClass SC =
+      static_cast<SPIRV::StorageClass::StorageClass>(
+          OpType->getOperand(1).getImm());
+  MachineIRBuilder MIB(I);
+  SPIRVType *NewBaseType =
+      ReuseType ? ResType
+                : GR.getOrCreateSPIRVType(
+                      ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
+  return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+}
+
 // Insert a bitcast before the instruction to keep SPIR-V code valid
 // when there is a type mismatch between results and operand types.
 static void validatePtrTypes(const SPIRVSubtarget &STI,
@@ -113,11 +154,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
   // Get operand type
   MachineFunction *MF = I.getParent()->getParent();
   Register OpReg = I.getOperand(OpIdx).getReg();
-  SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
-  Register OpTypeReg =
-      TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
-          ? TypeInst->getOperand(1).getReg()
-          : OpReg;
+  Register OpTypeReg = getTypeReg(MRI, OpReg);
   SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
   if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
     return;
@@ -134,30 +171,36 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
     return;
   // There is a type mismatch between results and operand types
   // and we insert a bitcast before the instruction to keep SPIR-V code valid
-  SPIRV::StorageClass::StorageClass SC =
-      static_cast<SPIRV::StorageClass::StorageClass>(
-          OpType->getOperand(1).getImm());
-  MachineIRBuilder MIB(I);
-  SPIRVType *NewBaseType =
-      IsSameMF ? ResType
-               : GR.getOrCreateSPIRVType(
-                     ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
-  SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+  SPIRVType *NewPtrType =
+      createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
   if (!GR.isBitcastCompatible(NewPtrType, OpType))
     report_fatal_error(
         "insert validation bitcast: incompatible result and operand types");
-  Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
-  bool Res = MIB.buildInstr(SPIRV::OpBitcast)
-                 .addDef(NewReg)
-                 .addUse(GR.getSPIRVTypeID(NewPtrType))
-                 .addUse(OpReg)
-                 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
-                                   *STI.getRegBankInfo());
-  if (!Res)
-    report_fatal_error("insert validation bitcast: cannot constrain all uses");
-  MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
-  GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
-  I.getOperand(OpIdx).setReg(NewReg);
+  doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
+}
+
+// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
+// that doesn't point to OpTypeEvent.
+static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
+                                       MachineRegisterInfo *MRI,
+                                       SPIRVGlobalRegistry &GR,
+                                       MachineInstr &I) {
+  constexpr unsigned OpIdx = 2;
+  MachineFunction *MF = I.getParent()->getParent();
+  Register OpReg = I.getOperand(OpIdx).getReg();
+  Register OpTypeReg = getTypeReg(MRI, OpReg);
+  SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
+  if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+    return;
+  SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+  if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
+    return;
+  // Insert a bitcast before the instruction to keep SPIR-V code valid.
+  LLVMContext &Context = MF->getMMI().getModule()->getContext();
+  SPIRVType *NewPtrType =
+      createNewPtrType(GR, I, OpType, false, true, nullptr,
+                       TargetExtType::get(Context, "spirv.Event"));
+  doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
 }
 
 // Insert a bitcast before the function call instruction to keep SPIR-V code
@@ -336,6 +379,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
                                       SPIRV::OpTypeBool))
           MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
         break;
+      case SPIRV::OpGroupWaitEvents:
+        // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
+        validateGroupWaitEventsPtr(STI, MRI, GR, MI);
+        break;
       case SPIRV::OpConstantI: {
         SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
         if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
diff --git a/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
new file mode 100644
index 0000000000000..d6fb70bb59a7e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#EventTy:]] = OpTypeEvent
+; CHECK: %[[#StructEventTy:]] = OpTypeStruct %[[#EventTy]]
+; CHECK: %[[#GenPtrStructEventTy:]] = OpTypePointer Generic %[[#StructEventTy]]
+; CHECK: %[[#FunPtrStructEventTy:]] = OpTypePointer Function %[[#StructEventTy]]
+; CHECK: %[[#GenPtrEventTy:]] = OpTypePointer Generic %[[#EventTy:]]
+; CHECK: OpFunction
+; CHECK: %[[#Var:]] = OpVariable %[[#FunPtrStructEventTy]] Function
+; CHECK-NEXT: %[[#AddrspacecastVar:]] = OpPtrCastToGeneric %[[#GenPtrStructEventTy]] %[[#Var]]
+; CHECK-NEXT: %[[#BitcastVar:]] = OpBitcast %[[#GenPtrEventTy]] %[[#AddrspacecastVar]]
+; CHECK-NEXT: OpGroupWaitEvents %[[#]] %[[#]] %[[#BitcastVar]]
+
+%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+
+define weak_odr dso_local spir_kernel void @foo() {
+entry:
+  %var = alloca %"class.sycl::_V1::device_event"
+  %eventptr = addrspacecast ptr %var to ptr addrspace(4)
+  call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %eventptr)
+  ret void
+}
+
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))

@VyacheslavLevytskyy VyacheslavLevytskyy changed the title [SPIR-V] Insert a bitcast before OpGroupWaitEvents when the last parameter requires [SPIR-V] Validate type of the last parameter of OpGroupWaitEvents May 29, 2024
@VyacheslavLevytskyy VyacheslavLevytskyy merged commit ce73e17 into llvm:main Jun 3, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants