Skip to content

[SPIR-V] Improve type inference in SPIR-V Backend for opaque pointers #86283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 77 additions & 46 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ class SPIRVEmitIntrinsics
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
void processParamTypes(Function *F, IRBuilder<> &B);
Type *deduceFunParamType(Function *F, unsigned OpIdx);
Type *deduceFunParamType(Function *F, unsigned OpIdx,
std::unordered_set<Function *> &FVisited);

public:
static char ID;
Expand Down Expand Up @@ -169,6 +172,10 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
static Type *deduceElementTypeHelper(Value *I,
std::unordered_set<Value *> &Visited,
DenseMap<Value *, Type *> &DeducedElTys) {
// allow to pass nullptr as an argument
if (!I)
return nullptr;

// maybe already known
auto It = DeducedElTys.find(I);
if (It != DeducedElTys.end())
Expand All @@ -182,15 +189,20 @@ static Type *deduceElementTypeHelper(Value *I,
// fallback value in case when we fail to deduce a type
Type *Ty = nullptr;
// look for known basic patterns of type inference
if (auto *Ref = dyn_cast<AllocaInst>(I))
if (auto *Ref = dyn_cast<AllocaInst>(I)) {
Ty = Ref->getAllocatedType();
else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
} else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
Ty = Ref->getResultElementType();
else if (auto *Ref = dyn_cast<GlobalValue>(I))
} else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
Ty = Ref->getValueType();
else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
} else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
DeducedElTys);
} else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
isPointerTy(Src) && isPointerTy(Dest))
Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
}

// remember the found relationship
if (Ty)
Expand Down Expand Up @@ -795,61 +807,80 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
}
}

void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
DenseMap<unsigned, Argument *> Args;
unsigned i = 0;
for (Argument &Arg : F->args()) {
if (isUntypedPointerTy(Arg.getType()) &&
DeducedElTys.find(&Arg) == DeducedElTys.end() &&
!HasPointeeTypeAttr(&Arg))
Args[i] = &Arg;
i++;
}
if (Args.size() == 0)
return;
Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
std::unordered_set<Function *> FVisited;
return deduceFunParamType(F, OpIdx, FVisited);
}

Type *SPIRVEmitIntrinsics::deduceFunParamType(
Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
// maybe a cycle
if (FVisited.find(F) != FVisited.end())
return nullptr;
FVisited.insert(F);

// Args contains opaque pointers without element type definition
B.SetInsertPointPastAllocas(F);
std::unordered_set<Value *> Visited;
SmallVector<std::pair<Function *, unsigned>> Lookup;
// search in function's call sites
for (User *U : F->users()) {
CallInst *CI = dyn_cast<CallInst>(U);
if (!CI)
if (!CI || OpIdx >= CI->arg_size())
continue;
for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
OpIdx++) {
auto It = Args.find(OpIdx);
Argument *Arg = It == Args.end() ? nullptr : It->second;
if (!Arg)
continue;
Value *OpArg = CI->getArgOperand(OpIdx);
if (!isPointerTy(OpArg->getType()))
Value *OpArg = CI->getArgOperand(OpIdx);
if (!isPointerTy(OpArg->getType()))
continue;
// maybe we already know operand's element type
if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
return It->second;
// search in actual parameter's users
for (User *OpU : OpArg->users()) {
Instruction *Inst = dyn_cast<Instruction>(OpU);
if (!Inst || Inst == CI)
continue;
// maybe we already know the operand's element type
auto DeducedIt = DeducedElTys.find(OpArg);
Type *ElemTy =
DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
if (!ElemTy) {
for (User *OpU : OpArg->users()) {
if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
Visited.clear();
ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
if (ElemTy)
break;
}
}
Visited.clear();
if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
return Ty;
}
// check if it's a formal parameter of the outer function
if (!CI->getParent() || !CI->getParent()->getParent())
continue;
Function *OuterF = CI->getParent()->getParent();
if (FVisited.find(OuterF) != FVisited.end())
continue;
for (unsigned i = 0; i < OuterF->arg_size(); ++i) {
if (OuterF->getArg(i) == OpArg) {
Lookup.push_back(std::make_pair(OuterF, i));
break;
}
if (ElemTy) {
unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
}
}

// search in function parameters
for (auto &Pair : Lookup) {
if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
return Ty;
}

return nullptr;
}

void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
B.SetInsertPointPastAllocas(F);
DenseMap<Argument *, Type *> Args;
for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
Argument *Arg = F->getArg(OpIdx);
if (isUntypedPointerTy(Arg->getType()) &&
DeducedElTys.find(Arg) == DeducedElTys.end() &&
!HasPointeeTypeAttr(Arg)) {
if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
CallInst *AssignPtrTyCI = buildIntrWithMD(
Intrinsic::spv_assign_ptr_type, {Arg->getType()},
Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
Constant::getNullValue(ElemTy), Arg,
{B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
DeducedElTys[AssignPtrTyCI] = ElemTy;
DeducedElTys[Arg] = ElemTy;
Args.erase(It);
}
}
if (Args.size() == 0)
break;
}
}

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
GVar = M->getGlobalVariable(Name);
if (GVar == nullptr) {
const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
// Module takes ownership of the global var.
GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
GlobalValue::ExternalLinkage, nullptr,
Twine(Name));
Expand Down
12 changes: 8 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
Register GV = I.getOperand(1).getReg();
MachineRegisterInfo::def_instr_iterator II = MRI->def_instr_begin(GV);
(void)II;
assert(((*II).getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
(*II).getOpcode() == TargetOpcode::COPY ||
(*II).getOpcode() == SPIRV::OpVariable) &&
Expand Down Expand Up @@ -771,10 +772,13 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
// TODO: check if we have such GV, add init, use buildGlobalVariable.
Type *LLVMArrTy = ArrayType::get(
IntegerType::get(GR.CurMF->getFunction().getContext(), 8), Num);
GlobalVariable *GV =
new GlobalVariable(LLVMArrTy, true, GlobalValue::InternalLinkage);
Function &CurFunction = GR.CurMF->getFunction();
Type *LLVMArrTy =
ArrayType::get(IntegerType::get(CurFunction.getContext(), 8), Num);
// Module takes ownership of the global var.
GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
true, GlobalValue::InternalLinkage,
Constant::getNullValue(LLVMArrTy));
Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
GR.add(GV, GR.CurMF, VarReg);

