Skip to content

Commit 7e15df5

Browse files
jcranmer-intelagainull
authored andcommitted
Adjust builtin variable tracking to support i8 geps. (#1892)
The existing logic for the replacement of builtin variables with calls to functions relies on relatively brittle tracking that is broken when opaque pointers is turned on, and will be even more thoroughly broken if/when typed geps are replaced with i8 geps or ptradd. This patch replaces that logic with a less brittle variant that is able to handle any sequence of bitcast, gep, or addrspacecast instructions between the global variable and the ultimate load instruction. It still will error out if the variable is used in too insane of a fashion (say, trying to load an i32 out of the i64, or a misaligned vector type). Original commit: KhronosGroup/SPIRV-LLVM-Translator@4911067
1 parent 5493588 commit 7e15df5

File tree

4 files changed

+97
-141
lines changed

4 files changed

+97
-141
lines changed

llvm-spirv/lib/SPIRV/SPIRVInternal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ std::string decodeSPIRVTypeName(StringRef Name,
980980
SmallVectorImpl<std::string> &Strs);
981981

982982
// Copy attributes from function to call site.
983-
void setAttrByCalledFunc(CallInst *Call);
983+
CallInst *setAttrByCalledFunc(CallInst *Call);
984984
bool isSPIRVBuiltinVariable(GlobalVariable *GV, SPIRVBuiltinVariableKind *Kind);
985985
// Transform builtin variable from GlobalVariable to builtin call.
986986
// e.g.

llvm-spirv/lib/SPIRV/SPIRVUtil.cpp

Lines changed: 73 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1909,14 +1909,15 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
19091909
return true;
19101910
}
19111911

