Skip to content

Commit 5fb5a7f

Browse files
type inference: use types parsed from demangled function declarations
1 parent 8785813 commit 5fb5a7f

File tree

5 files changed

+159
-47
lines changed

5 files changed

+159
-47
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,16 +2664,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
26642664
return false;
26652665
}
26662666

2667-
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2668-
unsigned ArgIdx, LLVMContext &Ctx) {
2669-
SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
2670-
StringRef BuiltinArgs =
2671-
DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
2672-
BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
2673-
if (ArgIdx >= BuiltinArgsTypeStrs.size())
2674-
return nullptr;
2675-
StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
2676-
2667+
Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx) {
26772668
// Parse strings representing OpenCL builtin types.
26782669
if (hasBuiltinTypePrefix(TypeStr)) {
26792670
// OpenCL builtin types in demangled call strings have the following format:
@@ -2717,6 +2708,29 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
27172708
return BaseType;
27182709
}
27192710

2711+
bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
2712+
const StringRef DemangledCall, LLVMContext &Ctx) {
2713+
auto Pos1 = DemangledCall.find('(');
2714+
if (Pos1 == StringRef::npos)
2715+
return false;
2716+
auto Pos2 = DemangledCall.find(')');
2717+
if (Pos2 == StringRef::npos || Pos1 > Pos2)
2718+
return false;
2719+
DemangledCall.slice(Pos1 + 1, Pos2)
2720+
.split(BuiltinArgsTypeStrs, ',', -1, false);
2721+
return true;
2722+
}
2723+
2724+
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2725+
unsigned ArgIdx, LLVMContext &Ctx) {
2726+
SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
2727+
parseBuiltinTypeStr(BuiltinArgsTypeStrs, DemangledCall, Ctx);
2728+
if (ArgIdx >= BuiltinArgsTypeStrs.size())
2729+
return nullptr;
2730+
StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
2731+
return parseBuiltinCallArgumentType(TypeStr, Ctx);
2732+
}
2733+
27202734
struct BuiltinType {
27212735
StringRef Name;
27222736
uint32_t Opcode;

llvm/lib/Target/SPIRV/SPIRVBuiltins.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
5656
/// \p ArgIdx is the index of the argument to parse.
5757
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
5858
unsigned ArgIdx, LLVMContext &Ctx);
59+
bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
60+
const StringRef DemangledCall, LLVMContext &Ctx);
61+
Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx);
5962

6063
/// Translates a string representing a SPIR-V or OpenCL builtin type to a
6164
/// TargetExtType that can be further lowered with lowerBuiltinType().

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 116 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ class SPIRVEmitIntrinsics
7676
DenseSet<Instruction *> AggrStores;
7777
SPIRV::InstructionSet::InstructionSet InstrSet;
7878

79+
// map of function declarations to <pointer arg index => element type>
80+
DenseMap<const Function *, SmallVector<std::pair<unsigned, Type *>>>
81+
FDeclPtrTys;
82+
7983
// a register of Instructions that don't have a complete type definition
8084
bool CanTodoType = true;
8185
unsigned TodoTypeSz = 0;
@@ -184,6 +188,10 @@ class SPIRVEmitIntrinsics
184188
void deduceOperandElementTypeFunctionPointer(
185189
CallInst *CI, SmallVector<std::pair<Value *, unsigned>> &Ops,
186190
Type *&KnownElemTy, bool IsPostprocessing);
191+
bool deduceOperandElementTypeFunctionRet(
192+
Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
193+
const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing,
194+
Type *&KnownElemTy, Value *Op, Function *F);
187195

