Skip to content

Commit 4926454

Browse files
authored
[ESIMD] ESIMDOptimizeVecArgCallConv: allow more IR patterns. (#6919)
Allow all-zero GEPs in optimized ptr param use-def chains. Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent 53d9c7b commit 4926454

File tree

4 files changed

+100
-6
lines changed

4 files changed

+100
-6
lines changed

llvm/include/llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,34 @@ inline void assert_and_diag(bool Condition, StringRef Msg,
6868
/// Tells if this value is a bit cast or address space cast.
6969
bool isCast(const Value *V);
7070

71+
/// Tells if this value is a GEP instructions with all zero indices.
72+
bool isZeroGEP(const Value *V);
73+
7174
/// Climbs up the use-def chain of given value until a value which is not a
7275
/// bit cast or address space cast is met.
7376
const Value *stripCasts(const Value *V);
7477
Value *stripCasts(Value *V);
7578

79+
/// Climbs up the use-def chain of given value until a value is met which is
80+
/// neither of:
81+
/// - bit cast
82+
/// - address space cast
83+
/// - GEP instruction with all zero indices
84+
const Value *stripCastsAndZeroGEPs(const Value *V);
85+
Value *stripCastsAndZeroGEPs(Value *V);
86+
7687
/// Collects uses of given value "looking through" casts. I.e. if a use is a
7788
/// cast (chain), then uses of the result of the cast (chain) are collected.
7889
void collectUsesLookThroughCasts(const Value *V,
7990
SmallPtrSetImpl<const Use *> &Uses);
8091

92+
/// Collects uses of given pointer-typed value "looking through" casts and GEPs
93+
/// with all zero indices - those pointer transformation instructions which
94+
/// don't change pointed-to value. E.g. if a use is a cast (chain), then uses of
95+
/// the result of the cast (chain) are collected.
96+
void collectUsesLookThroughCastsAndZeroGEPs(const Value *V,
97+
SmallPtrSetImpl<const Use *> &Uses);
98+
8199
/// Unwraps a presumably simd* type to extract the native vector type encoded
82100
/// in it. Returns nullptr if failed to do so.
83101
Type *getVectorTyOrNull(StructType *STy);

llvm/lib/SYCLLowerIR/ESIMD/ESIMDOptimizeVecArgCallConv.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ getMemTypeIfSameAddressLoadsStores(SmallPtrSetImpl<const Use *> &Uses,
7676
if (Uses.size() == 0) {
7777
return nullptr;
7878
}
79-
Value *Addr = esimd::stripCasts((*Uses.begin())->get());
79+
Value *Addr = esimd::stripCastsAndZeroGEPs((*Uses.begin())->get());
8080

8181
for (const auto *UU : Uses) {
8282
const User *U = UU->getUser();
@@ -92,7 +92,7 @@ getMemTypeIfSameAddressLoadsStores(SmallPtrSetImpl<const Use *> &Uses,
9292
}
9393

9494
if (const auto *SI = dyn_cast<StoreInst>(U)) {
95-
if (esimd::stripCasts(SI->getPointerOperand()) != Addr) {
95+
if (esimd::stripCastsAndZeroGEPs(SI->getPointerOperand()) != Addr) {
9696
// the pointer escapes into memory
9797
return nullptr;
9898
}
@@ -167,7 +167,7 @@ Type *getPointedToTypeIfOptimizeable(const Argument &FormalParam) {
167167
// }
168168
{
169169
SmallPtrSet<const Use *, 4> Uses;
170-
esimd::collectUsesLookThroughCasts(&FormalParam, Uses);
170+
esimd::collectUsesLookThroughCastsAndZeroGEPs(&FormalParam, Uses);
171171
bool LoadMet = 0;
172172
bool StoreMet = 0;
173173
ContentT = getMemTypeIfSameAddressLoadsStores(Uses, LoadMet, StoreMet);
@@ -225,14 +225,14 @@ Type *getPointedToTypeIfOptimizeable(const Argument &FormalParam) {
225225
if (!Call || (Call->getCalledFunction() != F)) {
226226
return nullptr;
227227
}
228-
auto ArgNo = FormalParam.getArgNo();
229-
Value *ActualParam = esimd::stripCasts(Call->getArgOperand(ArgNo));
228+
Value *ActualParam = esimd::stripCastsAndZeroGEPs(
229+
Call->getArgOperand(FormalParam.getArgNo()));
230230

231231
if (!IsSret && !isa<AllocaInst>(ActualParam)) {
232232
return nullptr;
233233
}
234234
SmallPtrSet<const Use *, 4> Uses;
235-
esimd::collectUsesLookThroughCasts(ActualParam, Uses);
235+
esimd::collectUsesLookThroughCastsAndZeroGEPs(ActualParam, Uses);
236236
bool LoadMet = 0;
237237
bool StoreMet = 0;
238238

llvm/lib/SYCLLowerIR/ESIMD/ESIMDUtils.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ bool isCast(const Value *V) {
8989
return (Opc == Instruction::BitCast) || (Opc == Instruction::AddrSpaceCast);
9090
}
9191

92+
bool isZeroGEP(const Value *V) {
93+
const auto *GEPI = dyn_cast<GetElementPtrInst>(V);
94+
return GEPI && GEPI->hasAllZeroIndices();
95+
}
96+
9297
const Value *stripCasts(const Value *V) {
9398
if (!V->getType()->isPtrOrPtrVectorTy())
9499
return V;
@@ -110,6 +115,30 @@ Value *stripCasts(Value *V) {
110115
return const_cast<Value *>(stripCasts(const_cast<const Value *>(V)));
111116
}
112117

118+
const Value *stripCastsAndZeroGEPs(const Value *V) {
119+
if (!V->getType()->isPtrOrPtrVectorTy())
120+
return V;
121+
// Even though we don't look through PHI nodes, we could be called on an
122+
// instruction in an unreachable block, which may be on a cycle.
123+
SmallPtrSet<const Value *, 4> Visited;
124+
Visited.insert(V);
125+
126+
do {
127+
if (isCast(V)) {
128+
V = cast<Operator>(V)->getOperand(0);
129+
} else if (isZeroGEP(V)) {
130+
V = cast<GetElementPtrInst>(V)->getOperand(0);
131+
}
132+
assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!");
133+
} while (Visited.insert(V).second);
134+
return V;
135+
}
136+
137+
Value *stripCastsAndZeroGEPs(Value *V) {
138+
return const_cast<Value *>(
139+
stripCastsAndZeroGEPs(const_cast<const Value *>(V)));
140+
}
141+
113142
void collectUsesLookThroughCasts(const Value *V,
114143
SmallPtrSetImpl<const Use *> &Uses) {
115144
for (const Use &U : V->uses()) {
@@ -123,6 +152,21 @@ void collectUsesLookThroughCasts(const Value *V,
123152
}
124153
}
125154

155+
void collectUsesLookThroughCastsAndZeroGEPs(
156+
const Value *V, SmallPtrSetImpl<const Use *> &Uses) {
157+
assert(V->getType()->isPtrOrPtrVectorTy() && "pointer type expected");
158+
159+
for (const Use &U : V->uses()) {
160+
Value *VV = U.getUser();
161+
162+
if (isCast(VV) || isZeroGEP(VV)) {
163+
collectUsesLookThroughCastsAndZeroGEPs(VV, Uses);
164+
} else {
165+
Uses.insert(&U);
166+
}
167+
}
168+
}
169+
126170
Type *getVectorTyOrNull(StructType *STy) {
127171
Type *Res = nullptr;
128172
while (STy && (STy->getStructNumElements() == 1)) {

llvm/test/SYCLLowerIR/ESIMD/vec_arg_call_conv.ll

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,38 @@ entry:
255255
ret void
256256
}
257257

258+
;----- Test4: IR contains all-zero GEP instructions in parameter use-def chains
259+
; Based on Test2.
260+
define dso_local spir_func void @_Z23callee__sret__x_param_x1(ptr addrspace(4) noalias sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, i32 noundef %i, ptr noundef %x, i32 noundef %j) local_unnamed_addr #3 !sycl_explicit_simd !8 !intel_reqd_sub_group_size !9 {
261+
; CHECK: define dso_local spir_func <8 x i32> @_Z23callee__sret__x_param_x1(i32 noundef %{{.*}}, <8 x i32> %{{.*}}, i32 noundef %{{.*}})
262+
entry:
263+
%x.ascast = addrspacecast ptr %x to ptr addrspace(4)
264+
%add = add nsw i32 %i, %j
265+
%splat.splatinsert.i.i.i = insertelement <8 x i32> poison, i32 %add, i64 0
266+
%splat.splat.i.i.i = shufflevector <8 x i32> %splat.splatinsert.i.i.i, <8 x i32> poison, <8 x i32> zeroinitializer
267+
%M_data.i.i.i = getelementptr inbounds %"class.sycl::_V1::ext::intel::esimd::detail::simd_obj_impl.3", ptr addrspace(4) %x.ascast, i64 0, i32 0
268+
%call.i.i.i1 = load <8 x i32>, ptr addrspace(4) %M_data.i.i.i, align 32
269+
%add.i.i.i.i.i = add <8 x i32> %call.i.i.i1, %splat.splat.i.i.i
270+
store <8 x i32> %add.i.i.i.i.i, ptr addrspace(4) %agg.result, align 32
271+
ret void
272+
}
273+
274+
;----- Test4 caller.
275+
; Function Attrs: convergent noinline norecurse
276+
define dso_local spir_func void @_Z21test__sret__x_param_x1(ptr addrspace(4) noalias sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, ptr noundef %x) local_unnamed_addr #3 !sycl_explicit_simd !8 !intel_reqd_sub_group_size !9 {
277+
; CHECK: define dso_local spir_func <8 x i32> @_Z21test__sret__x_param_x1(<8 x i32> %{{.*}})
278+
entry:
279+
%agg.tmp = alloca %"class.sycl::_V1::ext::intel::esimd::simd.2", align 32
280+
%agg.tmp.ascast = addrspacecast ptr %agg.tmp to ptr addrspace(4)
281+
%x.ascast = addrspacecast ptr %x to ptr addrspace(4)
282+
%M_data.i.i.i = getelementptr inbounds %"class.sycl::_V1::ext::intel::esimd::detail::simd_obj_impl.3", ptr addrspace(4) %x.ascast, i64 0, i32 0
283+
%call.i.i.i1 = load <8 x i32>, ptr addrspace(4) %M_data.i.i.i, align 32
284+
store <8 x i32> %call.i.i.i1, ptr addrspace(4) %agg.tmp.ascast, align 32
285+
call spir_func void @_Z23callee__sret__x_param_x1(ptr addrspace(4) sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, i32 noundef 2, ptr noundef nonnull %agg.tmp, i32 noundef 1) #7
286+
; CHECK: %{{.*}} = call spir_func <8 x i32> @_Z23callee__sret__x_param_x1(i32 2, <8 x i32> %{{.*}}, i32 1)
287+
ret void
288+
}
289+
258290
attributes #0 = { convergent noinline norecurse "frame-pointer"="all" "min-legal-vector-width"="512" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../opaque_ptr.cpp" }
259291
attributes #1 = { alwaysinline convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
260292
attributes #2 = { convergent noinline norecurse "frame-pointer"="all" "min-legal-vector-width"="12288" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../opaque_ptr.cpp" }

0 commit comments

Comments
 (0)