Skip to content

[SPIR-V] Fix inconsistency between previously deduced element type of a pointer and function's return type #109660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ class SPIRVEmitIntrinsics
Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
std::unordered_set<Function *> &FVisited);
void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
CallInst *AssignCI);

public:
static char ID;
Expand Down Expand Up @@ -475,10 +477,11 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
if (DemangledName.length() > 0)
DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
auto AsArgIt = ResTypeByArg.find(DemangledName);
if (AsArgIt != ResTypeByArg.end()) {
if (AsArgIt != ResTypeByArg.end())
Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
Visited, UnknownElemTypeI8);
}
else if (Type *KnownRetTy = GR->findDeducedElementType(CalledF))
Ty = KnownRetTy;
}
}

Expand Down Expand Up @@ -808,6 +811,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
CallInst *PtrCastI =
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
I->setOperand(OpIt.second, PtrCastI);
buildAssignPtr(B, KnownElemTy, PtrCastI);
}
}
}
Expand Down Expand Up @@ -1706,6 +1710,26 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
return true;
}

void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy,
Type *KnownElemTy,
CallInst *AssignCI) {
updateAssignType(AssignCI, CI, PoisonValue::get(NewElemTy));
IRBuilder<> B(CI->getContext());
B.SetInsertPoint(*CI->getInsertionPointAfterDef());
B.SetCurrentDebugLocation(CI->getDebugLoc());
Type *OpTy = CI->getType();
SmallVector<Type *, 2> Types = {OpTy, OpTy};
SmallVector<Value *, 2> Args = {CI, buildMD(PoisonValue::get(KnownElemTy)),
B.getInt32(getPointerAddressSpace(OpTy))};
CallInst *PtrCasted =
B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
SmallVector<User *> Users(CI->users());
for (auto *U : Users)
if (U != AssignCI && U != PtrCasted)
U->replaceUsesOfWith(CI, PtrCasted);
buildAssignPtr(B, KnownElemTy, PtrCasted);
}

// Try to deduce a better type for pointers to untyped ptr.
bool SPIRVEmitIntrinsics::postprocessTypes() {
bool Changed = false;
Expand All @@ -1717,6 +1741,18 @@ bool SPIRVEmitIntrinsics::postprocessTypes() {
Type *KnownTy = GR->findDeducedElementType(*IB);
if (!KnownTy || !AssignCI || !isa<Instruction>(AssignCI->getArgOperand(0)))
continue;
// Try to improve the type deduced after all Functions are processed.
if (auto *CI = dyn_cast<CallInst>(*IB)) {
if (Function *CalledF = CI->getCalledFunction()) {
Type *RetElemTy = GR->findDeducedElementType(CalledF);
// Fix inconsistency between known type and function's return type.
if (RetElemTy && RetElemTy != KnownTy) {
replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI);
Changed = true;
continue;
}
}
}
Instruction *I = cast<Instruction>(AssignCI->getArgOperand(0));
for (User *U : I->users()) {
Instruction *Inst = dyn_cast<Instruction>(U);
Expand Down
17 changes: 13 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,17 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
return {Reg, GetIdOp};
}

static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
MachineBasicBlock &MBB = *Def->getParent();
MachineBasicBlock::iterator DefIt =
Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end();
// Skip all the PHI and debug instructions.
while (DefIt != MBB.end() &&
(DefIt->isPHI() || DefIt->isDebugOrPseudoInstr()))
DefIt = std::next(DefIt);
MIB.setInsertPt(MBB, DefIt);
}

// Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
// a dst of the definition, assign SPIRVType to both registers. If SpvType is
// provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
Expand All @@ -350,11 +361,9 @@ namespace llvm {
Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI) {
MachineInstr *Def = MRI.getVRegDef(Reg);
assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
MIB.setInsertPt(*Def->getParent(),
(Def->getNextNode() ? Def->getNextNode()->getIterator()
: Def->getParent()->end()));
MachineInstr *Def = MRI.getVRegDef(Reg);
setInsertPtAfterDef(MIB, Def);
SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
if (auto *RC = MRI.getRegClassOrNull(Reg)) {
Expand Down
55 changes: 55 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.

; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK: %[[#Char:]] = OpTypeInt 8 0
; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
; CHECK: %[[#Int:]] = OpTypeInt 32 0
; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
; CHECK-DAG: %[[#Casted1:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
; CHECK-DAG: %[[#Casted2:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
; CHECK: OpBranchConditional
; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted1]] %[[#]]
; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted2]] %[[#]]

define void @f0(ptr %arg) {
entry:
ret void
}

define ptr @f1() {
entry:
%p = alloca i8
store i8 8, ptr %p
ret ptr %p
}

define ptr @f2() {
entry:
%p = alloca i32
store i32 32, ptr %p
ret ptr %p
}

define ptr @foo(i1 %arg) {
entry:
%r1 = tail call ptr @f1()
%r2 = tail call ptr @f2()
br i1 %arg, label %l1, label %l2

l1:
br label %exit

l2:
br label %exit

exit:
%ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
%ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
tail call void @f0(ptr %ret)
ret ptr %ret2
}
53 changes: 53 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.

; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK: %[[#Char:]] = OpTypeInt 8 0
; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
; CHECK: %[[#Int:]] = OpTypeInt 32 0
; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
; CHECK: %[[#Casted:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]

define ptr @foo(i1 %arg) {
entry:
%r1 = tail call ptr @f1()
%r2 = tail call ptr @f2()
br i1 %arg, label %l1, label %l2

l1:
br label %exit

l2:
br label %exit

exit:
%ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
%ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
tail call void @f0(ptr %ret)
ret ptr %ret2
}

define void @f0(ptr %arg) {
entry:
ret void
}

define ptr @f1() {
entry:
%p = alloca i8
store i8 8, ptr %p
ret ptr %p
}

define ptr @f2() {
entry:
%p = alloca i32
store i32 32, ptr %p
ret ptr %p
}
Loading