Skip to content

Commit 4a9c78e

Browse files
Prepare SPIRVWriter for type conversion without opaque pointers. (#1499)
This changeset is, by itself, not yet enough to get most SPIR-V files to be emitted when the input module is in opaque pointer mode. However, this does remove all of the calls to `getPointerElementType` that SPIRVWriter makes (directly or indirectly), except for the ones that directly correspond to translating a pointer type. A later changeset will add a type scavenger that will be used to find the pointee type of a pointer. All calls to `getPointerElementType` that remain after this one will be instead shifted to query the type scavenger instead. To facilitate this change, several methods are added to avoid querying pointer element types, and they have been added in several places where their need is known. The most basic of basic kernels, those that do not use pointer types (other than declaring global values and functions) will work in opaque pointer mode with this changeset.
1 parent 3f5e65d commit 4a9c78e

File tree

10 files changed

+406
-328
lines changed

10 files changed

+406
-328
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -976,13 +976,18 @@ void OCLToSPIRVBase::visitCallReadImageWithSampler(CallInst *CI,
976976
StringRef MangledName) {
977977
assert(MangledName.find(kMangledName::Sampler) != StringRef::npos);
978978
assert(CI->getCalledFunction() && "Unexpected indirect call");
979-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
979+
Function *Func = CI->getCalledFunction();
980+
AttributeList Attrs = Func->getAttributes();
980981
bool IsRetScalar = !CI->getType()->isVectorTy();
982+
SmallVector<StructType *, 3> ArgStructTys;
983+
getParameterTypes(CI, ArgStructTys);
981984
mutateCallInstSPIRV(
982985
M, CI,
983986
[=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
984987
auto *ImageTy =
985-
OCLTypeToSPIRVPtr->getAdaptedType(Args[0])->getPointerElementType();
988+
OCLTypeToSPIRVPtr->getAdaptedArgumentType(Func, 0).second;
989+
if (!ImageTy)
990+
ImageTy = ArgStructTys[0];
986991
ImageTy = adaptSPIRVImageType(M, ImageTy);
987992
auto SampledImgTy = getSPIRVTypeByChangeBaseTypeName(
988993
M, ImageTy, kSPIRVTypeName::Image, kSPIRVTypeName::SampledImg);
@@ -1696,8 +1701,12 @@ void OCLToSPIRVBase::visitSubgroupAVCBuiltinCallWithSampler(
16961701
if (!isOCLImageStructType(ParamTys[I]))
16971702
continue;
16981703

1699-
auto *ImageTy = OCLTypeToSPIRVPtr->getAdaptedType(Args[I])
1700-
->getPointerElementType();
1704+
auto *ImageTy =
1705+
OCLTypeToSPIRVPtr
1706+
->getAdaptedArgumentType(CI->getCalledFunction(), I)
1707+
.second;
1708+
if (!ImageTy)
1709+
ImageTy = ParamTys[I];
17011710
ImageTy = adaptSPIRVImageType(M, ImageTy);
17021711
auto *SampledImgTy = getSPIRVTypeByChangeBaseTypeName(
17031712
M, ImageTy, kSPIRVTypeName::Image, kSPIRVTypeName::VmeImageINTEL);

lib/SPIRV/OCLTypeToSPIRV.cpp

Lines changed: 46 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ bool OCLTypeToSPIRVLegacy::runOnModule(Module &M) {
7070
return runOCLTypeToSPIRV(M);
7171
}
7272

73-
OCLTypeToSPIRVBase OCLTypeToSPIRVPass::run(llvm::Module &M,
74-
llvm::ModuleAnalysisManager &MAM) {
73+
OCLTypeToSPIRVBase &OCLTypeToSPIRVPass::run(llvm::Module &M,
74+
llvm::ModuleAnalysisManager &MAM) {
7575
runOCLTypeToSPIRV(M);
7676
return *this;
7777
}
@@ -104,11 +104,12 @@ bool OCLTypeToSPIRVBase::runOCLTypeToSPIRV(Module &Module) {
104104
return false;
105105
}
106106

107-
void OCLTypeToSPIRVBase::addAdaptedType(Value *V, Type *T) {
107+
void OCLTypeToSPIRVBase::addAdaptedType(Value *V, Type *Ty,
108+
unsigned AddrSpace) {
108109
LLVM_DEBUG(dbgs() << "[add adapted type] ";
109110
V->printAsOperand(dbgs(), true, M);
110-
dbgs() << " => " << *T << '\n');
111-
AdaptedTy[V] = T;
111+
dbgs() << " => " << *Ty << '\n');
112+
AdaptedTy[V] = {Ty, AddrSpace};
112113
}
113114

114115
void OCLTypeToSPIRVBase::addWork(Function *F) {
@@ -117,34 +118,6 @@ void OCLTypeToSPIRVBase::addWork(Function *F) {
117118
WorkSet.insert(F);
118119
}
119120

120-
/// Find index of \param V as argument of function call \param CI.
121-
static unsigned getArgIndex(CallInst *CI, Value *V) {
122-
for (unsigned AI = 0, AE = CI->arg_size(); AI != AE; ++AI) {
123-
if (CI->getArgOperand(AI) == V)
124-
return AI;
125-
}
126-
llvm_unreachable("Not argument of function call");
127-
return ~0U;
128-
}
129-
130-
/// Find index of \param V as argument of function call \param CI.
131-
static unsigned getArgIndex(Function *F, Value *V) {
132-
auto A = F->arg_begin(), E = F->arg_end();
133-
for (unsigned I = 0; A != E; ++I, ++A) {
134-
if (&(*A) == V)
135-
return I;
136-
}
137-
llvm_unreachable("Not argument of function");
138-
return ~0U;
139-
}
140-
141-
/// Get i-th argument of a function.
142-
static Argument *getArg(Function *F, unsigned I) {
143-
auto AI = F->arg_begin();
144-
std::advance(AI, I);
145-
return &(*AI);
146-
}
147-
148121
/// Create a new function type if \param F has arguments in AdaptedTy, and
149122
/// propagates the adapted arguments to functions called by \param F.
150123
void OCLTypeToSPIRVBase::adaptFunction(Function *F) {
@@ -158,15 +131,17 @@ void OCLTypeToSPIRVBase::adaptFunction(Function *F) {
158131
auto Loc = AdaptedTy.find(&I);
159132
auto Found = (Loc != AdaptedTy.end());
160133
Changed |= Found;
161-
ArgTys.push_back(Found ? Loc->second : I.getType());
134+
ArgTys.push_back(Found ? Loc->second.first : I.getType());
162135

163136
if (Found) {
164-
for (auto U : I.users()) {
165-
if (auto CI = dyn_cast<CallInst>(U)) {
166-
auto ArgIndex = getArgIndex(CI, &I);
137+
auto *Ty = Loc->second.first;
138+
unsigned AddrSpace = Loc->second.second;
139+
for (auto &U : I.uses()) {
140+
if (auto *CI = dyn_cast<CallInst>(U.getUser())) {
141+
auto ArgIndex = CI->getArgOperandNo(&U);
167142
auto CF = CI->getCalledFunction();
168143
if (AdaptedTy.count(CF) == 0) {
169-
addAdaptedType(getArg(CF, ArgIndex), Loc->second);
144+
addAdaptedType(CF->getArg(ArgIndex), Ty, AddrSpace);
170145
addWork(CF);
171146
}
172147
}
@@ -179,7 +154,7 @@ void OCLTypeToSPIRVBase::adaptFunction(Function *F) {
179154

180155
auto FT = F->getFunctionType();
181156
FT = FunctionType::get(FT->getReturnType(), ArgTys, FT->isVarArg());
182-
addAdaptedType(F, FT);
157+
addAdaptedType(F, FT, 0);
183158
}
184159

185160
// Handle functions with sampler arguments that don't get called by
@@ -204,15 +179,10 @@ void OCLTypeToSPIRVBase::adaptArgumentsBySamplerUse(Module &M) {
204179
AdaptedTy.count(SamplerArg) != 0) // Already traced this, move on.
205180
continue;
206181

207-
if (SamplerArg->getType()->isPointerTy() &&
208-
isSPIRVStructType(SamplerArg->getType()->getPointerElementType(),
209-
kSPIRVTypeName::Sampler))
210-
return;
211-
212-
addAdaptedType(SamplerArg, getSamplerType(&M));
182+
addAdaptedType(SamplerArg, getSamplerStructType(&M), SPIRAS_Constant);
213183
auto Caller = cast<Argument>(SamplerArg)->getParent();
214184
addWork(Caller);
215-
TraceArg(Caller, getArgIndex(Caller, SamplerArg));
185+
TraceArg(Caller, Idx);
216186
}
217187
};
218188

@@ -235,20 +205,28 @@ void OCLTypeToSPIRVBase::adaptFunctionArguments(Function *F) {
235205
if (TypeMD)
236206
return;
237207
bool Changed = false;
238-
auto FT = F->getFunctionType();
239-
auto PI = FT->param_begin();
240208
auto Arg = F->arg_begin();
241-
for (unsigned I = 0; I < F->arg_size(); ++I, ++PI, ++Arg) {
242-
auto NewTy = *PI;
243-
if (isPointerToOpaqueStructType(NewTy)) {
244-
auto STName = NewTy->getPointerElementType()->getStructName();
209+
SmallVector<StructType *, 4> ParamTys;
210+
getParameterTypes(F, ParamTys);
211+
212+
// If we couldn't get any information from demangling, there is nothing that
213+
// can be done.
214+
if (ParamTys.empty())
215+
return;
216+
217+
for (unsigned I = 0; I < F->arg_size(); ++I, ++Arg) {
218+
StructType *NewTy = ParamTys[I];
219+
if (NewTy && NewTy->isOpaque()) {
220+
auto STName = NewTy->getStructName();
245221
if (!hasAccessQualifiedName(STName))
246222
continue;
247223
if (STName.startswith(kSPR2TypeName::ImagePrefix)) {
248224
auto Ty = STName.str();
249225
auto AccStr = getAccessQualifierFullName(Ty);
250-
addAdaptedType(&*Arg, getOrCreateOpaquePtrType(
251-
M, mapOCLTypeNameToSPIRV(Ty, AccStr)));
226+
addAdaptedType(
227+
&*Arg,
228+
getOrCreateOpaqueStructType(M, mapOCLTypeNameToSPIRV(Ty, AccStr)),
229+
SPIRAS_Global);
252230
Changed = true;
253231
}
254232
}
@@ -269,16 +247,18 @@ void OCLTypeToSPIRVBase::adaptArgumentsByMetadata(Function *F) {
269247
for (unsigned I = 0, E = TypeMD->getNumOperands(); I != E; ++I, ++Arg) {
270248
auto OCLTyStr = getMDOperandAsString(TypeMD, I);
271249
if (OCLTyStr == OCL_TYPE_NAME_SAMPLER_T) {
272-
addAdaptedType(&(*Arg), getSamplerType(M));
250+
addAdaptedType(&(*Arg), getSamplerStructType(M), SPIRAS_Constant);
273251
Changed = true;
274252
} else if (OCLTyStr.startswith("image") && OCLTyStr.endswith("_t")) {
275253
auto Ty = (Twine("opencl.") + OCLTyStr).str();
276254
if (StructType::getTypeByName(F->getContext(), Ty)) {
277255
auto AccMD = F->getMetadata(SPIR_MD_KERNEL_ARG_ACCESS_QUAL);
278256
assert(AccMD && "Invalid access qualifier metadata");
279257
auto AccStr = getMDOperandAsString(AccMD, I);
280-
addAdaptedType(&(*Arg), getOrCreateOpaquePtrType(
281-
M, mapOCLTypeNameToSPIRV(Ty, AccStr)));
258+
addAdaptedType(
259+
&(*Arg),
260+
getOrCreateOpaqueStructType(M, mapOCLTypeNameToSPIRV(Ty, AccStr)),
261+
SPIRAS_Global);
282262
Changed = true;
283263
}
284264
}
@@ -315,14 +295,15 @@ void OCLTypeToSPIRVBase::adaptArgumentsByMetadata(Function *F) {
315295
// opencl data type x and access qualifier y, and use opencl.image_x.y to
316296
// represent image_x type with access qualifier y.
317297
//
318-
Type *OCLTypeToSPIRVBase::getAdaptedType(Value *V) {
319-
auto Loc = AdaptedTy.find(V);
320-
if (Loc != AdaptedTy.end())
321-
return Loc->second;
322-
323-
if (auto F = dyn_cast<Function>(V))
324-
return F->getFunctionType();
325-
return V->getType();
298+
std::pair<Type *, Type *>
299+
OCLTypeToSPIRVBase::getAdaptedArgumentType(Function *F, unsigned ArgNo) {
300+
Value *Arg = F->getArg(ArgNo);
301+
auto Loc = AdaptedTy.find(Arg);
302+
if (Loc == AdaptedTy.end())
303+
return {nullptr, nullptr};
304+
Type *PointeeTy = Loc->second.first;
305+
Type *PointerTy = PointerType::get(PointeeTy, Loc->second.second);
306+
return {PointerTy, PointeeTy};
326307
}
327308

328309
} // namespace SPIRV

lib/SPIRV/OCLTypeToSPIRV.h

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "llvm/IR/Function.h"
4747
#include "llvm/IR/LLVMContext.h"
4848
#include "llvm/IR/PassManager.h"
49+
#include "llvm/IR/ValueMap.h"
4950
#include "llvm/Pass.h"
5051

5152
#include <map>
@@ -58,23 +59,26 @@ class OCLTypeToSPIRVBase {
5859
OCLTypeToSPIRVBase();
5960

6061
bool runOCLTypeToSPIRV(llvm::Module &M);
61-
/// \return Adapted type based on kernel argument metadata. If \p V is
62-
/// a function, returns function type.
63-
/// E.g. for a function with argument of read only opencl.image_2d_t* type
64-
/// returns a function with argument of type opencl.image2d_t.read_only*.
65-
llvm::Type *getAdaptedType(llvm::Value *V);
62+
63+
/// Returns the adapted type of the corresponding argument for a function.
64+
/// The first value of the returned pair is the LLVM type of the argument.
65+
/// The second value of the returned pair is the pointer element type of the
66+
/// argument, if the type is a pointer.
67+
std::pair<llvm::Type *, llvm::Type *>
68+
getAdaptedArgumentType(llvm::Function *F, unsigned ArgNo);
6669

6770
private:
6871
llvm::Module *M;
6972
llvm::LLVMContext *Ctx;
70-
std::map<llvm::Value *, llvm::Type *> AdaptedTy; // Adapted types for values
71-
std::set<llvm::Function *> WorkSet; // Functions to be adapted
73+
// Map of argument/Function -> {pointee type, address space}
74+
llvm::ValueMap<llvm::Value *, std::pair<llvm::Type *, unsigned>> AdaptedTy;
75+
std::set<llvm::Function *> WorkSet; // Functions to be adapted
7276

7377
void adaptFunctionArguments(llvm::Function *F);
7478
void adaptArgumentsByMetadata(llvm::Function *F);
7579
void adaptArgumentsBySamplerUse(llvm::Module &M);
7680
void adaptFunction(llvm::Function *F);
77-
void addAdaptedType(llvm::Value *V, llvm::Type *T);
81+
void addAdaptedType(llvm::Value *V, llvm::Type *PointeeTy, unsigned AS);
7882
void addWork(llvm::Function *F);
7983
};
8084

@@ -92,7 +96,7 @@ class OCLTypeToSPIRVPass : public OCLTypeToSPIRVBase,
9296
public:
9397
using Result = OCLTypeToSPIRVBase;
9498
static llvm::AnalysisKey Key;
95-
OCLTypeToSPIRVBase run(llvm::Module &F, llvm::ModuleAnalysisManager &MAM);
99+
OCLTypeToSPIRVBase &run(llvm::Module &F, llvm::ModuleAnalysisManager &MAM);
96100
};
97101

98102
} // namespace SPIRV

lib/SPIRV/OCLUtil.cpp

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,47 +1326,30 @@ Instruction *mutateCallInstOCL(
13261326
TakeFuncName);
13271327
}
13281328

1329-
static std::pair<StringRef, StringRef>
1330-
getSrcAndDstElememntTypeName(BitCastInst *BIC) {
1331-
if (!BIC)
1332-
return std::pair<StringRef, StringRef>("", "");
1333-
1334-
Type *SrcTy = BIC->getSrcTy();
1335-
Type *DstTy = BIC->getDestTy();
1336-
if (SrcTy->isPointerTy())
1337-
SrcTy = SrcTy->getPointerElementType();
1338-
if (DstTy->isPointerTy())
1339-
DstTy = DstTy->getPointerElementType();
1340-
auto SrcST = dyn_cast<StructType>(SrcTy);
1341-
auto DstST = dyn_cast<StructType>(DstTy);
1342-
if (!DstST || !DstST->hasName() || !SrcST || !SrcST->hasName())
1343-
return std::pair<StringRef, StringRef>("", "");
1344-
1345-
return std::make_pair(SrcST->getName(), DstST->getName());
1329+
static StringRef getStructName(Type *Ty) {
1330+
if (auto *STy = dyn_cast<StructType>(Ty))
1331+
return STy->isLiteral() ? "" : Ty->getStructName();
1332+
return "";
13461333
}
13471334

1348-
bool isSamplerInitializer(Instruction *Inst) {
1349-
BitCastInst *BIC = dyn_cast<BitCastInst>(Inst);
1350-
auto Names = getSrcAndDstElememntTypeName(BIC);
1351-
if (Names.second == getSPIRVTypeName(kSPIRVTypeName::Sampler) &&
1352-
Names.first == getSPIRVTypeName(kSPIRVTypeName::ConstantSampler))
1353-
return true;
1354-
1355-
return false;
1356-
}
1357-
1358-
bool isPipeStorageInitializer(Instruction *Inst) {
1359-
BitCastInst *BIC = dyn_cast<BitCastInst>(Inst);
1360-
auto Names = getSrcAndDstElememntTypeName(BIC);
1361-
if (Names.second == getSPIRVTypeName(kSPIRVTypeName::PipeStorage) &&
1362-
Names.first == getSPIRVTypeName(kSPIRVTypeName::ConstantPipeStorage))
1363-
return true;
1364-
1365-
return false;
1366-
}
1367-
1368-
bool isSpecialTypeInitializer(Instruction *Inst) {
1369-
return isSamplerInitializer(Inst) || isPipeStorageInitializer(Inst);
1335+
Value *unwrapSpecialTypeInitializer(Value *V) {
1336+
if (auto *BC = dyn_cast<BitCastOperator>(V)) {
1337+
Type *DestTy = BC->getDestTy();
1338+
Type *SrcTy = BC->getSrcTy();
1339+
if (SrcTy->isPointerTy() && !SrcTy->isOpaquePointerTy()) {
1340+
StringRef SrcName =
1341+
getStructName(SrcTy->getNonOpaquePointerElementType());
1342+
StringRef DestName =
1343+
getStructName(DestTy->getNonOpaquePointerElementType());
1344+
if (DestName == getSPIRVTypeName(kSPIRVTypeName::PipeStorage) &&
1345+
SrcName == getSPIRVTypeName(kSPIRVTypeName::ConstantPipeStorage))
1346+
return BC->getOperand(0);
1347+
if (DestName == getSPIRVTypeName(kSPIRVTypeName::Sampler) &&
1348+
SrcName == getSPIRVTypeName(kSPIRVTypeName::ConstantSampler))
1349+
return BC->getOperand(0);
1350+
}
1351+
}
1352+
return nullptr;
13701353
}
13711354

13721355
bool isSamplerStructTy(StructType *STy) {

lib/SPIRV/OCLUtil.h

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -495,15 +495,10 @@ Instruction *mutateCallInstOCL(
495495
std::function<Instruction *(CallInst *)> RetMutate,
496496
AttributeList *Attrs = nullptr, bool TakeFuncName = false);
497497

498-
/// Check if instruction is bitcast from spirv.ConstantSampler to spirv.Sampler
499-
bool isSamplerInitializer(Instruction *Inst);
500-
501-
/// Check if instruction is bitcast from spirv.ConstantPipeStorage
502-
/// to spirv.PipeStorage
503-
bool isPipeStorageInitializer(Instruction *Inst);
504-
505-
/// Check (isSamplerInitializer || isPipeStorageInitializer)
506-
bool isSpecialTypeInitializer(Instruction *Inst);
498+
/// If the value is a special type initializer (something that bitcasts from
499+
/// spirv.ConstantSampler to spirv.Sampler or likewise for PipeStorage), get the
500+
/// original type initializer, unwrap the bitcast. Otherwise, return nullptr.
501+
Value *unwrapSpecialTypeInitializer(Value *V);
507502

508503
bool isPipeOrAddressSpaceCastBI(const StringRef MangledName);
509504
bool isEnqueueKernelBI(const StringRef MangledName);

0 commit comments

Comments
 (0)