Skip to content

[SPIR-V] Improve type inference in SPIR-V Backend for opaque pointers #86283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 25, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR improves type inference in SPIR-V Backend for opaque pointers, accounting or a case when there is a chain of function calls that allows to deduce formal parameter types from actual arguments. The attached test demonstrates the case.

@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR improves type inference in SPIR-V Backend for opaque pointers, accounting or a case when there is a chain of function calls that allows to deduce formal parameter types from actual arguments. The attached test demonstrates the case.


Full diff: https://github.com/llvm/llvm-project/pull/86283.diff

4 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+77-46)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+8-4)
  • (added) llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll (+52)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 458af9229ed7b1..5828db6669ff18 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -92,6 +92,9 @@ class SPIRVEmitIntrinsics
   void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
   void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
   void processParamTypes(Function *F, IRBuilder<> &B);
+  Type *deduceFunParamType(Function *F, unsigned OpIdx);
+  Type *deduceFunParamType(Function *F, unsigned OpIdx,
+                           std::unordered_set<Function *> &FVisited);
 
 public:
   static char ID;
@@ -169,6 +172,10 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
 static Type *deduceElementTypeHelper(Value *I,
                                      std::unordered_set<Value *> &Visited,
                                      DenseMap<Value *, Type *> &DeducedElTys) {
+  // allow to pass nullptr as an argument
+  if (!I)
+    return nullptr;
+
   // maybe already known
   auto It = DeducedElTys.find(I);
   if (It != DeducedElTys.end())
@@ -182,15 +189,20 @@ static Type *deduceElementTypeHelper(Value *I,
   // fallback value in case when we fail to deduce a type
   Type *Ty = nullptr;
   // look for known basic patterns of type inference
-  if (auto *Ref = dyn_cast<AllocaInst>(I))
+  if (auto *Ref = dyn_cast<AllocaInst>(I)) {
     Ty = Ref->getAllocatedType();
-  else if (auto *Ref = dyn_cast<GetElementPtrInst>(I))
+  } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
     Ty = Ref->getResultElementType();
-  else if (auto *Ref = dyn_cast<GlobalValue>(I))
+  } else if (auto *Ref = dyn_cast<GlobalValue>(I)) {
     Ty = Ref->getValueType();
-  else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I))
+  } else if (auto *Ref = dyn_cast<AddrSpaceCastInst>(I)) {
     Ty = deduceElementTypeHelper(Ref->getPointerOperand(), Visited,
                                  DeducedElTys);
+  } else if (auto *Ref = dyn_cast<BitCastInst>(I)) {
+    if (Type *Src = Ref->getSrcTy(), *Dest = Ref->getDestTy();
+        isPointerTy(Src) && isPointerTy(Dest))
+      Ty = deduceElementTypeHelper(Ref->getOperand(0), Visited, DeducedElTys);
+  }
 
   // remember the found relationship
   if (Ty)
@@ -795,61 +807,80 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
   }
 }
 
-void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
-  DenseMap<unsigned, Argument *> Args;
-  unsigned i = 0;
-  for (Argument &Arg : F->args()) {
-    if (isUntypedPointerTy(Arg.getType()) &&
-        DeducedElTys.find(&Arg) == DeducedElTys.end() &&
-        !HasPointeeTypeAttr(&Arg))
-      Args[i] = &Arg;
-    i++;
-  }
-  if (Args.size() == 0)
-    return;
+Type *SPIRVEmitIntrinsics::deduceFunParamType(Function *F, unsigned OpIdx) {
+  std::unordered_set<Function *> FVisited;
+  return deduceFunParamType(F, OpIdx, FVisited);
+}
+
+Type *SPIRVEmitIntrinsics::deduceFunParamType(
+    Function *F, unsigned OpIdx, std::unordered_set<Function *> &FVisited) {
+  // maybe a cycle
+  if (FVisited.find(F) != FVisited.end())
+    return nullptr;
+  FVisited.insert(F);
 
-  // Args contains opaque pointers without element type definition
-  B.SetInsertPointPastAllocas(F);
   std::unordered_set<Value *> Visited;
+  SmallVector<std::pair<Function *, unsigned>> Lookup;
+  // search in function's call sites
   for (User *U : F->users()) {
     CallInst *CI = dyn_cast<CallInst>(U);
-    if (!CI)
+    if (!CI || OpIdx >= CI->arg_size())
       continue;
-    for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
-         OpIdx++) {
-      auto It = Args.find(OpIdx);
-      Argument *Arg = It == Args.end() ? nullptr : It->second;
-      if (!Arg)
-        continue;
-      Value *OpArg = CI->getArgOperand(OpIdx);
-      if (!isPointerTy(OpArg->getType()))
+    Value *OpArg = CI->getArgOperand(OpIdx);
+    if (!isPointerTy(OpArg->getType()))
+      continue;
+    // maybe we already know operand's element type
+    if (auto It = DeducedElTys.find(OpArg); It != DeducedElTys.end())
+      return It->second;
+    // search in actual parameter's users
+    for (User *OpU : OpArg->users()) {
+      Instruction *Inst = dyn_cast<Instruction>(OpU);
+      if (!Inst || Inst == CI)
         continue;
-      // maybe we already know the operand's element type
-      auto DeducedIt = DeducedElTys.find(OpArg);
-      Type *ElemTy =
-          DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
-      if (!ElemTy) {
-        for (User *OpU : OpArg->users()) {
-          if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
-            Visited.clear();
-            ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
-            if (ElemTy)
-              break;
-          }
-        }
+      Visited.clear();
+      if (Type *Ty = deduceElementTypeHelper(Inst, Visited, DeducedElTys))
+        return Ty;
+    }
+    // check if it's a formal parameter of the outer function
+    if (!CI->getParent() || !CI->getParent()->getParent())
+      continue;
+    Function *OuterF = CI->getParent()->getParent();
+    if (FVisited.find(OuterF) != FVisited.end())
+      continue;
+    for (unsigned i = 0; i < OuterF->arg_size(); ++i) {
+      if (OuterF->getArg(i) == OpArg) {
+        Lookup.push_back(std::make_pair(OuterF, i));
+        break;
       }
-      if (ElemTy) {
-        unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
+    }
+  }
+
+  // search in function parameters
+  for (auto &Pair : Lookup) {
+    if (Type *Ty = deduceFunParamType(Pair.first, Pair.second, FVisited))
+      return Ty;
+  }
+
+  return nullptr;
+}
+
+void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
+  B.SetInsertPointPastAllocas(F);
+  DenseMap<Argument *, Type *> Args;
+  for (unsigned OpIdx = 0; OpIdx < F->arg_size(); ++OpIdx) {
+    Argument *Arg = F->getArg(OpIdx);
+    if (isUntypedPointerTy(Arg->getType()) &&
+        DeducedElTys.find(Arg) == DeducedElTys.end() &&
+        !HasPointeeTypeAttr(Arg)) {
+      if (Type *ElemTy = deduceFunParamType(F, OpIdx)) {
         CallInst *AssignPtrTyCI = buildIntrWithMD(
             Intrinsic::spv_assign_ptr_type, {Arg->getType()},
-            Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
+            Constant::getNullValue(ElemTy), Arg,
+            {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
         DeducedElTys[AssignPtrTyCI] = ElemTy;
         DeducedElTys[Arg] = ElemTy;
-        Args.erase(It);
       }
     }
-    if (Args.size() == 0)
-      break;
   }
 }
 
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..ee52163a5d127f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -479,6 +479,7 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(
     GVar = M->getGlobalVariable(Name);
     if (GVar == nullptr) {
       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
+      // Module takes ownership of the global var.
       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
                                 GlobalValue::ExternalLinkage, nullptr,
                                 Twine(Name));
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5bb8f6084f9671..39228e2196b3af 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -499,6 +499,7 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
     assert(I.getOperand(1).isReg() && I.getOperand(2).isReg());
     Register GV = I.getOperand(1).getReg();
     MachineRegisterInfo::def_instr_iterator II = MRI->def_instr_begin(GV);
+    (void)II;
     assert(((*II).getOpcode() == TargetOpcode::G_GLOBAL_VALUE ||
             (*II).getOpcode() == TargetOpcode::COPY ||
             (*II).getOpcode() == SPIRV::OpVariable) &&
@@ -771,10 +772,13 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
     SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
         ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
     // TODO: check if we have such GV, add init, use buildGlobalVariable.
-    Type *LLVMArrTy = ArrayType::get(
-        IntegerType::get(GR.CurMF->getFunction().getContext(), 8), Num);
-    GlobalVariable *GV =
-        new GlobalVariable(LLVMArrTy, true, GlobalValue::InternalLinkage);
+    Function &CurFunction = GR.CurMF->getFunction();
+    Type *LLVMArrTy =
+        ArrayType::get(IntegerType::get(CurFunction.getContext(), 8), Num);
+    // Module takes ownership of the global var.
+    GlobalVariable *GV = new GlobalVariable(*CurFunction.getParent(), LLVMArrTy,
+                                            true, GlobalValue::InternalLinkage,
+                                            Constant::getNullValue(LLVMArrTy));
     Register VarReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
     GR.add(GV, GR.CurMF, VarReg);
 
diff --git a/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
new file mode 100644
index 00000000000000..703f1e22a0321a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/type-deduce-by-call-chain.ll
@@ -0,0 +1,52 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: OpName %[[ArgCum:.*]] "_arg_cum"
+; CHECK-SPIRV-DAG: OpName %[[FunTest:.*]] "test"
+; CHECK-SPIRV-DAG: OpName %[[Addr:.*]] "addr"
+; CHECK-SPIRV-DAG: OpName %[[StubObj:.*]] "stub_object"
+; CHECK-SPIRV-DAG: OpName %[[MemOrder:.*]] "mem_order"
+; CHECK-SPIRV-DAG: OpName %[[FooStub:.*]] "foo_stub"
+; CHECK-SPIRV-DAG: OpName %[[FooObj:.*]] "foo_object"
+; CHECK-SPIRV-DAG: OpName %[[FooMemOrder:.*]] "mem_order"
+; CHECK-SPIRV-DAG: OpName %[[FooFunc:.*]] "foo"
+; CHECK-SPIRV-DAG: %[[TyLong:.*]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[TyVoid:.*]] = OpTypeVoid
+; CHECK-SPIRV-DAG: %[[TyPtrLong:.*]] = OpTypePointer CrossWorkgroup %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyFunPtrLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtrLong]]
+; CHECK-SPIRV-DAG: %[[TyGenPtrLong:.*]] = OpTypePointer Generic %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[TyFunGenPtrLongLong:.*]] = OpTypeFunction %[[TyVoid]] %[[TyGenPtrLong]] %[[TyLong]]
+; CHECK-SPIRV-DAG: %[[Const3:.*]] = OpConstant %[[TyLong]] 3
+; CHECK-SPIRV: %[[FunTest]] = OpFunction %[[TyVoid]] None %[[TyFunPtrLong]]
+; CHECK-SPIRV: %[[ArgCum]] = OpFunctionParameter %[[TyPtrLong]]
+; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooFunc]] %[[Addr]] %[[Const3]]
+; CHECK-SPIRV: %[[FooStub]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
+; CHECK-SPIRV: %[[StubObj]] = OpFunctionParameter %[[TyGenPtrLong]]
+; CHECK-SPIRV: %[[MemOrder]] = OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV: %[[FooFunc]] = OpFunction %[[TyVoid]] None %[[TyFunGenPtrLongLong]]
+; CHECK-SPIRV: %[[FooObj]] = OpFunctionParameter %[[TyGenPtrLong]]
+; CHECK-SPIRV: %[[FooMemOrder]] = OpFunctionParameter %[[TyLong]]
+; CHECK-SPIRV: OpFunctionCall %[[TyVoid]] %[[FooStub]] %[[FooObj]] %[[FooMemOrder]]
+
+define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_cum) {
+entry:
+  %lptr = getelementptr inbounds i32, ptr addrspace(1) %_arg_cum, i64 1
+  %addr = addrspacecast ptr addrspace(1) %lptr to ptr addrspace(4)
+  %object = bitcast ptr addrspace(4) %addr to ptr addrspace(4)
+  call spir_func void @foo(ptr addrspace(4) %object, i32 3)
+  ret void
+}
+
+define void @foo_stub(ptr addrspace(4) noundef %stub_object, i32 noundef %mem_order) {
+entry:
+  %object.addr = alloca ptr addrspace(4)
+  %object.addr.ascast = addrspacecast ptr %object.addr to ptr addrspace(4)
+  store ptr addrspace(4) %stub_object, ptr addrspace(4) %object.addr.ascast
+  ret void
+}
+
+define void @foo(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order) {
+  tail call void @foo_stub(ptr addrspace(4) noundef %foo_object, i32 noundef %mem_order)
+  ret void
+}
+

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 1d250d9 into llvm:main Mar 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants