Skip to content

[SPIR-V] Prevent type change of GEP results in type inference #129250

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 28, 2025

Conversation

VyacheslavLevytskyy
Copy link
Contributor

The following reproducer demonstrates the issue with invalid definition of GEP results during type inference

define spir_kernel void @foo(i1 %fl, i64 %idx, ptr addrspace(1) %dest, ptr addrspace(3) %src) {
  %p1 = getelementptr inbounds i8, ptr addrspace(1) %dest, i64 %idx
  %res = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %p1, ptr addrspace(3) %src, i64 128, i64 1, target("spirv.Event") zeroinitializer)
  ret void
}

declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))

Here OpGroupAsyncCopy expects i32* arguments and type inference fails to set a correct type of the GEP result %p1, because it is an argument of OpGroupAsyncCopy.

This PR fixes the issue by preventing type change of GEP results in type inference.

@llvmbot
Copy link
Member

llvmbot commented Feb 28, 2025

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

The following reproducer demonstrates the issue with invalid definition of GEP results during type inference

define spir_kernel void @<!-- -->foo(i1 %fl, i64 %idx, ptr addrspace(1) %dest, ptr addrspace(3) %src) {
  %p1 = getelementptr inbounds i8, ptr addrspace(1) %dest, i64 %idx
  %res = tail call spir_func target("spirv.Event") @<!-- -->_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %p1, ptr addrspace(3) %src, i64 128, i64 1, target("spirv.Event") zeroinitializer)
  ret void
}

declare dso_local spir_func target("spirv.Event") @<!-- -->_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))

Here OpGroupAsyncCopy expects i32* arguments and type inference fails to set a correct type of the GEP result %p1, because it is an argument of OpGroupAsyncCopy.

This PR fixes the issue by preventing type change of GEP results in type inference.


Full diff: https://github.com/llvm/llvm-project/pull/129250.diff

2 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+28-9)
  • (added) llvm/test/CodeGen/SPIRV/pointers/ptr-access-chain-type.ll (+24)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 5dfba8427258f..d6177058231d9 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -646,6 +646,20 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
   Ty = RefTy;
 }
 
+Type *getGEPType(GetElementPtrInst *Ref) {
+  Type *Ty = nullptr;
+  // TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
+  // useful here
+  if (isNestedPointer(Ref->getSourceElementType())) {
+    Ty = Ref->getSourceElementType();
+    for (Use &U : drop_begin(Ref->indices()))
+      Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
+  } else {
+    Ty = Ref->getResultElementType();
+  }
+  return Ty;
+}
+
 Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
     Value *I, std::unordered_set<Value *> &Visited, bool UnknownElemTypeI8,
     bool IgnoreKnownType) {
@@ -668,15 +682,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
   if (auto *Ref = dyn_cast<AllocaInst>(I)) {
     maybeAssignPtrType(Ty, I, Ref->getAllocatedType(), UnknownElemTypeI8);
   } else if (auto *Ref = dyn_cast<GetElementPtrInst>(I)) {
-    // TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
-    // useful here
-    if (isNestedPointer(Ref->getSourceElementType())) {
-      Ty = Ref->getSourceElementType();
-      for (Use &U : drop_begin(Ref->indices()))
-        Ty = GetElementPtrInst::getTypeAtIndex(Ty, U.get());
-    } else {
-      Ty = Ref->getResultElementType();
-    }
+    Ty = getGEPType(Ref);
   } else if (auto *Ref = dyn_cast<LoadInst>(I)) {
     Value *Op = Ref->getPointerOperand();
     Type *KnownTy = GR->findDeducedElementType(Op);
@@ -2307,6 +2313,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) {
 
 // Apply types parsed from demangled function declarations.
 void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
+  DenseMap<Function *, CallInst *> Ptrcasts;
   for (auto It : FDeclPtrTys) {
     Function *F = It.first;
     for (auto *U : F->users()) {
@@ -2326,6 +2333,9 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) {
             B.SetCurrentDebugLocation(DebugLoc());
             buildAssignPtr(B, ElemTy, Arg);
           }
+        } else if (isa<GetElementPtrInst>(Param)) {
+          replaceUsesOfWithSpvPtrcast(Param, normalizeType(ElemTy), CI,
+                                      Ptrcasts);
         } else if (isa<Instruction>(Param)) {
           GR->addDeducedElementType(Param, normalizeType(ElemTy));
           // insertAssignTypeIntrs() will complete buildAssignPtr()
@@ -2370,6 +2380,15 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   AggrConstTypes.clear();
   AggrStores.clear();
 
+  // fix GEP result types ahead of inference
+  for (auto &I : instructions(Func)) {
+    auto *Ref = dyn_cast<GetElementPtrInst>(&I);
+    if (!Ref || GR->findDeducedElementType(Ref))
+      continue;
+    if (Type *GepTy = getGEPType(Ref))
+      GR->addDeducedElementType(Ref, normalizeType(GepTy));
+  }
+
   processParamTypesByFunHeader(CurrF, B);
 
   // StoreInst's operand type can be changed during the next transformations,
diff --git a/llvm/test/CodeGen/SPIRV/pointers/ptr-access-chain-type.ll b/llvm/test/CodeGen/SPIRV/pointers/ptr-access-chain-type.ll
new file mode 100644
index 0000000000000..d69959609c9dc
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/ptr-access-chain-type.ll
@@ -0,0 +1,24 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-DAG: %[[#Char:]] = OpTypeInt 8 0
+; CHECK-DAG: %[[#Long:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#CharPtr:]] = OpTypePointer CrossWorkgroup %[[#Char]]
+; CHECK-DAG: %[[#LongPtr:]] = OpTypePointer CrossWorkgroup %[[#Long]]
+; CHECK-DAG: %[[#LongPtrWG:]] = OpTypePointer Workgroup %[[#Long]]
+; CHECK: OpFunction
+; CHECK: OpFunctionParameter
+; CHECK: %[[#Dest:]] = OpFunctionParameter %[[#CharPtr]]
+; CHECK: %[[#Src:]] = OpFunctionParameter %[[#LongPtrWG]]
+; CHECK: %[[#InDest:]] = OpInBoundsPtrAccessChain %[[#CharPtr]] %[[#Dest]] %[[#]]
+; CHECK: %[[#InDestCasted:]] = OpBitcast %[[#LongPtr]] %[[#InDest]]
+; CHECK: OpGroupAsyncCopy %[[#]] %[[#]] %[[#InDestCasted]] %[[#Src]] %[[#]] %[[#]] %[[#]]
+
+define spir_kernel void @foo(i64 %idx, ptr addrspace(1) %dest, ptr addrspace(3) %src) {
+  %p1 = getelementptr inbounds i8, ptr addrspace(1) %dest, i64 %idx
+  %res = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32 2, ptr addrspace(1) %p1, ptr addrspace(3) %src, i64 128, i64 1, target("spirv.Event") zeroinitializer)
+  ret void
+}
+
+; For this test case the mangling is important.
+declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 494f672 into llvm:main Feb 28, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants