Skip to content

Commit fc71302

Browse files
authored
Add support of new IR pattern to transOCLBuiltinFromVariable (#1091)
* Add support of new IR pattern to transOCLBuiltinFromVariable Some LLVM optimizations started generation of ascast -> gep -> load sequence, this patch adds support for it.
1 parent f8a2e49 commit fc71302

File tree

2 files changed

+94
-12
lines changed

2 files changed

+94
-12
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,20 @@ Value *SPIRVToLLVM::mapFunction(SPIRVFunction *BF, Function *F) {
290290
// %5 = insertelement <3 x i64> %3, i64 %4, i32 2
291291
// %c = extractelement <3 x i64> %5, i32 idx
292292
// %d = extractelement <3 x i64> %5, i32 idx
293+
//
294+
// Replace the following pattern:
295+
// %0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to
296+
// <3 x i64> addrspace(4)*
297+
// %1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
298+
// %2 = load i64, i64 addrspace(4)* %1, align 32
299+
// With:
300+
// %0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
301+
// %1 = insertelement <3 x i64> undef, i64 %0, i32 0
302+
// %2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
303+
// %3 = insertelement <3 x i64> %1, i64 %2, i32 1
304+
// %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
305+
// %5 = insertelement <3 x i64> %3, i64 %4, i32 2
306+
// %6 = extractelement <3 x i64> %5, i32 0
293307
bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
294308
SPIRVBuiltinVariableKind Kind) {
295309
std::string FuncName;
@@ -300,7 +314,8 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
300314
} else {
301315
FuncName = std::string(GV->getName());
302316
}
303-
Type *ReturnTy = GV->getType()->getPointerElementType();
317+
Type *GVTy = GV->getType()->getPointerElementType();
318+
Type *ReturnTy = GVTy;
304319
// Some SPIR-V builtin variables are translated to a function with an index
305320
// argument.
306321
bool HasIndexArg =
@@ -324,9 +339,9 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
324339
}
325340

326341
// Collect instructions in these containers to remove them later.
327-
std::vector<Instruction *> Extracts;
328342
std::vector<Instruction *> Loads;
329343
std::vector<Instruction *> Casts;
344+
std::vector<Instruction *> GEPs;
330345

331346
auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
332347
auto Call = CallInst::Create(Func, Arg, "", I);
@@ -342,12 +357,15 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
342357
// with this vector.
343358
// If HasIndexArg is false, the result of the Load instruction is the value
344359
// which should be replaced with the Func.
345-
auto FindAndReplace = [&](LoadInst *LD) {
360+
// Returns true if Load was replaced, false otherwise.
361+
auto ReplaceIfLoad = [&](User *I) {
362+
auto *LD = dyn_cast<LoadInst>(I);
363+
if (!LD)
364+
return false;
346365
std::vector<Value *> Vectors;
347366
Loads.push_back(LD);
348367
if (HasIndexArg) {
349-
auto *VecTy = cast<FixedVectorType>(
350-
LD->getPointerOperandType()->getPointerElementType());
368+
auto *VecTy = cast<FixedVectorType>(GVTy);
351369
Value *EmptyVec = UndefValue::get(VecTy);
352370
Vectors.push_back(EmptyVec);
353371
const DebugLoc &DLoc = LD->getDebugLoc();
@@ -363,10 +381,27 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
363381
Insert->insertAfter(Call);
364382
Vectors.push_back(Insert);
365383
}
366-
LD->replaceAllUsesWith(Vectors.back());
384+
385+
Value *Ptr = LD->getPointerOperand();
386+
387+
if (isa<FixedVectorType>(Ptr->getType()->getPointerElementType())) {
388+
LD->replaceAllUsesWith(Vectors.back());
389+
} else {
390+
auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
391+
assert(GEP && "Unexpected pattern!");
392+
assert(GEP->getNumIndices() == 2 && "Unexpected pattern!");
393+
Value *Idx = GEP->getOperand(2);
394+
Value *Vec = Vectors.back();
395+
auto *NewExtract = ExtractElementInst::Create(Vec, Idx);
396+
NewExtract->insertAfter(cast<Instruction>(Vec));
397+
LD->replaceAllUsesWith(NewExtract);
398+
}
399+
367400
} else {
368401
Replace({}, LD);
369402
}
403+
404+
return true;
370405
};
371406

