-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[SPIR-V] Validate type of the last parameter of OpGroupWaitEvents #93661
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
VyacheslavLevytskyy
merged 2 commits into
llvm:main
from
VyacheslavLevytskyy:validate_OpGroupWaitEvents
Jun 3, 2024
Merged
[SPIR-V] Validate type of the last parameter of OpGroupWaitEvents #93661
VyacheslavLevytskyy
merged 2 commits into
llvm:main
from
VyacheslavLevytskyy:validate_OpGroupWaitEvents
Jun 3, 2024
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-backend-spir-v Author: Vyacheslav Levytskyy (VyacheslavLevytskyy) ChangesThis PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V Backend inserts a bitcast before OpGroupWaitEvents if the last argument is a pointer that doesn't point to OpTypeEvent. Full diff: https://github.com/llvm/llvm-project/pull/93661.diff 2 Files Affected:
diff --git a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
index 2bd22bbd63169..5ccbaf12ddee2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
@@ -104,6 +104,47 @@ SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
return std::make_pair(0u, RC);
}
+inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) {
+ SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
+ return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
+ ? TypeInst->getOperand(1).getReg()
+ : OpReg;
+}
+
+static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR, MachineInstr &I,
+ Register OpReg, unsigned OpIdx,
+ SPIRVType *NewPtrType) {
+ Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
+ MachineIRBuilder MIB(I);
+ bool Res = MIB.buildInstr(SPIRV::OpBitcast)
+ .addDef(NewReg)
+ .addUse(GR.getSPIRVTypeID(NewPtrType))
+ .addUse(OpReg)
+ .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
+ *STI.getRegBankInfo());
+ if (!Res)
+ report_fatal_error("insert validation bitcast: cannot constrain all uses");
+ MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
+ GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
+ I.getOperand(OpIdx).setReg(NewReg);
+}
+
+static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I,
+ SPIRVType *OpType, bool ReuseType,
+ bool EmitIR, SPIRVType *ResType,
+ const Type *ResTy) {
+ SPIRV::StorageClass::StorageClass SC =
+ static_cast<SPIRV::StorageClass::StorageClass>(
+ OpType->getOperand(1).getImm());
+ MachineIRBuilder MIB(I);
+ SPIRVType *NewBaseType =
+ ReuseType ? ResType
+ : GR.getOrCreateSPIRVType(
+ ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR);
+ return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+}
+
// Insert a bitcast before the instruction to keep SPIR-V code valid
// when there is a type mismatch between results and operand types.
static void validatePtrTypes(const SPIRVSubtarget &STI,
@@ -113,11 +154,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
// Get operand type
MachineFunction *MF = I.getParent()->getParent();
Register OpReg = I.getOperand(OpIdx).getReg();
- SPIRVType *TypeInst = MRI->getVRegDef(OpReg);
- Register OpTypeReg =
- TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter
- ? TypeInst->getOperand(1).getReg()
- : OpReg;
+ Register OpTypeReg = getTypeReg(MRI, OpReg);
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
return;
@@ -134,30 +171,36 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
return;
// There is a type mismatch between results and operand types
// and we insert a bitcast before the instruction to keep SPIR-V code valid
- SPIRV::StorageClass::StorageClass SC =
- static_cast<SPIRV::StorageClass::StorageClass>(
- OpType->getOperand(1).getImm());
- MachineIRBuilder MIB(I);
- SPIRVType *NewBaseType =
- IsSameMF ? ResType
- : GR.getOrCreateSPIRVType(
- ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, false);
- SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC);
+ SPIRVType *NewPtrType =
+ createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy);
if (!GR.isBitcastCompatible(NewPtrType, OpType))
report_fatal_error(
"insert validation bitcast: incompatible result and operand types");
- Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
- bool Res = MIB.buildInstr(SPIRV::OpBitcast)
- .addDef(NewReg)
- .addUse(GR.getSPIRVTypeID(NewPtrType))
- .addUse(OpReg)
- .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(),
- *STI.getRegBankInfo());
- if (!Res)
- report_fatal_error("insert validation bitcast: cannot constrain all uses");
- MRI->setRegClass(NewReg, &SPIRV::IDRegClass);
- GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF());
- I.getOperand(OpIdx).setReg(NewReg);
+ doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
+}
+
+// Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer
+// that doesn't point to OpTypeEvent.
+static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI,
+ MachineRegisterInfo *MRI,
+ SPIRVGlobalRegistry &GR,
+ MachineInstr &I) {
+ constexpr unsigned OpIdx = 2;
+ MachineFunction *MF = I.getParent()->getParent();
+ Register OpReg = I.getOperand(OpIdx).getReg();
+ Register OpTypeReg = getTypeReg(MRI, OpReg);
+ SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF);
+ if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer)
+ return;
+ SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg());
+ if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent)
+ return;
+ // Insert a bitcast before the instruction to keep SPIR-V code valid.
+ LLVMContext &Context = MF->getMMI().getModule()->getContext();
+ SPIRVType *NewPtrType =
+ createNewPtrType(GR, I, OpType, false, true, nullptr,
+ TargetExtType::get(Context, "spirv.Event"));
+ doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType);
}
// Insert a bitcast before the function call instruction to keep SPIR-V code
@@ -336,6 +379,10 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;
+ case SPIRV::OpGroupWaitEvents:
+ // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent>
+ validateGroupWaitEventsPtr(STI, MRI, GR, MI);
+ break;
case SPIRV::OpConstantI: {
SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg());
if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() &&
diff --git a/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
new file mode 100644
index 0000000000000..d6fb70bb59a7e
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/event-wait-ptr-type.ll
@@ -0,0 +1,28 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#EventTy:]] = OpTypeEvent
+; CHECK: %[[#StructEventTy:]] = OpTypeStruct %[[#EventTy]]
+; CHECK: %[[#GenPtrStructEventTy:]] = OpTypePointer Generic %[[#StructEventTy]]
+; CHECK: %[[#FunPtrStructEventTy:]] = OpTypePointer Function %[[#StructEventTy]]
+; CHECK: %[[#GenPtrEventTy:]] = OpTypePointer Generic %[[#EventTy:]]
+; CHECK: OpFunction
+; CHECK: %[[#Var:]] = OpVariable %[[#FunPtrStructEventTy]] Function
+; CHECK-NEXT: %[[#AddrspacecastVar:]] = OpPtrCastToGeneric %[[#GenPtrStructEventTy]] %[[#Var]]
+; CHECK-NEXT: %[[#BitcastVar:]] = OpBitcast %[[#GenPtrEventTy]] %[[#AddrspacecastVar]]
+; CHECK-NEXT: OpGroupWaitEvents %[[#]] %[[#]] %[[#BitcastVar]]
+
+%"class.sycl::_V1::device_event" = type { target("spirv.Event") }
+
+define weak_odr dso_local spir_kernel void @foo() {
+entry:
+ %var = alloca %"class.sycl::_V1::device_event"
+ %eventptr = addrspacecast ptr %var to ptr addrspace(4)
+ call spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32 2, i32 1, ptr addrspace(4) %eventptr)
+ ret void
+}
+
+declare dso_local spir_func void @_Z23__spirv_GroupWaitEventsjiP9ocl_event(i32, i32, ptr addrspace(4))
|
michalpaszkowski
approved these changes
May 30, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR fixes invalid OpGroupWaitEvents emission to ensure that SPIR-V Backend inserts a bitcast before OpGroupWaitEvents if the last argument is a pointer that doesn't point to OpTypeEvent.