Skip to content

Commit 4911067

Browse files
[OpaquePointers] Adjust builtin variable tracking to support i8 geps. (#1892)
The existing logic for the replacement of builtin variables with calls to functions relies on relatively brittle tracking that is broken when opaque pointers is turned on, and will be even more thoroughly broken if/when typed geps are replaced with i8 geps or ptradd. This patch replaces that logic with a less brittle variant that is able to handle any sequence of bitcast, gep, or addrspacecast instructions between the global variable and the ultimate load instruction. It still will error out if the variable is used in too insane of a fashion (say, trying to load an i32 out of the i64, or a misaligned vector type).
1 parent 68855f6 commit 4911067

File tree

4 files changed

+97
-141
lines changed

4 files changed

+97
-141
lines changed

lib/SPIRV/SPIRVInternal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ std::string decodeSPIRVTypeName(StringRef Name,
981981
SmallVectorImpl<std::string> &Strs);
982982

983983
// Copy attributes from function to call site.
984-
void setAttrByCalledFunc(CallInst *Call);
984+
CallInst *setAttrByCalledFunc(CallInst *Call);
985985
bool isSPIRVBuiltinVariable(GlobalVariable *GV, SPIRVBuiltinVariableKind *Kind);
986986
// Transform builtin variable from GlobalVariable to builtin call.
987987
// e.g.

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 73 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1907,14 +1907,15 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
19071907
return true;
19081908
}
19091909

