Skip to content

Commit 246a4ef

Browse files
zuban32AlexeySachkov
authored andcommitted
Improve constant expressions lowering for function pointers.
Extend constexprs lowering support to lower constant vector of pure function pointers w/o any transformations inside.
1 parent 851a7f2 commit 246a4ef

File tree

4 files changed

+81
-18
lines changed

4 files changed

+81
-18
lines changed

llvm-spirv/lib/SPIRV/SPIRVLowerConstExpr.cpp

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ bool SPIRVLowerConstExpr::runOnModule(Module &Module) {
113113

114114
void SPIRVLowerConstExpr::visit(Module *M) {
115115
for (auto &I : M->functions()) {
116-
std::map<ConstantExpr *, Instruction *> CMap;
117116
std::list<Instruction *> WorkList;
118117
for (auto &BI : I) {
119118
for (auto &II : BI) {
@@ -124,7 +123,10 @@ void SPIRVLowerConstExpr::visit(Module *M) {
124123
while (!WorkList.empty()) {
125124
auto II = WorkList.front();
126125

127-
auto LowerOp = [&II, &FBegin, &I](ConstantExpr *CE) {
126+
auto LowerOp = [&II, &FBegin, &I](Value *V) -> Value * {
127+
if (isa<Function>(V))
128+
return V;
129+
auto *CE = cast<ConstantExpr>(V);
128130
SPIRVDBG(dbgs() << "[lowerConstantExpressions] " << *CE;)
129131
auto ReplInst = CE->getAsInstruction();
130132
auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
@@ -149,25 +151,30 @@ void SPIRVLowerConstExpr::visit(Module *M) {
149151
for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
150152
auto Op = II->getOperand(OI);
151153
auto *Vec = dyn_cast<ConstantVector>(Op);
152-
if (Vec && std::all_of(Vec->op_begin(), Vec->op_end(),
153-
[](Value *V) { return isa<ConstantExpr>(V); })) {
154+
if (Vec && std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
155+
return isa<ConstantExpr>(V) || isa<Function>(V);
156+
})) {
154157
// Expand a vector of constexprs and construct it back with series of
155158
// insertelement instructions
156-
std::list<Instruction *> ReplList;
157-
std::transform(
158-
Vec->op_begin(), Vec->op_end(), std::back_inserter(ReplList),
159-
[LowerOp](Value *V) { return LowerOp(cast<ConstantExpr>(V)); });
159+
std::list<Value *> OpList;
160+
std::transform(Vec->op_begin(), Vec->op_end(),
161+
std::back_inserter(OpList),
162+
[LowerOp](Value *V) { return LowerOp(V); });
160163
Value *Repl = nullptr;
161164
unsigned Idx = 0;
162-
for (auto V : ReplList)
165+
std::list<Instruction *> ReplList;
166+
for (auto V : OpList) {
167+
if (auto *Inst = dyn_cast<Instruction>(V))
168+
ReplList.push_back(Inst);
163169
Repl = InsertElementInst::Create(
164170
(Repl ? Repl : UndefValue::get(Vec->getType())), V,
165171
ConstantInt::get(Type::getInt32Ty(M->getContext()), Idx++), "",
166172
II);
173+
}
167174
II->replaceUsesOfWith(Op, Repl);
168175
WorkList.splice(WorkList.begin(), ReplList);
169176
} else if (auto CE = dyn_cast<ConstantExpr>(Op))
170-
WorkList.push_front(LowerOp(CE));
177+
WorkList.push_front(cast<Instruction>(LowerOp(CE)));
171178
}
172179
}
173180
}

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,13 +1302,22 @@ SPIRVValue *LLVMToSPIRV::transValueWithoutDecoration(Value *V,
13021302

13031303
if (auto Ins = dyn_cast<InsertElementInst>(V)) {
13041304
auto Index = Ins->getOperand(2);
1305-
if (auto Const = dyn_cast<ConstantInt>(Index))
1305+
if (auto Const = dyn_cast<ConstantInt>(Index)) {
1306+
SPIRVValue *InsVal = nullptr;
1307+
if (auto *F = dyn_cast<Function>(Ins->getOperand(1))) {
1308+
if (!BM->checkExtension(ExtensionID::SPV_INTEL_function_pointers,
1309+
SPIRVEC_FunctionPointers, toString(V)))
1310+
return nullptr;
1311+
InsVal = BM->addFunctionPointerINTELInst(
1312+
transType(F->getType()),
1313+
static_cast<SPIRVFunction *>(transValue(F, BB)), BB);
1314+
} else
1315+
InsVal = transValue(Ins->getOperand(1), BB);
13061316
return mapValue(V, BM->addCompositeInsertInst(
1307-
transValue(Ins->getOperand(1), BB),
1308-
transValue(Ins->getOperand(0), BB),
1317+
InsVal, transValue(Ins->getOperand(0), BB),
13091318
std::vector<SPIRVWord>(1, Const->getZExtValue()),
13101319
BB));
1311-
else
1320+
} else
13121321
return mapValue(
13131322
V, BM->addVectorInsertDynamicInst(transValue(Ins->getOperand(0), BB),
13141323
transValue(Ins->getOperand(1), BB),

llvm-spirv/test/constexpr_vector.ll

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
; RUN: llvm-as < %s | llvm-spirv -s | llvm-dis | FileCheck %s --check-prefix=CHECK-LLVM
22

33
; CHECK-LLVM: define dllexport void @vadd() {
4-
; CHECK-LLVM-NEXT: entry:
5-
; CHECK-LLVM-NEXT: %Funcs = alloca <16 x i8>, align 16
4+
; CHECK-LLVM: %Funcs = alloca <16 x i8>, align 16
65
; CHECK-LLVM-NEXT: %0 = ptrtoint i32 (i32)* @_Z2f1u2CMvb32_j to i64
76
; CHECK-LLVM-NEXT: %1 = bitcast i64 %0 to <8 x i8>
87
; CHECK-LLVM-NEXT: %2 = extractelement <8 x i8> %1, i32 0
@@ -40,8 +39,12 @@
4039
; CHECK-LLVM-NEXT: %34 = insertelement <16 x i8> %33, i8 %18, i32 14
4140
; CHECK-LLVM-NEXT: %35 = insertelement <16 x i8> %34, i8 %19, i32 15
4241
; CHECK-LLVM-NEXT: store <16 x i8> %35, <16 x i8>* %Funcs, align 16
43-
; CHECK-LLVM-NEXT: ret void
44-
; CHECK-LLVM-NEXT: }
42+
; CHECK-LLVM: %Funcs1 = alloca <2 x i64>, align 16
43+
; CHECK-LLVM-NEXT: %36 = ptrtoint i32 (i32)* @_Z2f1u2CMvb32_j to i64
44+
; CHECK-LLVM-NEXT: %37 = ptrtoint i32 (i32)* @_Z2f2u2CMvb32_j to i64
45+
; CHECK-LLVM-NEXT: %38 = insertelement <2 x i64> undef, i64 %36, i32 0
46+
; CHECK-LLVM-NEXT: %39 = insertelement <2 x i64> %38, i64 %37, i32 1
47+
; CHECK-LLVM-NEXT: store <2 x i64> %39, <2 x i64>* %Funcs1, align 16
4548

4649
; RUN: llvm-as < %s | llvm-spirv -spirv-text --spirv-ext=+SPV_INTEL_function_pointers | FileCheck %s --check-prefix=CHECK-SPIRV
4750

@@ -115,5 +118,7 @@ define dllexport void @vadd() {
115118
entry:
116119
%Funcs = alloca <16 x i8>, align 16
117120
store <16 x i8> <i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64) to <8 x i8>), i32 0), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64) to <8 x i8>), i32 1), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64) to <8 x i8>), i32 2), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64) to <8 x i8>), i32 3), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64) to <8 x i8>), i32 4), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64) to <8 x i8>), i32 5), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64) to <8 x i8>), i32 6), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64) to <8 x i8>), i32 7), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64) to <8 x i8>), i32 0), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64) to <8 x i8>), i32 1), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64) to <8 x i8>), i32 2), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64) to <8 x i8>), i32 3), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64) to <8 x i8>), i32 4), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64) to <8 x i8>), i32 5), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64) to <8 x i8>), i32 6), i8 extractelement (<8 x i8> bitcast (i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64) to <8 x i8>), i32 7)>, <16 x i8>* %Funcs, align 16
121+
%Funcs1 = alloca <2 x i64>, align 16
122+
store <2 x i64> <i64 ptrtoint (i32 (i32)* @_Z2f1u2CMvb32_j to i64), i64 ptrtoint (i32 (i32)* @_Z2f2u2CMvb32_j to i64)>, <2 x i64>* %Funcs1, align 16
118123
ret void
119124
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
; RUN: llvm-as < %s | llvm-spirv -spirv-text --spirv-ext=+SPV_INTEL_function_pointers | FileCheck %s --check-prefix=CHECK-SPIRV
2+
3+
; CHECK-SPIRV-DAG: 6 Name [[F1:[0-9+]]] "_Z2f1u2CMvb32_j"
4+
; CHECK-SPIRV-DAG: 6 Name [[F2:[0-9+]]] "_Z2f2u2CMvb32_j"
5+
; CHECK-SPIRV-DAG: 4 Name [[Funcs:[0-9]+]] "Funcs"
6+
7+
; CHECK-SPIRV: 4 TypeInt [[TypeInt32:[0-9]+]] 32 0
8+
; CHECK-SPIRV: 4 TypeFunction [[TypeFunc:[0-9]+]] [[TypeInt32]] [[TypeInt32]]
9+
; CHECK-SPIRV: 4 TypePointer [[TypePtr:[0-9]+]] {{[0-9]+}} [[TypeFunc]]
10+
; CHECK-SPIRV: 4 TypeVector [[TypeVec:[0-9]+]] [[TypePtr]] [[TypeInt32]]
11+
; CHECK-SPIRV: 3 Undef [[TypeVec]] [[TypeUndef:[0-9]+]]
12+
13+
; CHECK-SPIRV: 4 FunctionPointerINTEL [[TypePtr]] [[F1Ptr:[0-9]+]] [[F1]]
14+
; CHECK-SPIRV: 6 CompositeInsert [[TypeVec]] [[NewVec0:[0-9]+]] [[F1Ptr]] [[TypeUndef]] 0
15+
; CHECK-SPIRV: 4 FunctionPointerINTEL [[TypePtr]] [[F2Ptr:[0-9]+]] [[F2]]
16+
; CHECK-SPIRV: 6 CompositeInsert [[TypeVec]] [[NewVec1:[0-9]+]] [[F2Ptr]] [[NewVec0]] 1
17+
; CHECK-SPIRV: 5 Store [[Funcs]] [[NewVec1]] [[TypeInt32]] {{[0-9+]}}
18+
19+
20+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
21+
target triple = "spir-unknown-unknown"
22+
23+
; Function Attrs: noinline norecurse nounwind readnone
24+
define internal i32 @_Z2f1u2CMvb32_j(i32 %x) {
25+
entry:
26+
ret i32 %x
27+
}
28+
; Function Attrs: noinline norecurse nounwind readnone
29+
define internal i32 @_Z2f2u2CMvb32_j(i32 %x) {
30+
entry:
31+
ret i32 %x
32+
}
33+
34+
; Function Attrs: noinline nounwind
35+
define dllexport void @vadd() {
36+
entry:
37+
%Funcs = alloca <2 x i32 (i32)*>, align 16
38+
%0 = insertelement <2 x i32 (i32)*> undef, i32 (i32)* @_Z2f1u2CMvb32_j, i32 0
39+
%1 = insertelement <2 x i32 (i32)*> %0, i32 (i32)* @_Z2f2u2CMvb32_j, i32 1
40+
store <2 x i32 (i32)*> %1, <2 x i32 (i32)*>* %Funcs, align 16
41+
ret void
42+
}

0 commit comments

Comments
 (0)