Expand Down
52 changes: 52 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-SPIRV-DAG: OpName %[[ArgCum:.*]] "_arg_cum"
; CHECK-SPIRV-DAG: OpName %[[FunTest:.*]] "test"
; CHECK-SPIRV-DAG: OpName %[[Addr:.*]] "addr"
; CHECK-SPIRV-DAG: OpName %[[StubObj:.*]] "stub_object"
; CHECK-SPIRV-DAG: OpName %[[MemOrder:.*]] "mem_order"
; CHECK-SPIRV-DAG: OpName %[[FooStub:.*]] "foo_stub"
; CHECK-SPIRV-DAG: OpName %[[FooObj:.*]] "foo_object"
; CHECK-SPIRV-DAG: OpName %[[FooMemOrder:.*]] "mem_order"
; CHECK-SPIRV-DAG: OpName %[[FooFunc:.*]] "foo"
; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 32 0
; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer CrossWorkgroup %[[TyLong]]
; CHECK-SPIRV-DAG: %[[TyFunPtrLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrLong]]
; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
; CHECK-SPIRV-DAG: %[[TyFunGenPtrLongLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrLong]] %[[TyLong]]
; CHECK-SPIRV-DAG: %[[Const3:.*]] = OpConstant %[[TyLong]] 3
; CHECK-SPIRV: %[[FunTest]] = OpFunction %[[TyVoid]] None %[[TyFunPtrLong]]
; CHECK-SPIRV: %[[ArgCum]] = OpFunctionParameter %[[TyPtrLong]]
; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[Addr]] %[[Const3]]
; CHECK-SPIRV: %[[FooStub]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
; CHECK-SPIRV: %[[StubObj]] = OpFunctionParameter %[[TyGenPtrLong]]
; CHECK-SPIRV: %[[MemOrder]] = OpFunctionParameter %[[TyLong]]
; CHECK-SPIRV: %[[FooFunc]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
; CHECK-SPIRV: %[[FooObj]] = OpFunctionParameter %[[TyGenPtrLong]]
; CHECK-SPIRV: %[[FooMemOrder]] = OpFunctionParameter %[[TyLong]]
; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooStub]] %[[FooObj]] %[[FooMemOrder]]

define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_cum) {
entry:
%lptr = getelementptr inbounds i32, ptr addrspace(1) %_arg_cum, i64 1
%addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4)
%object = bitcast ptr addrspace(4) %addr to ptr addrspace(4)
call spir_func void @foo(ptr addrspace(4) %object, i32 3)
ret void
}

define void @foo_stub(ptr addrspace(4) noundef %stub_object, i32 noundef %mem_order) {
entry:
%object.addr = alloca ptr addrspace(4)
%object.addr.ascast = addrspacecast ptr %object.addr to ptr addrspace(4)
store ptr addrspace(4) %stub_object, ptr addrspace(4) %object.addr.ascast
ret void
}

define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) {
tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order)
ret void
}