1910-
void setAttrByCalledFunc(CallInst *Call) {
1910+
CallInst *setAttrByCalledFunc(CallInst *Call) {
19111911
Function *F = Call->getCalledFunction();
19121912
assert(F);
19131913
if (F->isIntrinsic()) {
1914-
return;
1914+
return Call;
19151915
}
19161916
Call->setCallingConv(F->getCallingConv());
19171917
Call->setAttributes(F->getAttributes());
1918+
return Call;
19181919
}
19191920

19201921
bool isSPIRVBuiltinVariable(GlobalVariable *GV,
@@ -1964,6 +1965,75 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV,
19641965
// %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
19651966
// %5 = insertelement <3 x i64> %3, i64 %4, i32 2
19661967
// %6 = extractelement <3 x i64> %5, i32 0
1968+
1969+
/// Recursively look through the uses of a global variable, including casts or
1970+
/// gep offsets, to find all loads of the variable. Gep offsets that are non-0
1971+
/// are accumulated in the AccumulatedOffset parameter, which will eventually be
1972+
/// used to figure out which index of a variable is being used.
1973+
static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
1974+
Function *ReplacementFunc) {
1975+
const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
1976+
SmallVector<Instruction *, 4> InstsToRemove;
1977+
for (User *U : V->users()) {
1978+
if (auto *Cast = dyn_cast<CastInst>(U)) {
1979+
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
1980+
InstsToRemove.push_back(Cast);
1981+
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
1982+
APInt NewOffset = AccumulatedOffset.sextOrTrunc(
1983+
DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
1984+
if (!GEP->accumulateConstantOffset(DL, NewOffset))
1985+
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
1986+
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
1987+
InstsToRemove.push_back(GEP);
1988+
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
1989+
// Figure out which index the accumulated offset corresponds to. If we
1990+
// have a weird offset (e.g., trying to load byte 7), bail out.
1991+
Type *ScalarTy = ReplacementFunc->getReturnType();
1992+
APInt Index;
1993+
uint64_t Remainder;
1994+
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
1995+
Index, Remainder);
1996+
if (Remainder != 0)
1997+
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
1998+
1999+
IRBuilder<> Builder(Load);
2000+
Value *Replacement;
2001+
if (ReplacementFunc->getFunctionType()->getNumParams() == 0) {
2002+
if (Load->getType() != ScalarTy)
2003+
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
2004+
Replacement =
2005+
setAttrByCalledFunc(Builder.CreateCall(ReplacementFunc, {}));
2006+
} else {
2007+
// The function has an index parameter.
2008+
if (auto *VecTy = dyn_cast<FixedVectorType>(Load->getType())) {
2009+
if (!Index.isZero())
2010+
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
2011+
Replacement = UndefValue::get(VecTy);
2012+
for (unsigned I = 0; I < VecTy->getNumElements(); I++) {
2013+
Replacement = Builder.CreateInsertElement(
2014+
Replacement,
2015+
setAttrByCalledFunc(
2016+
Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
2017+
Builder.getInt32(I));
2018+
}
2019+
} else if (Load->getType() == ScalarTy) {
2020+
Replacement = setAttrByCalledFunc(Builder.CreateCall(
2021+
ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
2022+
} else {
2023+
llvm_unreachable("Illegal load type of a SPIR-V builtin variable");
2024+
}
2025+
}
2026+
Load->replaceAllUsesWith(Replacement);
2027+
InstsToRemove.push_back(Load);
2028+
} else {
2029+
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
2030+
}
2031+
}
2032+
2033+
for (Instruction *I : InstsToRemove)
2034+
I->eraseFromParent();
2035+
}
2036+
19672037
bool lowerBuiltinVariableToCall(GlobalVariable *GV,
19682038
SPIRVBuiltinVariableKind Kind) {
19692039
// There might be dead constant users of GV (for example, SPIRVLowerConstExpr
@@ -1999,113 +2069,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
19992069
Func->setDoesNotAccessMemory();
20002070
}
20012071

2002-
// Collect instructions in these containers to remove them later.
2003-
std::vector<Instruction *> Loads;
2004-
std::vector<Instruction *> Casts;
2005-
std::vector<Instruction *> GEPs;
2006-
2007-
auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
2008-
auto *Call = CallInst::Create(Func, Arg, "", I);
2009-
Call->takeName(I);
2010-
setAttrByCalledFunc(Call);
2011-
SPIRVDBG(dbgs() << "[lowerBuiltinVariableToCall] " << *I << " -> " << *Call
2012-
<< '\n';)
2013-
I->replaceAllUsesWith(Call);
2014-
};
2015-
2016-
// If HasIndexArg is true, we create 3 built-in calls and insertelement to
2017-
// get 3-element vector filled with ids and replace uses of Load instruction
2018-
// with this vector.
2019-
// If HasIndexArg is false, the result of the Load instruction is the value
2020-
// which should be replaced with the Func.
2021-
// Returns true if Load was replaced, false otherwise.
2022-
auto ReplaceIfLoad = [&](User *I) {
2023-
auto *LD = dyn_cast<LoadInst>(I);
2024-
if (!LD)
2025-
return false;
2026-
std::vector<Value *> Vectors;
2027-
Loads.push_back(LD);
2028-
if (HasIndexArg) {
2029-
auto *VecTy = cast<FixedVectorType>(GVTy);
2030-
Value *EmptyVec = UndefValue::get(VecTy);
2031-
Vectors.push_back(EmptyVec);
2032-
const DebugLoc &DLoc = LD->getDebugLoc();
2033-
for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
2034-
auto *Idx = ConstantInt::get(Type::getInt32Ty(C), I);
2035-
auto *Call = CallInst::Create(Func, {Idx}, "", LD);
2036-
if (DLoc)
2037-
Call->setDebugLoc(DLoc);
2038-
setAttrByCalledFunc(Call);
2039-
auto *Insert = InsertElementInst::Create(Vectors.back(), Call, Idx);
2040-
if (DLoc)
2041-
Insert->setDebugLoc(DLoc);
2042-
Insert->insertAfter(Call);
2043-
Vectors.push_back(Insert);
2044-
}
2045-
2046-
Value *Ptr = LD->getPointerOperand();
2047-
2048-
if (isa<FixedVectorType>(LD->getType())) {
2049-
LD->replaceAllUsesWith(Vectors.back());
2050-
} else {
2051-
auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
2052-
assert(GEP && "Unexpected pattern!");
2053-
assert(GEP->getNumIndices() == 2 && "Unexpected pattern!");
2054-
Value *Idx = GEP->getOperand(2);
2055-
Value *Vec = Vectors.back();
2056-
auto *NewExtract = ExtractElementInst::Create(Vec, Idx);
2057-
NewExtract->insertAfter(cast<Instruction>(Vec));
2058-
LD->replaceAllUsesWith(NewExtract);
2059-
}
2060-
2061-
} else {
2062-
Replace({}, LD);
2063-
}
2064-
2065-
return true;
2066-
};
2067-
2068-
// Go over the GV users, find Load and ExtractElement instructions and
2069-
// replace them with the corresponding function call.
2070-
for (auto *UI : GV->users()) {
2071-
// There might or might not be an addrspacecast instruction.
2072-
if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
2073-
Casts.push_back(ASCast);
2074-
for (auto *CastUser : ASCast->users()) {
2075-
if (ReplaceIfLoad(CastUser))
2076-
continue;
2077-
if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
2078-
GEPs.push_back(GEP);
2079-
for (auto *GEPUser : GEP->users()) {
2080-
if (!ReplaceIfLoad(GEPUser))
2081-
llvm_unreachable("Unexpected pattern!");
2082-
}
2083-
} else {
2084-
llvm_unreachable("Unexpected pattern!");
2085-
}
2086-
}
2087-
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(UI)) {
2088-
GEPs.push_back(GEP);
2089-
for (auto *GEPUser : GEP->users()) {
2090-
if (!ReplaceIfLoad(GEPUser))
2091-
llvm_unreachable("Unexpected pattern!");
2092-
}
2093-
} else if (!ReplaceIfLoad(UI)) {
2094-
llvm_unreachable("Unexpected pattern!");
2095-
}
2096-
}
2097-
2098-
auto Erase = [](std::vector<Instruction *> &ToErase) {
2099-
for (Instruction *I : ToErase) {
2100-
assert(I->hasNUses(0));
2101-
I->eraseFromParent();
2102-
}
2103-
};
2104-
// Order of erasing is important.
2105-
Erase(Loads);
2106-
Erase(GEPs);
2107-
Erase(Casts);
2108-
2072+
replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
21092073
return true;
21102074
}
21112075

