Skip to content

[SPIR-V] Improve type inference: fix types of return values in call lowering #116609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Nov 29, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Nov 18, 2024

Goals of the PR are:

  • to ensure that correct types are applied to virtual registers which were used as return values in call lowering. A reproducer is attached as a new test case, before the PR it fails because spirv-val considers output invalid due to wrong result/operand types in OpPhi's;
  • improve type inference by speeding up postprocessing of types: by limiting iterations by checking what remains to process, and processing each instruction just once for any number of operands with uncomplete types;
  • improve type inference by more accurate work with uncomplete types (pass uncomplete property to dependent operands, ensure consistency of uncomplete-types data structure);
  • change processing order and add traversing of PHI nodes when type inference apply instructions results to specify/update/cast operands type (fixes an issue with OpPhi's result type mismatch with operand types).

@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as ready for review November 20, 2024 13:08
@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

Goals of the PR are:

  • to ensure that correct types are applied to virtual registers which were used as return values in call lowering. A reproducer is attached as a new test case, before the PR it fails because spirv-val considers output invalid due to wrong result/operand types in OpPhi's;
  • improve type inference by speeding up postprocessing of types: by limiting iterations by checking what remains to process, and processing each instruction just once for any number of operands with uncomplete types;
  • improve type inference by more accurate work with uncomplete types (pass uncomplete property to dependent operands, ensure consistency of uncomplete-types data structure);
  • change processing order and add traversing of PHI nodes when type inference apply instructions results to specify/update/cast operands type (fixes an issue with OpPhi's result type mismatch with operand types).

Patch is 46.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116609.diff

14 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+6-22)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp (+17)
  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+187-93)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+2-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp (+1-3)
  • (modified) llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp (+3-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPostLegalizer.cpp (+2-7)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.cpp (+49)
  • (modified) llvm/lib/Target/SPIRV/SPIRVUtils.h (+16)
  • (modified) llvm/test/CodeGen/SPIRV/extensions/SPV_INTEL_function_pointers/fp_two_calls.ll (+7-3)
  • (added) llvm/test/CodeGen/SPIRV/pointers/builtin-ret-reg-type.ll (+55)
  • (added) llvm/test/CodeGen/SPIRV/pointers/phi-chain-types.ll (+82)
  • (modified) llvm/test/CodeGen/SPIRV/transcoding/OpGenericCastToPtr.ll (-2)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index 06a37f1f559d44..bed34b83d2e546 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -447,12 +447,8 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
                               MachineIRBuilder &MIRBuilder,
                               SPIRVGlobalRegistry *GR, LLT LowLevelType,
                               Register DestinationReg = Register(0)) {
-  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
-  if (!DestinationReg.isValid()) {
-    DestinationReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
-    MRI->setType(DestinationReg, LLT::scalar(64));
-    GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
-  }
+  if (!DestinationReg.isValid())
+    DestinationReg = createVirtualRegister(BaseType, GR, MIRBuilder);
   // TODO: consider using correct address space and alignment (p0 is canonical
   // type for selection though).
   MachinePointerInfo PtrInfo = MachinePointerInfo();
@@ -2129,7 +2125,7 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
     const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
         Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
     for (unsigned I = 0; I < LocalSizeNum; ++I) {
-      Register Reg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
+      Register Reg = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
       MRI->setType(Reg, LLType);
       GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
       auto GEPInst = MIRBuilder.buildIntrinsic(
@@ -2517,23 +2513,11 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
                                  SPIRVGlobalRegistry *GR) {
   LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
 
-  // SPIR-V type and return register.
-  Register ReturnRegister = OrigRet;
-  SPIRVType *ReturnType = nullptr;
-  if (OrigRetTy && !OrigRetTy->isVoidTy()) {
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-    if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
-      MIRBuilder.getMRI()->setRegClass(ReturnRegister,
-                                       GR->getRegClass(ReturnType));
-  } else if (OrigRetTy && OrigRetTy->isVoidTy()) {
-    ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
-    MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
-    ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
-  }
-
   // Lookup the builtin in the TableGen records.
+  SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
+  assert(SpvType && "Inconsistent return register: expected valid type info");
   std::unique_ptr<const IncomingCall> Call =
-      lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
+      lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
 
   if (!Call) {
     LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");
diff --git a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
index 3c5397319aaf21..3fdaa6aa3257ea 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
@@ -539,6 +539,23 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
 
   if (isFunctionDecl && !DemangledName.empty() &&
       (canUseGLSL || canUseOpenCL)) {
+    if (ResVReg.isValid()) {
+      if (!GR->getSPIRVTypeForVReg(ResVReg)) {
+        const Type *RetTy = OrigRetTy;
+        if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
+          const Value *OrigValue = Info.OrigRet.OrigValue;
+          if (!OrigValue)
+            OrigValue = Info.CB;
+          if (OrigValue)
+            if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
+              RetTy =
+                  TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
+        }
+        setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
+      }
+    } else {
+      ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
+    }
     SmallVector<Register, 8> ArgVRegs;
     for (auto Arg : Info.OrigArgs) {
       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index e6ef40e010dc20..7460e0a71aae51 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -67,7 +67,7 @@ class SPIRVEmitIntrinsics
       public InstVisitor<SPIRVEmitIntrinsics, Instruction *> {
   SPIRVTargetMachine *TM = nullptr;
   SPIRVGlobalRegistry *GR = nullptr;
-  Function *F = nullptr;
+  Function *CurrF = nullptr;
   bool TrackConstants = true;
   bool HaveFunPtrs = false;
   DenseMap<Instruction *, Constant *> AggrConsts;
@@ -76,8 +76,27 @@ class SPIRVEmitIntrinsics
   SPIRV::InstructionSet::InstructionSet InstrSet;
 
   // a register of Instructions that don't have a complete type definition
-  DenseMap<Value *, unsigned> UncompleteTypeInfo;
-  SmallVector<Value *> PostprocessWorklist;
+  bool CanTodoType = true;
+  unsigned TodoTypeSz = 0;
+  DenseMap<Value *, bool> TodoType;
+  void insertTodoType(Value *Op) {
+    if (CanTodoType) {
+      auto It = TodoType.try_emplace(Op, true);
+      if (It.second)
+        ++TodoTypeSz;
+    }
+  }
+  void eraseTodoType(Value *Op) {
+    auto It = TodoType.find(Op);
+    if (It != TodoType.end() && It->second) {
+      TodoType[Op] = false;
+      --TodoTypeSz;
+    }
+  }
+  bool isTodoType(Value *Op) {
+    auto It = TodoType.find(Op);
+    return It != TodoType.end() && It->second;
+  }
 
   // well known result types of builtins
   enum WellKnownTypes { Event };
@@ -105,8 +124,9 @@ class SPIRVEmitIntrinsics
                                bool UnknownElemTypeI8);
 
   // deduce Types of operands of the Instruction if possible
-  void deduceOperandElementType(Instruction *I, Instruction *AskOp = 0,
-                                Type *AskTy = 0, CallInst *AssignCI = 0);
+  void deduceOperandElementType(Instruction *I,
+                                const SmallPtrSet<Value *, 4> *AskOps = nullptr,
+                                bool IsPostprocessing = false);
 
   void preprocessCompositeConstants(IRBuilder<> &B);
   void preprocessUndefs(IRBuilder<> &B);
@@ -145,12 +165,20 @@ class SPIRVEmitIntrinsics
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
                                   std::unordered_set<Function *> &FVisited);
+
+  bool deduceOperandElementTypeCalledFunction(
+      SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
+      SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy);
+  void deduceOperandElementTypeFunctionPointer(
+      CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
+      Type *&KnownElemTy, bool IsPostprocessing);
+
   void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
                             CallInst *AssignCI);
   void replaceAllUsesWith(Value *Src, Value *Dest, bool DeleteOld = true);
 
   bool runOnFunction(Function &F);
-  bool postprocessTypes();
+  bool postprocessTypes(Module &M);
   bool processFunctionPointers(Module &M);
 
 public:
@@ -280,17 +308,10 @@ void SPIRVEmitIntrinsics::replaceAllUsesWith(Value *Src, Value *Dest,
   GR->updateIfExistDeducedElementType(Src, Dest, DeleteOld);
   GR->updateIfExistAssignPtrTypeInstr(Src, Dest, DeleteOld);
   // Update uncomplete type records if any
-  auto It = UncompleteTypeInfo.find(Src);
-  if (It == UncompleteTypeInfo.end())
-    return;
-  if (DeleteOld) {
-    unsigned Pos = It->second;
-    UncompleteTypeInfo.erase(Src);
-    UncompleteTypeInfo[Dest] = Pos;
-    PostprocessWorklist[Pos] = Dest;
-  } else {
-    UncompleteTypeInfo[Dest] = PostprocessWorklist.size();
-    PostprocessWorklist.push_back(Dest);
+  if (isTodoType(Src)) {
+    if (DeleteOld)
+      eraseTodoType(Src);
+    insertTodoType(Dest);
   }
 }
 
@@ -354,7 +375,7 @@ void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy,
   Value *OfType = PoisonValue::get(ElemTy);
   CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg);
   if (AssignPtrTyCI == nullptr ||
-      AssignPtrTyCI->getParent()->getParent() != F) {
+      AssignPtrTyCI->getParent()->getParent() != CurrF) {
     AssignPtrTyCI = buildIntrWithMD(
         Intrinsic::spv_assign_ptr_type, {Arg->getType()}, OfType, Arg,
         {B.getInt32(getPointerAddressSpace(Arg->getType()))}, B);
@@ -455,10 +476,7 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
   if (isUntypedPointerTy(RefTy)) {
     if (!UnknownElemTypeI8)
       return;
-    if (auto *I = dyn_cast<Instruction>(Op)) {
-      UncompleteTypeInfo[I] = PostprocessWorklist.size();
-      PostprocessWorklist.push_back(I);
-    }
+    insertTodoType(Op);
   }
   Ty = RefTy;
 }
@@ -661,10 +679,7 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) {
     return Ty;
   if (!UnknownElemTypeI8)
     return nullptr;
-  if (auto *Instr = dyn_cast<Instruction>(I)) {
-    UncompleteTypeInfo[Instr] = PostprocessWorklist.size();
-    PostprocessWorklist.push_back(Instr);
-  }
+  insertTodoType(I);
   return IntegerType::getInt8Ty(I->getContext());
 }
 
@@ -683,8 +698,7 @@ static inline Type *getAtomicElemTy(SPIRVGlobalRegistry *GR, Instruction *I,
 
 // Try to deduce element type for a call base. Returns false if this is an
 // indirect function invocation, and true otherwise.
-static bool deduceOperandElementTypeCalledFunction(
-    SPIRVGlobalRegistry *GR, Instruction *I,
+bool SPIRVEmitIntrinsics::deduceOperandElementTypeCalledFunction(
     SPIRV::InstructionSet::InstructionSet InstrSet, CallInst *CI,
     SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
   Function *CalledF = CI->getCalledFunction();
@@ -726,7 +740,7 @@ static bool deduceOperandElementTypeCalledFunction(
       case SPIRV::OpAtomicUMax:
       case SPIRV::OpAtomicSMin:
       case SPIRV::OpAtomicSMax: {
-        KnownElemTy = getAtomicElemTy(GR, I, Op);
+        KnownElemTy = getAtomicElemTy(GR, CI, Op);
         if (!KnownElemTy)
           return true;
         Ops.push_back(std::make_pair(Op, 0));
@@ -738,32 +752,44 @@ static bool deduceOperandElementTypeCalledFunction(
 }
 
 // Try to deduce element type for a function pointer.
-static void deduceOperandElementTypeFunctionPointer(
-    SPIRVGlobalRegistry *GR, Instruction *I, CallInst *CI,
-    SmallVector<std::pair<Value *, unsigned>> &Ops, Type *&KnownElemTy) {
+void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
+    CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
+    Type *&KnownElemTy, bool IsPostprocessing) {
   Value *Op = CI->getCalledOperand();
   if (!Op || !isPointerTy(Op->getType()))
     return;
   Ops.push_back(std::make_pair(Op, std::numeric_limits<unsigned>::max()));
   FunctionType *FTy = CI->getFunctionType();
-  bool IsNewFTy = false;
+  bool IsNewFTy = false, IsUncomplete = false;
   SmallVector<Type *, 4> ArgTys;
   for (Value *Arg : CI->args()) {
     Type *ArgTy = Arg->getType();
-    if (ArgTy->isPointerTy())
+    if (ArgTy->isPointerTy()) {
       if (Type *ElemTy = GR->findDeducedElementType(Arg)) {
         IsNewFTy = true;
         ArgTy = TypedPointerType::get(ElemTy, getPointerAddressSpace(ArgTy));
+        if (isTodoType(Arg))
+          IsUncomplete = true;
+      } else {
+        IsUncomplete = true;
       }
+    }
     ArgTys.push_back(ArgTy);
   }
   Type *RetTy = FTy->getReturnType();
-  if (I->getType()->isPointerTy())
-    if (Type *ElemTy = GR->findDeducedElementType(I)) {
+  if (CI->getType()->isPointerTy()) {
+    if (Type *ElemTy = GR->findDeducedElementType(CI)) {
       IsNewFTy = true;
       RetTy =
-          TypedPointerType::get(ElemTy, getPointerAddressSpace(I->getType()));
+          TypedPointerType::get(ElemTy, getPointerAddressSpace(CI->getType()));
+      if (isTodoType(CI))
+        IsUncomplete = true;
+    } else {
+      IsUncomplete = true;
     }
+  }
+  if (!IsPostprocessing && IsUncomplete)
+    insertTodoType(Op);
   KnownElemTy =
       IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
 }
@@ -772,17 +798,18 @@ static void deduceOperandElementTypeFunctionPointer(
 // tries to deduce them. If the Instruction has Pointer operands with known
 // types which differ from expected, this function tries to insert a bitcast to
 // resolve the issue.
-void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
-                                                   Instruction *AskOp,
-                                                   Type *AskTy,
-                                                   CallInst *AskCI) {
+void SPIRVEmitIntrinsics::deduceOperandElementType(
+    Instruction *I, const SmallPtrSet<Value *, 4> *AskOps,
+    bool IsPostprocessing) {
   SmallVector<std::pair<Value *, unsigned>> Ops;
   Type *KnownElemTy = nullptr;
+  bool Uncomplete = false;
   // look for known basic patterns of type inference
   if (auto *Ref = dyn_cast<PHINode>(I)) {
     if (!isPointerTy(I->getType()) ||
         !(KnownElemTy = GR->findDeducedElementType(I)))
       return;
+    Uncomplete = isTodoType(I);
     for (unsigned i = 0; i < Ref->getNumIncomingValues(); i++) {
       Value *Op = Ref->getIncomingValue(i);
       if (isPointerTy(Op->getType()))
@@ -792,6 +819,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
     KnownElemTy = GR->findDeducedElementType(I);
     if (!KnownElemTy)
       return;
+    Uncomplete = isTodoType(I);
     Ops.push_back(std::make_pair(Ref->getPointerOperand(), 0));
   } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
     KnownElemTy = Ref->getSourceElementType();
@@ -837,27 +865,29 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
     if (!isPointerTy(I->getType()) ||
         !(KnownElemTy = GR->findDeducedElementType(I)))
       return;
+    Uncomplete = isTodoType(I);
     for (unsigned i = 0; i < Ref->getNumOperands(); i++) {
       Value *Op = Ref->getOperand(i);
       if (isPointerTy(Op->getType()))
         Ops.push_back(std::make_pair(Op, i));
     }
   } else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
-    Type *RetTy = F->getReturnType();
+    Type *RetTy = CurrF->getReturnType();
     if (!isPointerTy(RetTy))
       return;
     Value *Op = Ref->getReturnValue();
     if (!Op)
       return;
-    if (!(KnownElemTy = GR->findDeducedElementType(F))) {
+    if (!(KnownElemTy = GR->findDeducedElementType(CurrF))) {
       if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
-        GR->addDeducedElementType(F, OpElemTy);
+        GR->addDeducedElementType(CurrF, OpElemTy);
         TypedPointerType *DerivedTy =
             TypedPointerType::get(OpElemTy, getPointerAddressSpace(RetTy));
-        GR->addReturnType(F, DerivedTy);
+        GR->addReturnType(CurrF, DerivedTy);
       }
       return;
     }
+    Uncomplete = isTodoType(CurrF);
     Ops.push_back(std::make_pair(Op, 0));
   } else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
     if (!isPointerTy(Ref->getOperand(0)->getType()))
@@ -868,37 +898,52 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
     Type *ElemTy1 = GR->findDeducedElementType(Op1);
     if (ElemTy0) {
       KnownElemTy = ElemTy0;
+      Uncomplete = isTodoType(Op0);
       Ops.push_back(std::make_pair(Op1, 1));
     } else if (ElemTy1) {
       KnownElemTy = ElemTy1;
+      Uncomplete = isTodoType(Op1);
       Ops.push_back(std::make_pair(Op0, 0));
     }
   } else if (CallInst *CI = dyn_cast<CallInst>(I)) {
     if (!CI->isIndirectCall())
-      deduceOperandElementTypeCalledFunction(GR, I, InstrSet, CI, Ops,
-                                             KnownElemTy);
+      deduceOperandElementTypeCalledFunction(InstrSet, CI, Ops, KnownElemTy);
     else if (HaveFunPtrs)
-      deduceOperandElementTypeFunctionPointer(GR, I, CI, Ops, KnownElemTy);
+      deduceOperandElementTypeFunctionPointer(CI, Ops, KnownElemTy,
+                                              IsPostprocessing);
   }
 
   // There is no enough info to deduce types or all is valid.
   if (!KnownElemTy || Ops.size() == 0)
     return;
 
-  LLVMContext &Ctx = F->getContext();
+  LLVMContext &Ctx = CurrF->getContext();
   IRBuilder<> B(Ctx);
   for (auto &OpIt : Ops) {
     Value *Op = OpIt.first;
-    if (Op->use_empty() || (AskOp && Op != AskOp))
+    if (Op->use_empty())
       continue;
-    Type *Ty = AskOp ? AskTy : GR->findDeducedElementType(Op);
+    if (AskOps && !AskOps->contains(Op))
+      continue;
+    Type *AskTy = nullptr;
+    CallInst *AskCI = nullptr;
+    if (IsPostprocessing && AskOps) {
+      AskTy = GR->findDeducedElementType(Op);
+      AskCI = GR->findAssignPtrTypeInstr(Op);
+      assert(AskTy && AskCI);
+    }
+    Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op);
     if (Ty == KnownElemTy)
       continue;
     Value *OpTyVal = PoisonValue::get(KnownElemTy);
     Type *OpTy = Op->getType();
-    if (!Ty || AskTy || isUntypedPointerTy(Ty) ||
-        UncompleteTypeInfo.contains(Op)) {
+    if (!Ty || AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)) {
       GR->addDeducedElementType(Op, KnownElemTy);
+      // check if KnownElemTy is complete
+      if (!Uncomplete)
+        eraseTodoType(Op);
+      else if (!IsPostprocessing)
+        insertTodoType(Op);
       // check if there is existing Intrinsic::spv_assign_ptr_type instruction
       CallInst *AssignCI = AskCI ? AskCI : GR->findAssignPtrTypeInstr(Op);
       if (AssignCI == nullptr) {
@@ -912,6 +957,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
         updateAssignType(AssignCI, Op, OpTyVal);
       }
     } else {
+      eraseTodoType(Op);
       if (auto *OpI = dyn_cast<Instruction>(Op)) {
         // spv_ptrcast's argument Op denotes an instruction that generates
         // a value, and we may use getInsertionPointAfterDef()
@@ -921,7 +967,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
         B.SetInsertPointPastAllocas(OpA->getParent());
         B.SetCurrentDebugLocation(DebugLoc());
       } else {
-        B.SetInsertPoint(F->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
+        B.SetInsertPoint(CurrF->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
       }
       SmallVector<Type *, 2> Types = {OpTy, OpTy};
       SmallVector<Value *, 2> Args = {Op, buildMD(OpTyVal),
@@ -961,7 +1007,7 @@ void SPIRVEmitIntrinsics::replaceMemInstrUses(Instruction *Old,
 
 void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
   std::queue<Instruction *> Worklist;
-  for (auto &I : instructions(F))
+  for (auto &I : instructions(CurrF))
     Worklist.push(&I);
 
   while (!Worklist.empty()) {
@@ -989,7 +1035,7 @@ void SPIRVEmitIntrinsics::preprocessUndefs(IRBuilder<> &B) {
 
 void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
   std::queue<Instruction *> Worklist;
-  for (auto &I : instructions(F))
+  for (auto &I : instructions(CurrF))
     Worklist.push(&I);
 
   while (!Worklist.empty()) {
@@ -1048,7 +1094,7 @@ Instruction *SPIRVEmitIntrinsics::visitCallInst(CallInst &Call) {
     return &Call;
 
   const InlineAsm *IA = cast<InlineAsm>(Call.getCalledOperand());
-  LLVMContext &Ctx = F->getContext();
+  LLVMContext &Ctx = CurrF->getContext();
 
   Constant *TyC = UndefValue::get(IA->getFunctionType());
   MDString *ConstraintString = MDString::get(Ctx, IA->getConstraintString());
@@ -1249,10 +1295,10 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
                                                          IRBuilder<> &B) {
   // Handle basic instructions:
   StoreInst *SI = dyn_cast<StoreInst>(I);
-  if (IsKernelArgInt8(F, SI)) {
+  if (IsKernelArgInt8(CurrF, SI)) {
     return replacePointerOperandWithPtrCast(
-        I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0,
-        B);
+        I, SI->getValueOperand(), IntegerType::getInt8Ty(CurrF->getContext()),
+        0, B);
   } else if (SI) {
     Value *Op = SI->getValueOperand();
     Type *OpTy = Op->getType();
@@ -1419,7 +1465,7 @@ Instruction *SPIRVEmitIntrinsics::visitLoadInst(LoadInst &I) {
   TrackConstants = false;
   ...
[truncated]

@VyacheslavLevytskyy VyacheslavLevytskyy force-pushed the fix_phi_gep_1811 branch 2 times, most recently from a2f92f5 to 6d0c311 Compare November 27, 2024 14:54
Copy link

github-actions bot commented Nov 27, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@VyacheslavLevytskyy
Copy link
Contributor Author

One test case that fails is not related to this PR and wasn't changed in the PR. The fails is about spirv-val complaining that debug-info/debug-type-basic.ll produces invalid SPIR-V code:

+ /mnt/build/bin/spirv-val
+ /mnt/build/bin/llc --verify-machineinstrs --spv-emit-nonsemantic-debug-info --spirv-ext=+SPV_KHR_non_semantic_info -O0 -mtriple=spirv64-unknown-unknown /__w/llvm-project/llvm-project/llvm/test/CodeGen/SPIRV/debug-info/debug-type-basic.ll -o - -filetype=obj
error: line 86: NonSemantic.Shader.DebugInfo.100 DebugTypeBasic: expected operand Flags must be a result id of 32-bit unsigned OpConstant
  %52 = OpExtInst %void %49 DebugTypeBasic %51 %uint_8 %uint_2 %20

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for working on this! LGTM!

%offset.addr.ascast = addrspacecast ptr %offset.addr to ptr addrspace(4)
store ptr addrspace(4) %offset, ptr addrspace(4) %offset.addr.ascast, align 8
ret void
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: missing new line at the end of file

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit b5132b7 into llvm:main Nov 29, 2024
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants