Skip to content

[ESIMD] ESIMDOptimizeVecArgCallConv: allow more IR patterns. #6919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,34 @@ inline void assert_and_diag(bool Condition, StringRef Msg,
/// Tells if this value is a bit cast or address space cast.
bool isCast(const Value *V);

/// Tells if this value is a GEP instructions with all zero indices.
bool isZeroGEP(const Value *V);

/// Climbs up the use-def chain of given value until a value which is not a
/// bit cast or address space cast is met.
const Value *stripCasts(const Value *V);
Value *stripCasts(Value *V);

/// Climbs up the use-def chain of given value until a value is met which is
/// neither of:
/// - bit cast
/// - address space cast
/// - GEP instruction with all zero indices
const Value *stripCastsAndZeroGEPs(const Value *V);
Value *stripCastsAndZeroGEPs(Value *V);

/// Collects uses of given value "looking through" casts. I.e. if a use is a
/// cast (chain), then uses of the result of the cast (chain) are collected.
void collectUsesLookThroughCasts(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

/// Collects uses of given pointer-typed value "looking through" casts and GEPs
/// with all zero indices - those pointer transformation instructions which
/// don't change pointed-to value. E.g. if a use is a cast (chain), then uses of
/// the result of the cast (chain) are collected.
void collectUsesLookThroughCastsAndZeroGEPs(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

/// Unwraps a presumably simd* type to extract the native vector type encoded
/// in it. Returns nullptr if failed to do so.
Type *getVectorTyOrNull(StructType *STy);
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/SYCLLowerIR/ESIMD/ESIMDOptimizeVecArgCallConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ getMemTypeIfSameAddressLoadsStores(SmallPtrSetImpl<const Use *> &Uses,
if (Uses.size() == 0) {
return nullptr;
}
Value *Addr = esimd::stripCasts((*Uses.begin())->get());
Value *Addr = esimd::stripCastsAndZeroGEPs((*Uses.begin())->get());

for (const auto *UU : Uses) {
const User *U = UU->getUser();
Expand All @@ -92,7 +92,7 @@ getMemTypeIfSameAddressLoadsStores(SmallPtrSetImpl<const Use *> &Uses,
}

if (const auto *SI = dyn_cast<StoreInst>(U)) {
if (esimd::stripCasts(SI->getPointerOperand()) != Addr) {
if (esimd::stripCastsAndZeroGEPs(SI->getPointerOperand()) != Addr) {
// the pointer escapes into memory
return nullptr;
}
Expand Down Expand Up @@ -167,7 +167,7 @@ Type *getPointedToTypeIfOptimizeable(const Argument &FormalParam) {
// }
{
SmallPtrSet<const Use *, 4> Uses;
esimd::collectUsesLookThroughCasts(&FormalParam, Uses);
esimd::collectUsesLookThroughCastsAndZeroGEPs(&FormalParam, Uses);
bool LoadMet = 0;
bool StoreMet = 0;
ContentT = getMemTypeIfSameAddressLoadsStores(Uses, LoadMet, StoreMet);
Expand Down Expand Up @@ -225,14 +225,14 @@ Type *getPointedToTypeIfOptimizeable(const Argument &FormalParam) {
if (!Call || (Call->getCalledFunction() != F)) {
return nullptr;
}
auto ArgNo = FormalParam.getArgNo();
Value *ActualParam = esimd::stripCasts(Call->getArgOperand(ArgNo));
Value *ActualParam = esimd::stripCastsAndZeroGEPs(
Call->getArgOperand(FormalParam.getArgNo()));

if (!IsSret && !isa<AllocaInst>(ActualParam)) {
return nullptr;
}
SmallPtrSet<const Use *, 4> Uses;
esimd::collectUsesLookThroughCasts(ActualParam, Uses);
esimd::collectUsesLookThroughCastsAndZeroGEPs(ActualParam, Uses);
bool LoadMet = 0;
bool StoreMet = 0;

Expand Down
44 changes: 44 additions & 0 deletions llvm/lib/SYCLLowerIR/ESIMD/ESIMDUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ bool isCast(const Value *V) {
return (Opc == Instruction::BitCast) || (Opc == Instruction::AddrSpaceCast);
}

bool isZeroGEP(const Value *V) {
const auto *GEPI = dyn_cast<GetElementPtrInst>(V);
return GEPI && GEPI->hasAllZeroIndices();
}

const Value *stripCasts(const Value *V) {
if (!V->getType()->isPtrOrPtrVectorTy())
return V;
Expand All @@ -110,6 +115,30 @@ Value *stripCasts(Value *V) {
return const_cast<Value *>(stripCasts(const_cast<const Value *>(V)));
}

const Value *stripCastsAndZeroGEPs(const Value *V) {
if (!V->getType()->isPtrOrPtrVectorTy())
return V;
// Even though we don't look through PHI nodes, we could be called on an
// instruction in an unreachable block, which may be on a cycle.
SmallPtrSet<const Value *, 4> Visited;
Visited.insert(V);

do {
if (isCast(V)) {
V = cast<Operator>(V)->getOperand(0);
} else if (isZeroGEP(V)) {
V = cast<GetElementPtrInst>(V)->getOperand(0);
}
assert(V->getType()->isPtrOrPtrVectorTy() && "Unexpected operand type!");
} while (Visited.insert(V).second);
return V;
}

Value *stripCastsAndZeroGEPs(Value *V) {
return const_cast<Value *>(
stripCastsAndZeroGEPs(const_cast<const Value *>(V)));
}

void collectUsesLookThroughCasts(const Value *V,
SmallPtrSetImpl<const Use *> &Uses) {
for (const Use &U : V->uses()) {
Expand All @@ -123,6 +152,21 @@ void collectUsesLookThroughCasts(const Value *V,
}
}

void collectUsesLookThroughCastsAndZeroGEPs(
const Value *V, SmallPtrSetImpl<const Use *> &Uses) {
assert(V->getType()->isPtrOrPtrVectorTy() && "pointer type expected");

for (const Use &U : V->uses()) {
Value *VV = U.getUser();

if (isCast(VV) || isZeroGEP(VV)) {
collectUsesLookThroughCastsAndZeroGEPs(VV, Uses);
} else {
Uses.insert(&U);
}
}
}

Type *getVectorTyOrNull(StructType *STy) {
Type *Res = nullptr;
while (STy && (STy->getStructNumElements() == 1)) {
Expand Down
32 changes: 32 additions & 0 deletions llvm/test/SYCLLowerIR/ESIMD/vec_arg_call_conv.ll
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,38 @@ entry:
ret void
}

;----- Test4: IR contains all-zero GEP instructions in parameter use-def chains
; Based on Test2.
define dso_local spir_func void @_Z23callee__sret__x_param_x1(ptr addrspace(4) noalias sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, i32 noundef %i, ptr noundef %x, i32 noundef %j) local_unnamed_addr #3 !sycl_explicit_simd !8 !intel_reqd_sub_group_size !9 {
; CHECK: define dso_local spir_func <8 x i32> @_Z23callee__sret__x_param_x1(i32 noundef %{{.*}}, <8 x i32> %{{.*}}, i32 noundef %{{.*}})
entry:
%x.ascast = addrspacecast ptr %x to ptr addrspace(4)
%add = add nsw i32 %i, %j
%splat.splatinsert.i.i.i = insertelement <8 x i32> poison, i32 %add, i64 0
%splat.splat.i.i.i = shufflevector <8 x i32> %splat.splatinsert.i.i.i, <8 x i32> poison, <8 x i32> zeroinitializer
%M_data.i.i.i = getelementptr inbounds %"class.sycl::_V1::ext::intel::esimd::detail::simd_obj_impl.3", ptr addrspace(4) %x.ascast, i64 0, i32 0
%call.i.i.i1 = load <8 x i32>, ptr addrspace(4) %M_data.i.i.i, align 32
%add.i.i.i.i.i = add <8 x i32> %call.i.i.i1, %splat.splat.i.i.i
store <8 x i32> %add.i.i.i.i.i, ptr addrspace(4) %agg.result, align 32
ret void
}

;----- Test4 caller.
; Function Attrs: convergent noinline norecurse
define dso_local spir_func void @_Z21test__sret__x_param_x1(ptr addrspace(4) noalias sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, ptr noundef %x) local_unnamed_addr #3 !sycl_explicit_simd !8 !intel_reqd_sub_group_size !9 {
; CHECK: define dso_local spir_func <8 x i32> @_Z21test__sret__x_param_x1(<8 x i32> %{{.*}})
entry:
%agg.tmp = alloca %"class.sycl::_V1::ext::intel::esimd::simd.2", align 32
%agg.tmp.ascast = addrspacecast ptr %agg.tmp to ptr addrspace(4)
%x.ascast = addrspacecast ptr %x to ptr addrspace(4)
%M_data.i.i.i = getelementptr inbounds %"class.sycl::_V1::ext::intel::esimd::detail::simd_obj_impl.3", ptr addrspace(4) %x.ascast, i64 0, i32 0
%call.i.i.i1 = load <8 x i32>, ptr addrspace(4) %M_data.i.i.i, align 32
store <8 x i32> %call.i.i.i1, ptr addrspace(4) %agg.tmp.ascast, align 32
call spir_func void @_Z23callee__sret__x_param_x1(ptr addrspace(4) sret(%"class.sycl::_V1::ext::intel::esimd::simd.2") align 32 %agg.result, i32 noundef 2, ptr noundef nonnull %agg.tmp, i32 noundef 1) #7
; CHECK: %{{.*}} = call spir_func <8 x i32> @_Z23callee__sret__x_param_x1(i32 2, <8 x i32> %{{.*}}, i32 1)
ret void
}

attributes #0 = { convergent noinline norecurse "frame-pointer"="all" "min-legal-vector-width"="512" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../opaque_ptr.cpp" }
attributes #1 = { alwaysinline convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" }
attributes #2 = { convergent noinline norecurse "frame-pointer"="all" "min-legal-vector-width"="12288" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../opaque_ptr.cpp" }
Expand Down