test/builtin-vars-gep.ll

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,32 @@ target triple = "spir64"
1414
define spir_func void @foo() {
1515
entry:
1616
%GroupID = alloca [3 x i64], align 8
17-
%0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
18-
%1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
17+
%0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
18+
%1 = getelementptr <3 x i64>, ptr addrspace(4) %0, i64 0, i64 0
1919
; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
20-
; CHECK-LLVM: %[[Ins0:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize0]], i32 0
21-
; CHECK-LLVM: %[[GLocalSize1:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
22-
; CHECK-LLVM: %[[Ins1:[0-9]+]] = insertelement <3 x i64> %[[Ins0]], i64 %[[GLocalSize1]], i32 1
20+
%2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
21+
%3 = getelementptr <3 x i64>, ptr addrspace(4) %2, i64 0, i64 2
22+
%4 = load i64, ptr addrspace(4) %1, align 32
23+
%5 = load i64, ptr addrspace(4) %3, align 8
2324
; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
24-
; CHECK-LLVM: %[[Ins2:[0-9]+]] = insertelement <3 x i64> %[[Ins1]], i64 %[[GLocalSize2]], i32 2
25-
; CHECK-LLVM: %[[Extract:[0-9]+]] = extractelement <3 x i64> %[[Ins2]], i64 0
26-
%2 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
27-
%3 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %2, i64 0, i64 2
28-
%4 = load i64, i64 addrspace(4)* %1, align 32
29-
%5 = load i64, i64 addrspace(4)* %3, align 8
30-
; CHECK-LLVM: %[[GLocalSize01:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
31-
; CHECK-LLVM: %[[Ins01:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize01]], i32 0
32-
; CHECK-LLVM: %[[GLocalSize11:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
33-
; CHECK-LLVM: %[[Ins11:[0-9]+]] = insertelement <3 x i64> %[[Ins01]], i64 %[[GLocalSize11]], i32 1
34-
; CHECK-LLVM: %[[GLocalSize21:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
35-
; CHECK-LLVM: %[[Ins21:[0-9]+]] = insertelement <3 x i64> %[[Ins11]], i64 %[[GLocalSize21]], i32 2
36-
; CHECK-LLVM: %[[Extract1:[0-9]+]] = extractelement <3 x i64> %[[Ins21]], i64 2
37-
; CHECK-LLVM: mul i64 %[[Extract]], %[[Extract1]]
25+
; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
3826
%mul = mul i64 %4, %5
3927
ret void
4028
}
4129

30+
; Function Attrs: alwaysinline convergent nounwind mustprogress
31+
define spir_func void @foo_i8gep() {
32+
entry:
33+
%GroupID = alloca [3 x i64], align 8
34+
%0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
35+
%1 = getelementptr i8, ptr addrspace(4) %0, i64 0
36+
; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
37+
%2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
38+
%3 = getelementptr i8, ptr addrspace(4) %2, i64 16
39+
%4 = load i64, ptr addrspace(4) %1, align 32
40+
%5 = load i64, ptr addrspace(4) %3, align 8
41+
; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
42+
; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
43+
%mul = mul i64 %4, %5
44+
ret void
45+
}

test/transcoding/builtin_vars_gep.ll

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,8 @@ define spir_kernel void @f() {
2323
entry:
2424
%0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32
2525
; CHECK-OCL-IR: %[[#ID1:]] = call spir_func i64 @_Z13get_global_idj(i32 0)
26-
; CHECK-OCL-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0
27-
; CHECK-OCL-IR: %[[#ID2:]] = call spir_func i64 @_Z13get_global_idj(i32 1)
28-
; CHECK-OCL-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1
29-
; CHECK-OCL-IR: %[[#ID3:]] = call spir_func i64 @_Z13get_global_idj(i32 2)
30-
; CHECK-OCL-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2
31-
; CHECK-OCL-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i64 0
3226

3327
; CHECK-SPV-IR: %[[#ID1:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0)
34-
; CHECK-SPV-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0
35-
; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1)
36-
; CHECK-SPV-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1
37-
; CHECK-SPV-IR: %[[#ID3:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2)
38-
; CHECK-SPV-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2
39-
; CHECK-SPV-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i64 0
4028

4129
ret void
4230
}

0 commit comments

Comments
 (0)