372407
// Go over the GV users, find Load and ExtractElement instructions and
@@ -376,13 +411,19 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
376411
if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
377412
Casts.push_back(ASCast);
378413
for (auto *CastUser : ASCast->users()) {
379-
if (auto *LD = dyn_cast<LoadInst>(CastUser)) {
380-
FindAndReplace(LD);
414+
if (ReplaceIfLoad(CastUser))
415+
continue;
416+
if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
417+
GEPs.push_back(GEP);
418+
for (auto *GEPUser : GEP->users()) {
419+
if (!ReplaceIfLoad(GEPUser))
420+
llvm_unreachable("Unexpected pattern!");
421+
}
422+
} else {
423+
llvm_unreachable("Unexpected pattern!");
381424
}
382425
}
383-
} else if (auto *LD = dyn_cast<LoadInst>(UI)) {
384-
FindAndReplace(LD);
385-
} else {
426+
} else if (!ReplaceIfLoad(UI)) {
386427
llvm_unreachable("Unexpected pattern!");
387428
}
388429
}
@@ -394,8 +435,8 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
394435
}
395436
};
396437
// Order of erasing is important.
397-
Erase(Extracts);
398438
Erase(Loads);
439+
Erase(GEPs);
399440
Erase(Casts);
400441

401442
return true;

test/builtin-vars-gep.ll

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
; RUN: llvm-as < %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv
3+
; RUN: spirv-val %t.spv
4+
; RUN: llvm-spirv %t.spv -r -o %t.rev.bc
5+
; RUN: llvm-dis %t.rev.bc -o - | FileCheck %s --check-prefix=CHECK-LLVM
6+
7+
source_filename = "builtin-vars-gep.ll"
8+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir64"
10+
11+
@__spirv_BuiltInWorkgroupSize = external addrspace(1) constant <3 x i64>, align 32
12+
13+
; Function Attrs: alwaysinline convergent nounwind mustprogress
14+
define spir_func void @foo() {
15+
entry:
16+
%GroupID = alloca [3 x i64], align 8
17+
%0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
18+
%1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
19+
; CHECK-LLVM: %[[GLocalSize0:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
20+
; CHECK-LLVM: %[[Ins0:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize0]], i32 0
21+
; CHECK-LLVM: %[[GLocalSize1:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
22+
; CHECK-LLVM: %[[Ins1:[0-9]+]] = insertelement <3 x i64> %[[Ins0]], i64 %[[GLocalSize1]], i32 1
23+
; CHECK-LLVM: %[[GLocalSize2:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
24+
; CHECK-LLVM: %[[Ins2:[0-9]+]] = insertelement <3 x i64> %[[Ins1]], i64 %[[GLocalSize2]], i32 2
25+
; CHECK-LLVM: %[[Extract:[0-9]+]] = extractelement <3 x i64> %[[Ins2]], i64 0
26+
%2 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to <3 x i64> addrspace(4)*
27+
%3 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %2, i64 0, i64 2
28+
%4 = load i64, i64 addrspace(4)* %1, align 32
29+
%5 = load i64, i64 addrspace(4)* %3, align 8
30+
; CHECK-LLVM: %[[GLocalSize01:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 0) #1
31+
; CHECK-LLVM: %[[Ins01:[0-9]+]] = insertelement <3 x i64> undef, i64 %[[GLocalSize01]], i32 0
32+
; CHECK-LLVM: %[[GLocalSize11:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 1) #1
33+
; CHECK-LLVM: %[[Ins11:[0-9]+]] = insertelement <3 x i64> %[[Ins01]], i64 %[[GLocalSize11]], i32 1
34+
; CHECK-LLVM: %[[GLocalSize21:[0-9]+]] = call spir_func i64 @_Z14get_local_sizej(i32 2) #1
35+
; CHECK-LLVM: %[[Ins21:[0-9]+]] = insertelement <3 x i64> %[[Ins11]], i64 %[[GLocalSize21]], i32 2
36+
; CHECK-LLVM: %[[Extract1:[0-9]+]] = extractelement <3 x i64> %[[Ins21]], i64 2
37+
; CHECK-LLVM: mul i64 %[[Extract]], %[[Extract1]]
38+
%mul = mul i64 %4, %5
39+
ret void
40+
}
41+

0 commit comments

Comments
 (0)