Skip to content

Commit afec257

Browse files
[SPIRV] Add type inference of function parameters by call instances (#85077)
This PR adds type inference of function parameters by call instances. Two use cases that demonstrate the problem are added.
1 parent 2cf2ca3 commit afec257

File tree

8 files changed

+204
-8
lines changed

8 files changed

+204
-8
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
209209
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
210210
// type.
211211
Argument *Arg = F.getArg(ArgIdx);
212-
if (Arg->hasByValAttr() || Arg->hasByRefAttr()) {
212+
if (HasPointeeTypeAttr(Arg)) {
213213
Type *ByValRefType = Arg->hasByValAttr() ? Arg->getParamByValType()
214214
: Arg->getParamByRefType();
215215
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
@@ -319,6 +319,12 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
319319
buildOpDecorate(VRegs[i][0], MIRBuilder,
320320
SPIRV::Decoration::FuncParamAttr, {Attr});
321321
}
322+
if (Arg.hasAttribute(Attribute::ByVal)) {
323+
auto Attr =
324+
static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
325+
buildOpDecorate(VRegs[i][0], MIRBuilder,
326+
SPIRV::Decoration::FuncParamAttr, {Attr});
327+
}
322328

323329
if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
324330
std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class SPIRVEmitIntrinsics
9191
IRBuilder<> &B);
9292
void insertPtrCastOrAssignTypeInstr(Instruction *I, IRBuilder<> &B);
9393
void processGlobalValue(GlobalVariable &GV, IRBuilder<> &B);
94+
void processParamTypes(Function *F, IRBuilder<> &B);
9495

9596
public:
9697
static char ID;
@@ -794,6 +795,64 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
794795
}
795796
}
796797

798+
void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) {
799+
DenseMap<unsigned, Argument *> Args;
800+
unsigned i = 0;
801+
for (Argument &Arg : F->args()) {
802+
if (isUntypedPointerTy(Arg.getType()) &&
803+
DeducedElTys.find(&Arg) == DeducedElTys.end() &&
804+
!HasPointeeTypeAttr(&Arg))
805+
Args[i] = &Arg;
806+
i++;
807+
}
808+
if (Args.size() == 0)
809+
return;
810+
811+
// Args contains opaque pointers without element type definition
812+
B.SetInsertPointPastAllocas(F);
813+
std::unordered_set<Value *> Visited;
814+
for (User *U : F->users()) {
815+
CallInst *CI = dyn_cast<CallInst>(U);
816+
if (!CI)
817+
continue;
818+
for (unsigned OpIdx = 0; OpIdx < CI->arg_size() && Args.size() > 0;
819+
OpIdx++) {
820+
auto It = Args.find(OpIdx);
821+
Argument *Arg = It == Args.end() ? nullptr : It->second;
822+
if (!Arg)
823+
continue;
824+
Value *OpArg = CI->getArgOperand(OpIdx);
825+
if (!isPointerTy(OpArg->getType()))
826+
continue;
827+
// maybe we already know the operand's element type
828+
auto DeducedIt = DeducedElTys.find(OpArg);
829+
Type *ElemTy =
830+
DeducedIt == DeducedElTys.end() ? nullptr : DeducedIt->second;
831+
if (!ElemTy) {
832+
for (User *OpU : OpArg->users()) {
833+
if (Instruction *Inst = dyn_cast<Instruction>(OpU)) {
834+
Visited.clear();
835+
ElemTy = deduceElementTypeHelper(Inst, Visited, DeducedElTys);
836+
if (ElemTy)
837+
break;
838+
}
839+
}
840+
}
841+
if (ElemTy) {
842+
unsigned AddressSpace = getPointerAddressSpace(Arg->getType());
843+
CallInst *AssignPtrTyCI = buildIntrWithMD(
844+
Intrinsic::spv_assign_ptr_type, {Arg->getType()},
845+
Constant::getNullValue(ElemTy), Arg, {B.getInt32(AddressSpace)}, B);
846+
DeducedElTys[AssignPtrTyCI] = ElemTy;
847+
DeducedElTys[Arg] = ElemTy;
848+
Args.erase(It);
849+
}
850+
}
851+
if (Args.size() == 0)
852+
break;
853+
}
854+
}
855+
797856
bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
798857
if (Func.isDeclaration())
799858
return false;
@@ -839,6 +898,11 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
839898
continue;
840899
processInstrAfterVisit(I, B);
841900
}
901+
902+
// check if function parameter types are set
903+
if (!F->isIntrinsic())
904+
processParamTypes(F, B);
905+
842906
return true;
843907
}
844908

llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,7 @@ static void validatePtrTypes(const SPIRVSubtarget &STI,
104104
SPIRV::StorageClass::StorageClass SC =
105105
static_cast<SPIRV::StorageClass::StorageClass>(
106106
OpType->getOperand(1).getImm());
107-
MachineInstr *PrevI = I.getPrevNode();
108-
MachineBasicBlock &MBB = *I.getParent();
109-
MachineBasicBlock::iterator InsPt =
110-
PrevI ? PrevI->getIterator() : MBB.begin();
111-
MachineIRBuilder MIB(MBB, InsPt);
107+
MachineIRBuilder MIB(I);
112108
SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(ResType, MIB, SC);
113109
if (!GR.isBitcastCompatible(NewPtrType, OpType))
114110
report_fatal_error(

llvm/lib/Target/SPIRV/SPIRVUtils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,10 @@ inline unsigned getPointerAddressSpace(const Type *T) {
126126
: cast<TypedPointerType>(SubT)->getAddressSpace();
127127
}
128128

129+
// Return true if the Argument is decorated with a pointee type
130+
inline bool HasPointeeTypeAttr(Argument *Arg) {
131+
return Arg->hasByValAttr() || Arg->hasByRefAttr();
132+
}
133+
129134
} // namespace llvm
130135
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H

llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-load.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
; CHECK-DAG: %[[#TYLONGPTR:]] = OpTypePointer Function %[[#TYLONG]]
1010
; CHECK: %[[#PTRTOSTRUCT:]] = OpFunctionParameter %[[#TYSTRUCTPTR]]
1111
; CHECK: %[[#PTRTOLONG:]] = OpBitcast %[[#TYLONGPTR]] %[[#PTRTOSTRUCT]]
12-
; CHECK: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]]
12+
; CHECK-NEXT: OpLoad %[[#TYLONG]] %[[#PTRTOLONG]]
1313

1414
%struct.S = type { i32 }
1515
%struct.__wrapper_class = type { [7 x %struct.S] }

llvm/test/CodeGen/SPIRV/pointers/bitcast-fix-store.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
; CHECK: %[[#OBJ:]] = OpFunctionParameter %[[#TYSTRUCT]]
1414
; CHECK: %[[#ARGPTR2:]] = OpFunctionParameter %[[#TYLONGPTR]]
1515
; CHECK: %[[#PTRTOSTRUCT:]] = OpBitcast %[[#TYSTRUCTPTR]] %[[#ARGPTR2]]
16-
; CHECK: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]]
16+
; CHECK-NEXT: OpStore %[[#PTRTOSTRUCT]] %[[#OBJ]]
1717

1818
%struct.S = type { i32 }
1919
%struct.__wrapper_class = type { [7 x %struct.S] }
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "known_type_ptr"
5+
; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
6+
; CHECK-SPIRV-DAG: OpName %[[ArgToDeduce:.*]] "unknown_type_ptr"
7+
; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
8+
; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
9+
; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
10+
; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
11+
; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
12+
; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
13+
; CHECK-SPIRV: %[[ArgToDeduce]] = OpFunctionParameter %[[LongPtr]]
14+
; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[ArgToDeduce]]
15+
; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
16+
; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
17+
18+
define spir_kernel void @bar(ptr addrspace(1) %unknown_type_ptr) {
19+
entry:
20+
%elem = getelementptr inbounds i32, ptr addrspace(1) %unknown_type_ptr, i64 0
21+
call void @foo(ptr addrspace(1) %unknown_type_ptr)
22+
ret void
23+
}
24+
25+
define void @foo(ptr addrspace(1) %known_type_ptr) {
26+
entry:
27+
ret void
28+
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-SPIRV-DAG: OpName %[[FooArg:.*]] "unknown_type_ptr"
5+
; CHECK-SPIRV-DAG: OpName %[[Foo:.*]] "foo"
6+
; CHECK-SPIRV-DAG: OpName %[[BarArg:.*]] "known_type_ptr"
7+
; CHECK-SPIRV-DAG: OpName %[[Bar:.*]] "bar"
8+
; CHECK-SPIRV-DAG: OpName %[[UntypedArg:.*]] "arg"
9+
; CHECK-SPIRV-DAG: OpName %[[FunUntypedArg:.*]] "foo_untyped_arg"
10+
; CHECK-SPIRV-DAG: OpName %[[UnusedArg1:.*]] "unused_arg1"
11+
; CHECK-SPIRV-DAG: OpName %[[Foo2Arg:.*]] "unknown_type_ptr"
12+
; CHECK-SPIRV-DAG: OpName %[[Foo2:.*]] "foo2"
13+
; CHECK-SPIRV-DAG: OpName %[[Bar2Arg:.*]] "known_type_ptr"
14+
; CHECK-SPIRV-DAG: OpName %[[Bar2:.*]] "bar2"
15+
; CHECK-SPIRV-DAG: OpName %[[Foo5Arg1:.*]] "unknown_type_ptr1"
16+
; CHECK-SPIRV-DAG: OpName %[[Foo5Arg2:.*]] "unknown_type_ptr2"
17+
; CHECK-SPIRV-DAG: OpName %[[Foo5:.*]] "foo5"
18+
; CHECK-SPIRV-DAG: OpName %[[Bar5Arg:.*]] "known_type_ptr"
19+
; CHECK-SPIRV-DAG: OpName %[[Bar5:.*]] "bar5"
20+
; CHECK-SPIRV-DAG: %[[Char:.*]] = OpTypeInt 8 0
21+
; CHECK-SPIRV-DAG: %[[Long:.*]] = OpTypeInt 32 0
22+
; CHECK-SPIRV-DAG: %[[Half:.*]] = OpTypeFloat 16
23+
; CHECK-SPIRV-DAG: %[[Void:.*]] = OpTypeVoid
24+
; CHECK-SPIRV-DAG: %[[HalfConst:.*]] = OpConstant %[[Half]] 15360
25+
; CHECK-SPIRV-DAG: %[[CharPtr:.*]] = OpTypePointer CrossWorkgroup %[[Char]]
26+
; CHECK-SPIRV-DAG: %[[LongPtr:.*]] = OpTypePointer CrossWorkgroup %[[Long]]
27+
; CHECK-SPIRV-DAG: %[[Fun:.*]] = OpTypeFunction %[[Void]] %[[LongPtr]]
28+
; CHECK-SPIRV-DAG: %[[Fun2:.*]] = OpTypeFunction %[[Void]] %[[Half]] %[[LongPtr]]
29+
; CHECK-SPIRV-DAG: %[[Fun5:.*]] = OpTypeFunction %[[Void]] %[[Half]] %[[LongPtr]] %[[Half]] %[[LongPtr]] %[[Half]]
30+
; CHECK-SPIRV-DAG: %[[FunUntyped:.*]] = OpTypeFunction %[[Void]] %[[CharPtr]]
31+
32+
; CHECK-SPIRV: %[[Foo]] = OpFunction %[[Void]] None %[[Fun]]
33+
; CHECK-SPIRV: %[[FooArg]] = OpFunctionParameter %[[LongPtr]]
34+
; CHECK-SPIRV: %[[Bar]] = OpFunction %[[Void]] None %[[Fun]]
35+
; CHECK-SPIRV: %[[BarArg]] = OpFunctionParameter %[[LongPtr]]
36+
; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo]] %[[BarArg]]
37+
38+
; CHECK-SPIRV: %[[FunUntypedArg]] = OpFunction %[[Void]] None %[[FunUntyped]]
39+
; CHECK-SPIRV: %[[UntypedArg]] = OpFunctionParameter %[[CharPtr]]
40+
41+
; CHECK-SPIRV: %[[Foo2]] = OpFunction %[[Void]] None %[[Fun2]]
42+
; CHECK-SPIRV: %[[UnusedArg1]] = OpFunctionParameter %[[Half]]
43+
; CHECK-SPIRV: %[[Foo2Arg]] = OpFunctionParameter %[[LongPtr]]
44+
; CHECK-SPIRV: %[[Bar2]] = OpFunction %[[Void]] None %[[Fun]]
45+
; CHECK-SPIRV: %[[Bar2Arg]] = OpFunctionParameter %[[LongPtr]]
46+
; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo2]] %[[HalfConst]] %[[Bar2Arg]]
47+
48+
; CHECK-SPIRV: %[[Foo5]] = OpFunction %[[Void]] None %[[Fun5]]
49+
; CHECK-SPIRV: OpFunctionParameter %[[Half]]
50+
; CHECK-SPIRV: %[[Foo5Arg1]] = OpFunctionParameter %[[LongPtr]]
51+
; CHECK-SPIRV: OpFunctionParameter %[[Half]]
52+
; CHECK-SPIRV: %[[Foo5Arg2]] = OpFunctionParameter %[[LongPtr]]
53+
; CHECK-SPIRV: OpFunctionParameter %[[Half]]
54+
; CHECK-SPIRV: %[[Bar5]] = OpFunction %[[Void]] None %[[Fun]]
55+
; CHECK-SPIRV: %[[Bar5Arg]] = OpFunctionParameter %[[LongPtr]]
56+
; CHECK-SPIRV: OpFunctionCall %[[Void]] %[[Foo5]] %[[HalfConst]] %[[Bar5Arg]] %[[HalfConst]] %[[Bar5Arg]] %[[HalfConst]]
57+
58+
define void @foo(ptr addrspace(1) %unknown_type_ptr) {
59+
entry:
60+
ret void
61+
}
62+
63+
define spir_kernel void @bar(ptr addrspace(1) %known_type_ptr) {
64+
entry:
65+
%elem = getelementptr inbounds i32, ptr addrspace(1) %known_type_ptr, i64 0
66+
call void @foo(ptr addrspace(1) %known_type_ptr)
67+
ret void
68+
}
69+
70+
define void @foo_untyped_arg(ptr addrspace(1) %arg) {
71+
entry:
72+
ret void
73+
}
74+
75+
define void @foo2(half %unused_arg1, ptr addrspace(1) %unknown_type_ptr) {
76+
entry:
77+
ret void
78+
}
79+
80+
define spir_kernel void @bar2(ptr addrspace(1) %known_type_ptr) {
81+
entry:
82+
%elem = getelementptr inbounds i32, ptr addrspace(1) %known_type_ptr, i64 0
83+
call void @foo2(half 1.0, ptr addrspace(1) %known_type_ptr)
84+
ret void
85+
}
86+
87+
define void @foo5(half %unused_arg1, ptr addrspace(1) %unknown_type_ptr1, half %unused_arg2, ptr addrspace(1) %unknown_type_ptr2, half %unused_arg3) {
88+
entry:
89+
ret void
90+
}
91+
92+
define spir_kernel void @bar5(ptr addrspace(1) %known_type_ptr) {
93+
entry:
94+
%elem = getelementptr inbounds i32, ptr addrspace(1) %known_type_ptr, i64 0
95+
call void @foo5(half 1.0, ptr addrspace(1) %known_type_ptr, half 1.0, ptr addrspace(1) %known_type_ptr, half 1.0)
96+
ret void
97+
}

0 commit comments

Comments
 (0)