188196
CallInst *buildSpvPtrcast(Function *F, Value *Op, Type *ElemTy);
189197
void replaceUsesOfWithSpvPtrcast(Value *Op, Type *ElemTy, Instruction *I,
@@ -205,6 +213,7 @@ class SPIRVEmitIntrinsics
205213
bool runOnFunction(Function &F);
206214
bool postprocessTypes(Module &M);
207215
bool processFunctionPointers(Module &M);
216+
void parseFunDeclarations(Module &M);
208217

209218
public:
210219
static char ID;
@@ -957,6 +966,47 @@ void SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionPointer(
957966
IsNewFTy ? FunctionType::get(RetTy, ArgTys, FTy->isVarArg()) : FTy;
958967
}
959968

969+
bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet(
970+
Instruction *I, SmallPtrSet<Instruction *, 4> *UncompleteRets,
971+
const SmallPtrSet<Value *, 4> *AskOps, bool IsPostprocessing,
972+
Type *&KnownElemTy, Value *Op, Function *F) {
973+
KnownElemTy = GR->findDeducedElementType(F);
974+
if (KnownElemTy)
975+
return false;
976+
if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
977+
GR->addDeducedElementType(F, OpElemTy);
978+
GR->addReturnType(
979+
F, TypedPointerType::get(OpElemTy,
980+
getPointerAddressSpace(F->getReturnType())));
981+
// non-recursive update of types in function uses
982+
DenseSet<std::pair<Value *, Value *>> VisitedSubst{std::make_pair(I, Op)};
983+
for (User *U : F->users()) {
984+
CallInst *CI = dyn_cast<CallInst>(U);
985+
if (!CI || CI->getCalledFunction() != F)
986+
continue;
987+
if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
988+
if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
989+
updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
990+
propagateElemType(CI, PrevElemTy, VisitedSubst);
991+
}
992+
}
993+
}
994+
// Non-recursive update of types in the function uncomplete returns.
995+
// This may happen just once per a function, the latch is a pair of
996+
// findDeducedElementType(F) / addDeducedElementType(F, ...).
997+
// With or without the latch it is a non-recursive call due to
998+
// UncompleteRets set to nullptr in this call.
999+
if (UncompleteRets)
1000+
for (Instruction *UncompleteRetI : *UncompleteRets)
1001+
deduceOperandElementType(UncompleteRetI, nullptr, AskOps,
1002+
IsPostprocessing);
1003+
} else if (UncompleteRets) {
1004+
UncompleteRets->insert(I);
1005+
}
1006+
TypeValidated.insert(I);
1007+
return true;
1008+
}
1009+
9601010
// If the Instruction has Pointer operands with unresolved types, this function
9611011
// tries to deduce them. If the Instruction has Pointer operands with known
9621012
// types which differ from expected, this function tries to insert a bitcast to
@@ -1039,46 +1089,15 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(
10391089
Ops.push_back(std::make_pair(Op, i));
10401090
}
10411091
} else if (auto *Ref = dyn_cast<ReturnInst>(I)) {
1042-
Type *RetTy = CurrF->getReturnType();
1043-
if (!isPointerTy(RetTy))
1092+
if (!isPointerTy(CurrF->getReturnType()))
10441093
return;
10451094
Value *Op = Ref->getReturnValue();
10461095
if (!Op)
10471096
return;
1048-
if (!(KnownElemTy = GR->findDeducedElementType(CurrF))) {
1049-
if (Type *OpElemTy = GR->findDeducedElementType(Op)) {
1050-
GR->addDeducedElementType(CurrF, OpElemTy);
1051-
GR->addReturnType(CurrF, TypedPointerType::get(
1052-
OpElemTy, getPointerAddressSpace(RetTy)));
1053-
// non-recursive update of types in function uses
1054-
DenseSet<std::pair<Value *, Value *>> VisitedSubst{
1055-
std::make_pair(I, Op)};
1056-
for (User *U : CurrF->users()) {
1057-
CallInst *CI = dyn_cast<CallInst>(U);
1058-
if (!CI || CI->getCalledFunction() != CurrF)
1059-
continue;
1060-
if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) {
1061-
if (Type *PrevElemTy = GR->findDeducedElementType(CI)) {
1062-
updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy));
1063-
propagateElemType(CI, PrevElemTy, VisitedSubst);
1064-
}
1065-
}
1066-
}
1067-
// Non-recursive update of types in the function uncomplete returns.
1068-
// This may happen just once per a function, the latch is a pair of
1069-
// findDeducedElementType(F) / addDeducedElementType(F, ...).
1070-
// With or without the latch it is a non-recursive call due to
1071-
// UncompleteRets set to nullptr in this call.
1072-
if (UncompleteRets)
1073-
for (Instruction *UncompleteRetI : *UncompleteRets)
1074-
deduceOperandElementType(UncompleteRetI, nullptr, AskOps,
1075-
IsPostprocessing);
1076-
} else if (UncompleteRets) {
1077-
UncompleteRets->insert(I);
1078-
}
1079-
TypeValidated.insert(I);
1097+
if (deduceOperandElementTypeFunctionRet(I, UncompleteRets, AskOps,
1098+
IsPostprocessing, KnownElemTy, Op,
1099+
CurrF))
10801100
return;
1081-
}
10821101
Uncomplete = isTodoType(CurrF);
10831102
Ops.push_back(std::make_pair(Op, 0));
10841103
} else if (auto *Ref = dyn_cast<ICmpInst>(I)) {
@@ -2157,6 +2176,8 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
21572176
AggrConstTypes.clear();
21582177
AggrStores.clear();
21592178

2179+
DenseMap<Function *, DenseMap<unsigned, Type *>> FDeclPtrTys;
2180+
21602181
processParamTypesByFunHeader(CurrF, B);
21612182

21622183
// StoreInst's operand type can be changed during the next transformations,
@@ -2180,6 +2201,31 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
21802201
for (auto &I : instructions(Func))
21812202
Worklist.push_back(&I);
21822203

2204+
// Apply types parsed from demangled function declarations.
2205+
for (auto &I : Worklist) {
2206+
CallInst *CI = dyn_cast<CallInst>(I);
2207+
if (!CI || !CI->getCalledFunction())
2208+
continue;
2209+
auto It = FDeclPtrTys.find(CI->getCalledFunction());
2210+
if (It == FDeclPtrTys.end())
2211+
continue;
2212+
unsigned Sz = CI->arg_size();
2213+
for (auto [Idx, ElemTy] : It->second)
2214+
if (Idx < Sz) {
2215+
Value *Arg = CI->getArgOperand(Idx);
2216+
GR->addDeducedElementType(Arg, ElemTy);
2217+
if (CallInst *Ref = dyn_cast<CallInst>(Arg))
2218+
if (Function *RefF = Ref->getCalledFunction();
2219+
RefF && isPointerTy(RefF->getReturnType()) &&
2220+
!GR->findDeducedElementType(RefF)) {
2221+
GR->addDeducedElementType(RefF, ElemTy);
2222+
GR->addReturnType(RefF, TypedPointerType::get(
2223+
ElemTy, getPointerAddressSpace(
2224+
RefF->getReturnType())));
2225+
}
2226+
}
2227+
}
2228+
21832229
// Pass forward: use operand to deduce instructions result.
21842230
for (auto &I : Worklist) {
21852231
// Don't emit intrinsincs for convergence intrinsics.
@@ -2287,9 +2333,44 @@ bool SPIRVEmitIntrinsics::postprocessTypes(Module &M) {
22872333
return SzTodo > TodoTypeSz;
22882334
}
22892335

2336+
// Parse and store argument types of function declarations where needed.
2337+
void SPIRVEmitIntrinsics::parseFunDeclarations(Module &M) {
2338+
for (auto &F : M) {
2339+
if (!F.isDeclaration() || F.isIntrinsic())
2340+
continue;
2341+
// get the demangled name
2342+
std::string DemangledName = getOclOrSpirvBuiltinDemangledName(F.getName());
2343+
if (DemangledName.empty())
2344+
continue;
2345+
// find pointer arguments
2346+
SmallVector<unsigned> Idxs;
2347+
for (unsigned OpIdx = 0; OpIdx < F.arg_size(); ++OpIdx)
2348+
if (isPointerTy(F.getArg(OpIdx)->getType()))
2349+
Idxs.push_back(OpIdx);
2350+
if (!Idxs.size())
2351+
continue;
2352+
// parse function arguments
2353+
LLVMContext &Ctx = F.getContext();
2354+
SmallVector<StringRef, 10> TypeStrs;
2355+
SPIRV::parseBuiltinTypeStr(TypeStrs, DemangledName, Ctx);
2356+
if (!TypeStrs.size())
2357+
continue;
2358+
// find type info for pointer arguments
2359+
for (unsigned Idx : Idxs) {
2360+
if (Idx >= TypeStrs.size())
2361+
continue;
2362+
if (Type *ElemTy =
2363+
SPIRV::parseBuiltinCallArgumentType(TypeStrs[Idx].trim(), Ctx))
2364+
FDeclPtrTys[&F].push_back(std::make_pair(Idx, ElemTy));
2365+
}
2366+
}
2367+
}
2368+
22902369
bool SPIRVEmitIntrinsics::runOnModule(Module &M) {
22912370
bool Changed = false;
22922371

2372+
parseFunDeclarations(M);
2373+
22932374
TodoType.clear();
22942375
for (auto &F : M)
22952376
Changed |= runOnFunction(F);

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,25 +447,31 @@ Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) {
447447
TypeName.consume_front("atomic_");
448448
if (TypeName.consume_front("void"))
449449
return Type::getVoidTy(Ctx);
450-
else if (TypeName.consume_front("bool"))
450+
else if (TypeName.consume_front("bool") || TypeName.consume_front("_Bool"))
451451
return Type::getIntNTy(Ctx, 1);
452452
else if (TypeName.consume_front("char") ||
453+
TypeName.consume_front("signed char") ||
453454
TypeName.consume_front("unsigned char") ||
454455
TypeName.consume_front("uchar"))
455456
return Type::getInt8Ty(Ctx);
456457
else if (TypeName.consume_front("short") ||
458+
TypeName.consume_front("signed short") ||
457459
TypeName.consume_front("unsigned short") ||
458460
TypeName.consume_front("ushort"))
459461
return Type::getInt16Ty(Ctx);
460462
else if (TypeName.consume_front("int") ||
463+
TypeName.consume_front("signed int") ||
461464
TypeName.consume_front("unsigned int") ||
462465
TypeName.consume_front("uint"))
463466
return Type::getInt32Ty(Ctx);
464467
else if (TypeName.consume_front("long") ||
468+
TypeName.consume_front("signed long") ||
465469
TypeName.consume_front("unsigned long") ||
466470
TypeName.consume_front("ulong"))
467471
return Type::getInt64Ty(Ctx);
468-
else if (TypeName.consume_front("half"))
472+
else if (TypeName.consume_front("half") ||
473+
TypeName.consume_front("_Float16") ||
474+
TypeName.consume_front("__fp16"))
469475
return Type::getHalfTy(Ctx);
470476
else if (TypeName.consume_front("float"))
471477
return Type::getFloatTy(Ctx);

llvm/test/CodeGen/SPIRV/transcoding/spirv-event-null.ll

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@
3232

3333
%StructEvent = type { target("spirv.Event") }
3434

35+
define spir_kernel void @test_half(ptr addrspace(3) %_arg1, ptr addrspace(1) %_arg2) {
36+
entry:
37+
%r = tail call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv2_DF16_PU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg1, ptr addrspace(1) %_arg2, i64 16, i64 10, target("spirv.Event") zeroinitializer)
38+
ret void
39+
}
40+
41+
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv2_DF16_PU3AS1KS_mm9ocl_event(i32 noundef, ptr addrspace(3) noundef, ptr addrspace(1) noundef, i64 noundef, i64 noundef, target("spirv.Event"))
42+
3543
define spir_kernel void @foo(ptr addrspace(1) %_arg_out_ptr, ptr addrspace(3) %_arg_local_acc) {
3644
entry:
3745
%var = alloca %StructEvent

0 commit comments

Comments
 (0)