1912-
void setAttrByCalledFunc(CallInst *Call) {
1912+
CallInst *setAttrByCalledFunc(CallInst *Call) {
19131913
Function *F = Call->getCalledFunction();
19141914
assert(F);
19151915
if (F->isIntrinsic()) {
1916-
return;
1916+
return Call;
19171917
}
19181918
Call->setCallingConv(F->getCallingConv());
19191919
Call->setAttributes(F->getAttributes());
1920+
return Call;
19201921
}
19211922

19221923
bool isSPIRVBuiltinVariable(GlobalVariable *GV,
@@ -1966,6 +1967,75 @@ bool isSPIRVBuiltinVariable(GlobalVariable *GV,
19661967
// %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
19671968
// %5 = insertelement <3 x i64> %3, i64 %4, i32 2
19681969
// %6 = extractelement <3 x i64> %5, i32 0
1970+
1971+
/// Recursively look through the uses of a global variable, including casts or
1972+
/// gep offsets, to find all loads of the variable. Gep offsets that are non-0
1973+
/// are accumulated in the AccumulatedOffset parameter, which will eventually be
1974+
/// used to figure out which index of a variable is being used.
1975+
static void replaceUsesOfBuiltinVar(Value *V, const APInt &AccumulatedOffset,
1976+
Function *ReplacementFunc) {
1977+
const DataLayout &DL = ReplacementFunc->getParent()->getDataLayout();
1978+
SmallVector<Instruction *, 4> InstsToRemove;
1979+
for (User *U : V->users()) {
1980+
if (auto *Cast = dyn_cast<CastInst>(U)) {
1981+
replaceUsesOfBuiltinVar(Cast, AccumulatedOffset, ReplacementFunc);
1982+
InstsToRemove.push_back(Cast);
1983+
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(U)) {
1984+
APInt NewOffset = AccumulatedOffset.sextOrTrunc(
1985+
DL.getIndexSizeInBits(GEP->getPointerAddressSpace()));
1986+
if (!GEP->accumulateConstantOffset(DL, NewOffset))
1987+
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
1988+
replaceUsesOfBuiltinVar(GEP, NewOffset, ReplacementFunc);
1989+
InstsToRemove.push_back(GEP);
1990+
} else if (auto *Load = dyn_cast<LoadInst>(U)) {
1991+
// Figure out which index the accumulated offset corresponds to. If we
1992+
// have a weird offset (e.g., trying to load byte 7), bail out.
1993+
Type *ScalarTy = ReplacementFunc->getReturnType();
1994+
APInt Index;
1995+
uint64_t Remainder;
1996+
APInt::udivrem(AccumulatedOffset, ScalarTy->getScalarSizeInBits() / 8,
1997+
Index, Remainder);
1998+
if (Remainder != 0)
1999+
llvm_unreachable("Illegal GEP of a SPIR-V builtin variable");
2000+
2001+
IRBuilder<> Builder(Load);
2002+
Value *Replacement;
2003+
if (ReplacementFunc->getFunctionType()->getNumParams() == 0) {
2004+
if (Load->getType() != ScalarTy)
2005+
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
2006+
Replacement =
2007+
setAttrByCalledFunc(Builder.CreateCall(ReplacementFunc, {}));
2008+
} else {
2009+
// The function has an index parameter.
2010+
if (auto *VecTy = dyn_cast<FixedVectorType>(Load->getType())) {
2011+
if (!Index.isZero())
2012+
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
2013+
Replacement = UndefValue::get(VecTy);
2014+
for (unsigned I = 0; I < VecTy->getNumElements(); I++) {
2015+
Replacement = Builder.CreateInsertElement(
2016+
Replacement,
2017+
setAttrByCalledFunc(
2018+
Builder.CreateCall(ReplacementFunc, {Builder.getInt32(I)})),
2019+
Builder.getInt32(I));
2020+
}
2021+
} else if (Load->getType() == ScalarTy) {
2022+
Replacement = setAttrByCalledFunc(Builder.CreateCall(
2023+
ReplacementFunc, {Builder.getInt32(Index.getZExtValue())}));
2024+
} else {
2025+
llvm_unreachable("Illegal load type of a SPIR-V builtin variable");
2026+
}
2027+
}
2028+
Load->replaceAllUsesWith(Replacement);
2029+
InstsToRemove.push_back(Load);
2030+
} else {
2031+
llvm_unreachable("Illegal use of a SPIR-V builtin variable");
2032+
}
2033+
}
2034+
2035+
for (Instruction *I : InstsToRemove)
2036+
I->eraseFromParent();
2037+
}
2038+
19692039
bool lowerBuiltinVariableToCall(GlobalVariable *GV,
19702040
SPIRVBuiltinVariableKind Kind) {
19712041
// There might be dead constant users of GV (for example, SPIRVLowerConstExpr
@@ -2001,113 +2071,7 @@ bool lowerBuiltinVariableToCall(GlobalVariable *GV,
20012071
Func->setDoesNotAccessMemory();
20022072
}
20032073

2004-
// Collect instructions in these containers to remove them later.
2005-
std::vector<Instruction *> Loads;
2006-
std::vector<Instruction *> Casts;
2007-
std::vector<Instruction *> GEPs;
2008-
2009-
auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
2010-
auto *Call = CallInst::Create(Func, Arg, "", I);
2011-
Call->takeName(I);
2012-
setAttrByCalledFunc(Call);
2013-
SPIRVDBG(dbgs() << "[lowerBuiltinVariableToCall] " << *I << " -> " << *Call
2014-
<< '\n';)
2015-
I->replaceAllUsesWith(Call);
2016-
};
2017-
2018-
// If HasIndexArg is true, we create 3 built-in calls and insertelement to
2019-
// get 3-element vector filled with ids and replace uses of Load instruction
2020-
// with this vector.
2021-
// If HasIndexArg is false, the result of the Load instruction is the value
2022-
// which should be replaced with the Func.
2023-
// Returns true if Load was replaced, false otherwise.
2024-
auto ReplaceIfLoad = [&](User *I) {
2025-
auto *LD = dyn_cast<LoadInst>(I);
2026-
if (!LD)
2027-
return false;
2028-
std::vector<Value *> Vectors;
2029-
Loads.push_back(LD);
2030-
if (HasIndexArg) {
2031-
auto *VecTy = cast<FixedVectorType>(GVTy);
2032-
Value *EmptyVec = UndefValue::get(VecTy);
2033-
Vectors.push_back(EmptyVec);
2034-
const DebugLoc &DLoc = LD->getDebugLoc();
2035-
for (unsigned I = 0; I < VecTy->getNumElements(); ++I) {
2036-
auto *Idx = ConstantInt::get(Type::getInt32Ty(C), I);
2037-
auto *Call = CallInst::Create(Func, {Idx}, "", LD);
2038-
if (DLoc)
2039-
Call->setDebugLoc(DLoc);
2040-
setAttrByCalledFunc(Call);
2041-
auto *Insert = InsertElementInst::Create(Vectors.back(), Call, Idx);
2042-
if (DLoc)
2043-
Insert->setDebugLoc(DLoc);
2044-
Insert->insertAfter(Call);
2045-
Vectors.push_back(Insert);
2046-
}
2047-
2048-
Value *Ptr = LD->getPointerOperand();
2049-
2050-
if (isa<FixedVectorType>(LD->getType())) {
2051-
LD->replaceAllUsesWith(Vectors.back());
2052-
} else {
2053-
auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
2054-
assert(GEP && "Unexpected pattern!");
2055-
assert(GEP->getNumIndices() == 2 && "Unexpected pattern!");
2056-
Value *Idx = GEP->getOperand(2);
2057-
Value *Vec = Vectors.back();
2058-
auto *NewExtract = ExtractElementInst::Create(Vec, Idx);
2059-
NewExtract->insertAfter(cast<Instruction>(Vec));
2060-
LD->replaceAllUsesWith(NewExtract);
2061-
}
2062-
2063-
} else {
2064-
Replace({}, LD);
2065-
}
2066-
2067-
return true;
2068-
};
2069-
2070-
// Go over the GV users, find Load and ExtractElement instructions and
2071-
// replace them with the corresponding function call.
2072-
for (auto *UI : GV->users()) {
2073-
// There might or might not be an addrspacecast instruction.
2074-
if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
2075-
Casts.push_back(ASCast);
2076-
for (auto *CastUser : ASCast->users()) {
2077-
if (ReplaceIfLoad(CastUser))
2078-
continue;
2079-
if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
2080-
GEPs.push_back(GEP);
2081-
for (auto *GEPUser : GEP->users()) {
2082-
if (!ReplaceIfLoad(GEPUser))
2083-
llvm_unreachable("Unexpected pattern!");
2084-
}
2085-
} else {
2086-
llvm_unreachable("Unexpected pattern!");
2087-
}
2088-
}
2089-
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(UI)) {
2090-
GEPs.push_back(GEP);
2091-
for (auto *GEPUser : GEP->users()) {
2092-
if (!ReplaceIfLoad(GEPUser))
2093-
llvm_unreachable("Unexpected pattern!");
2094-
}
2095-
} else if (!ReplaceIfLoad(UI)) {
2096-
llvm_unreachable("Unexpected pattern!");
2097-
}
2098-
}
2099-
2100-
auto Erase = [](std::vector<Instruction *> &ToErase) {
2101-
for (Instruction *I : ToErase) {
2102-
assert(I->hasNUses(0));
2103-
I->eraseFromParent();
2104-
}
2105-
};
2106-
// Order of erasing is important.
2107-
Erase(Loads);
2108-
Erase(GEPs);
2109-
Erase(Casts);
2110-
2074+
replaceUsesOfBuiltinVar(GV, APInt(64, 0), Func);
21112075
return true;
21122076
}
21132077

