Skip to content

[ESIMD] Add semi-dynamic SLM allocation - esimd::experimental::slm_allocator. #7759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 40 additions & 32 deletions llvm/include/llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
// Utility functions for processing ESIMD code.
//===----------------------------------------------------------------------===//

#pragma once

#include "llvm/GenXIntrinsics/GenXMetadata.h"

#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Demangle/ItaniumDemangle.h"
Expand All @@ -17,7 +21,7 @@ namespace llvm {
namespace esimd {

constexpr char ESIMD_MARKER_MD[] = "sycl_explicit_simd";
// This is the prefixes of the names generated from
constexpr char GENX_KERNEL_METADATA[] = "genx.kernels";
// sycl/ext/oneapi/experimental/invoke_simd.hpp::__builtin_invoke_simd
// overloads instantiations:
constexpr char INVOKE_SIMD_PREF[] = "_Z33__regcall3____builtin_invoke_simd";
Expand All @@ -39,37 +43,6 @@ inline void assert_and_diag(bool Condition, StringRef Msg,
}
}

/// Tells if this value is a bit cast or address space cast.
bool isCast(const Value *V);

/// Tells if this value is a GEP instructions with all zero indices.
bool isZeroGEP(const Value *V);

/// Climbs up the use-def chain of given value until a value which is not a
/// bit cast or address space cast is met.
const Value *stripCasts(const Value *V);
Value *stripCasts(Value *V);

/// Climbs up the use-def chain of given value until a value is met which is
/// neither of:
/// - bit cast
/// - address space cast
/// - GEP instruction with all zero indices
const Value *stripCastsAndZeroGEPs(const Value *V);
Value *stripCastsAndZeroGEPs(Value *V);

/// Collects uses of given value "looking through" casts. I.e. if a use is a
/// cast (chain), then uses of the result of the cast (chain) are collected.
void collectUsesLookThroughCasts(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

/// Collects uses of given pointer-typed value "looking through" casts and GEPs
/// with all zero indices - those pointer transformation instructions which
/// don't change pointed-to value. E.g. if a use is a cast (chain), then uses of
/// the result of the cast (chain) are collected.
void collectUsesLookThroughCastsAndZeroGEPs(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

/// Unwraps a presumably simd* type to extract the native vector type encoded
/// in it. Returns nullptr if failed to do so.
Type *getVectorTyOrNull(StructType *STy);
Expand Down Expand Up @@ -104,5 +77,40 @@ class SimpleAllocator {
~SimpleAllocator() { reset(); }
};

// Turn a MDNode into llvm::value or its subclass.
// Return nullptr if the underlying value has type mismatch.
template <typename Ty = llvm::Value> Ty *getValue(llvm::Metadata *M) {
if (auto VM = dyn_cast<llvm::ValueAsMetadata>(M))
if (auto V = dyn_cast<Ty>(VM->getValue()))
return V;
return nullptr;
}

// Turn given Value into metadata.
inline llvm::Metadata *getMetadata(llvm::Value *V) {
return llvm::ValueAsMetadata::get(V);
}

// A functor which updates ESIMD kernel's uint64_t metadata in case it is less
// than the given one. Used in callgraph traversal to update nbarriers or SLM
// size metadata. Update is performed by the '()' operator and happens only
// when given function matches one of the kernels - thus, only reachable kernels
// are updated.
struct UpdateUint64MetaDataToMaxValue {
Module &M;
// The uint64_t metadata key to update.
genx::KernelMDOp Key;
// The new metadata value. Must be greater than the old for update to happen.
uint64_t NewVal;
// Pre-selected nodes from GENX_KERNEL_METADATA which can only potentially be
// updated.
SmallVector<MDNode *, 4> CandidatesToUpdate;

UpdateUint64MetaDataToMaxValue(Module &M, genx::KernelMDOp Key,
uint64_t NewVal);

void operator()(Function *F) const;
};

} // namespace esimd
} // namespace llvm
4 changes: 4 additions & 0 deletions llvm/include/llvm/SYCLLowerIR/ESIMD/LowerESIMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class ESIMDOptimizeVecArgCallConvPass
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
};

// Lowers calls __esimd_slm_alloc, __esimd_slm_free and __esimd_slm_init APIs.
// See more details in the .cpp file.
size_t lowerSLMReservationCalls(Module &M);

// Lowers calls to __esimd_set_kernel_properties
class SYCLLowerESIMDKernelPropsPass
: public PassInfoMixin<SYCLLowerESIMDKernelPropsPass> {
Expand Down
58 changes: 55 additions & 3 deletions llvm/include/llvm/SYCLLowerIR/SYCLUtils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//===------------ SYCLUtils.h - SYCL utility functions
//------------------===//
//===------------ SYCLUtils.h - SYCL utility functions --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -10,14 +9,20 @@
//===----------------------------------------------------------------------===//
#pragma once

#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Operator.h"

#include <functional>

namespace llvm {
namespace sycl {
namespace utils {
using CallGraphNodeAction = std::function<void(Function *)>;
constexpr char ATTR_SYCL_MODULE_ID[] = "sycl-module-id";

using CallGraphNodeAction = ::std::function<void(Function *)>;
using CallGraphFunctionFilter =
std::function<bool(const Instruction *, const Function *)>;

Expand Down Expand Up @@ -63,6 +68,53 @@ void traverseCallgraphUp(
traverseCallgraphUp(F, CallGraphNodeAction(ActionF), Visited,
ErrorOnNonCallUse, functionFilter);
}

/// Tells if this value is a bit cast or address space cast.
bool isCast(const Value *V);

/// Tells if this value is a GEP instructions with all zero indices.
bool isZeroGEP(const Value *V);

/// Climbs up the use-def chain of given value until a value which is not a
/// bit cast or address space cast is met.
const Value *stripCasts(const Value *V);
Value *stripCasts(Value *V);

/// Climbs up the use-def chain of given value until a value is met which is
/// neither of:
/// - bit cast
/// - address space cast
/// - GEP instruction with all zero indices
const Value *stripCastsAndZeroGEPs(const Value *V);
Value *stripCastsAndZeroGEPs(Value *V);

/// Collects uses of given value "looking through" casts. I.e. if a use is a
/// cast (chain), then uses of the result of the cast (chain) are collected.
void collectUsesLookThroughCasts(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

/// Collects uses of given pointer-typed value "looking through" casts and GEPs
/// with all zero indices - those pointer transformation instructions which
/// don't change pointed-to value. E.g. if a use is a cast (chain), then uses of
/// the result of the cast (chain) are collected.
void collectUsesLookThroughCastsAndZeroGEPs(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

void collectUsesLookThroughCasts(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

void collectUsesLookThroughCastsAndZeroGEPs(const Value *V,
SmallPtrSetImpl<const Use *> &Uses);

bool collectPossibleStoredVals(
Value *Addr, SmallPtrSetImpl<Value *> &Vals,
std::function<bool(const CallInst *)> EscapesIfAddrIsArgOf =
[](const CallInst *) { return true; });

inline bool isSYCLExternalFunction(const Function *F) {
return F->hasFnAttribute(ATTR_SYCL_MODULE_ID);
}

} // namespace utils
} // namespace sycl
} // namespace llvm
11 changes: 6 additions & 5 deletions llvm/lib/SYCLLowerIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@ set_property(GLOBAL PROPERTY LLVMGenXIntrinsics_SOURCE_PROP ${LLVMGenXIntrinsics
set_property(GLOBAL PROPERTY LLVMGenXIntrinsics_BINARY_PROP ${LLVMGenXIntrinsics_BINARY_DIR})

add_llvm_component_library(LLVMSYCLLowerIR
ESIMD/LowerESIMD.cpp
ESIMD/LowerESIMDVLoadVStore.cpp
ESIMD/LowerESIMDVecArg.cpp
ESIMD/ESIMDOptimizeVecArgCallConv.cpp
ESIMD/ESIMDUtils.cpp
ESIMD/ESIMDVerifier.cpp
ESIMD/LowerESIMD.cpp
ESIMD/LowerESIMDKernelAttrs.cpp
ESIMD/ESIMDOptimizeVecArgCallConv.cpp
ESIMD/LowerESIMDVecArg.cpp
ESIMD/LowerESIMDVLoadVStore.cpp
ESIMD/LowerESIMDSlmReservation.cpp
LowerInvokeSimd.cpp
LowerKernelProps.cpp
LowerWGScope.cpp
LowerWGLocalMemory.cpp
LowerWGScope.cpp
MutatePrintfAddrspace.cpp
SYCLPropagateAspectsUsage.cpp
SYCLUtils.cpp
Expand Down
11 changes: 6 additions & 5 deletions llvm/lib/SYCLLowerIR/ESIMD/ESIMDOptimizeVecArgCallConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "llvm/SYCLLowerIR/ESIMD/ESIMDUtils.h"
#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"
#include "llvm/SYCLLowerIR/SYCLUtils.h"

#include "llvm/GenXIntrinsics/GenXIntrinsics.h"

Expand Down Expand Up @@ -76,7 +77,7 @@ getMemTypeIfSameAddressLoadsStores(SmallPtrSetImpl<const Use *> &Uses,
if (Uses.size() == 0) {
return nullptr;
}
Value *Addr = esimd::stripCastsAndZeroGEPs((*Uses.begin())->get());
Value *Addr = sycl::utils::stripCastsAndZeroGEPs((*Uses.begin())->get());

for (const auto *UU : Uses) {
const User *U = UU->getUser();
Expand All @@ -92,7 +93,7 @@ getMemTypeIfSameAddressLoadsStores(SmallPtrSetImpl<const Use *> &Uses,
}

if (const auto *SI = dyn_cast<StoreInst>(U)) {
if (esimd::stripCastsAndZeroGEPs(SI->getPointerOperand()) != Addr) {
if (sycl::utils::stripCastsAndZeroGEPs(SI->getPointerOperand()) != Addr) {
// the pointer escapes into memory
return nullptr;
}
Expand Down Expand Up @@ -167,7 +168,7 @@ Type *getPointedToTypeIfOptimizeable(const Argument &FormalParam) {
// }
{
SmallPtrSet<const Use *, 4> Uses;
esimd::collectUsesLookThroughCastsAndZeroGEPs(&FormalParam, Uses);
sycl::utils::collectUsesLookThroughCastsAndZeroGEPs(&FormalParam, Uses);
bool LoadMet = 0;
bool StoreMet = 0;
ContentT = getMemTypeIfSameAddressLoadsStores(Uses, LoadMet, StoreMet);
Expand Down Expand Up @@ -225,14 +226,14 @@ Type *getPointedToTypeIfOptimizeable(const Argument &FormalParam) {
if (!Call || (Call->getCalledFunction() != F)) {
return nullptr;
}
Value *ActualParam = esimd::stripCastsAndZeroGEPs(
Value *ActualParam = sycl::utils::stripCastsAndZeroGEPs(
Call->getArgOperand(FormalParam.getArgNo()));

if (!IsSret && !isa<AllocaInst>(ActualParam)) {
return nullptr;
}
SmallPtrSet<const Use *, 4> Uses;
esimd::collectUsesLookThroughCastsAndZeroGEPs(ActualParam, Uses);
sycl::utils::collectUsesLookThroughCastsAndZeroGEPs(ActualParam, Uses);
bool LoadMet = 0;
bool StoreMet = 0;

Expand Down
Loading