llvm-spirv/test/builtin-vars-gep.ll

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,32 @@ target triple = "spir64"
1414
define spir_func void @foo() {
1515
entry:
1616
%GroupID = alloca [3 x i64], align 8
17-
%0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
18-
%1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
17+
%0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
18+
%1 = getelementptr <3 x i64>, ptr addrspace(4) %0, i64 0, i64 0
1919
; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
20-
; CHECK-LLVM: %[[Ins0:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize0]], i32 0
21-
; CHECK-LLVM: %[[GLocalSize1:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
22-
; CHECK-LLVM: %[[Ins1:[0-9]+]] = insertelement <3 x i64> %[[Ins0]], i64 %[[GLocalSize1]], i32 1
20+
%2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
21+
%3 = getelementptr <3 x i64>, ptr addrspace(4) %2, i64 0, i64 2
22+
%4 = load i64, ptr addrspace(4) %1, align 32
23+
%5 = load i64, ptr addrspace(4) %3, align 8
2324
; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
24-
; CHECK-LLVM: %[[Ins2:[0-9]+]] = insertelement <3 x i64> %[[Ins1]], i64 %[[GLocalSize2]], i32 2
25-
; CHECK-LLVM: %[[Extract:[0-9]+]] = extractelement <3 x i64> %[[Ins2]], i64 0
26-
%2 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
27-
%3 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %2, i64 0, i64 2
28-
%4 = load i64, i64 addrspace(4)* %1, align 32
29-
%5 = load i64, i64 addrspace(4)* %3, align 8
30-
; CHECK-LLVM: %[[GLocalSize01:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
31-
; CHECK-LLVM: %[[Ins01:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize01]], i32 0
32-
; CHECK-LLVM: %[[GLocalSize11:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
33-
; CHECK-LLVM: %[[Ins11:[0-9]+]] = insertelement <3 x i64> %[[Ins01]], i64 %[[GLocalSize11]], i32 1
34-
; CHECK-LLVM: %[[GLocalSize21:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
35-
; CHECK-LLVM: %[[Ins21:[0-9]+]] = insertelement <3 x i64> %[[Ins11]], i64 %[[GLocalSize21]], i32 2
36-
; CHECK-LLVM: %[[Extract1:[0-9]+]] = extractelement <3 x i64> %[[Ins21]], i64 2
37-
; CHECK-LLVM: mul i64 %[[Extract]], %[[Extract1]]
25+
; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
3826
%mul = mul i64 %4, %5
3927
ret void
4028
}
4129

30+
; Function Attrs: alwaysinline convergent nounwind mustprogress
31+
define spir_func void @foo_i8gep() {
32+
entry:
33+
%GroupID = alloca [3 x i64], align 8
34+
%0 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
35+
%1 = getelementptr i8, ptr addrspace(4) %0, i64 0
36+
; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
37+
%2 = addrspacecast ptr addrspace(1) @__spirv_BuiltInWorkgroupSize to ptr addrspace(4)
38+
%3 = getelementptr i8, ptr addrspace(4) %2, i64 16
39+
%4 = load i64, ptr addrspace(4) %1, align 32
40+
%5 = load i64, ptr addrspace(4) %3, align 8
41+
; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
42+
; CHECK-LLVM: mul i64 %[[GLocalSize0]], %[[GLocalSize2]]
43+
%mul = mul i64 %4, %5
44+
ret void
45+
}

llvm-spirv/test/transcoding/builtin_vars_gep.ll

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,8 @@ define spir_kernel void @f() {
2323
entry:
2424
%0 = load i64, i64 addrspace(1)* getelementptr (<3 x i64>, <3 x i64> addrspace(1)* @__spirv_BuiltInGlobalInvocationId, i64 0, i64 0), align 32
2525
; CHECK-OCL-IR: %[[#ID1:]] = call spir_func i64 @_Z13get_global_idj(i32 0)
26-
; CHECK-OCL-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0
27-
; CHECK-OCL-IR: %[[#ID2:]] = call spir_func i64 @_Z13get_global_idj(i32 1)
28-
; CHECK-OCL-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1
29-
; CHECK-OCL-IR: %[[#ID3:]] = call spir_func i64 @_Z13get_global_idj(i32 2)
30-
; CHECK-OCL-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2
31-
; CHECK-OCL-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i32 0
3226

3327
; CHECK-SPV-IR: %[[#ID1:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 0)
34-
; CHECK-SPV-IR: %[[#VEC1:]] = insertelement <3 x i64> undef, i64 %[[#ID1]], i32 0
35-
; CHECK-SPV-IR: %[[#ID2:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 1)
36-
; CHECK-SPV-IR: %[[#VEC2:]] = insertelement <3 x i64> %[[#VEC1]], i64 %[[#ID2]], i32 1
37-
; CHECK-SPV-IR: %[[#ID3:]] = call spir_func i64 @_Z33__spirv_BuiltInGlobalInvocationIdi(i32 2)
38-
; CHECK-SPV-IR: %[[#VEC3:]] = insertelement <3 x i64> %[[#VEC2]], i64 %[[#ID3]], i32 2
39-
; CHECK-SPV-IR: %[[#]] = extractelement <3 x i64> %[[#VEC3]], i32 0
4028

4129
ret void
4230
}

0 commit comments

Comments
 (0)