Skip to content

[SYCL][Fusion][NoSTL] Hide JITContext behind interface #12189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sycl-fusion/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@

add_llvm_library(sycl-fusion
lib/KernelFusion.cpp
lib/JITContext.cpp
lib/translation/KernelTranslation.cpp
lib/translation/SPIRVLLVMTranslation.cpp
lib/fusion/FusionPipeline.cpp
lib/fusion/FusionHelper.cpp
lib/fusion/JITContext.cpp
lib/fusion/ModuleHelper.cpp
lib/helper/ConfigHelper.cpp

Expand Down
19 changes: 8 additions & 11 deletions sycl-fusion/jit-compiler/include/KernelFusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#ifndef SYCL_FUSION_JIT_COMPILER_KERNELFUSION_H
#define SYCL_FUSION_JIT_COMPILER_KERNELFUSION_H

#include "JITContext.h"
#include "Kernel.h"
#include "Options.h"
#include "Parameter.h"
Expand Down Expand Up @@ -55,16 +54,14 @@ class FusionResult {
class KernelFusion {

public:
static FusionResult
fuseKernels(JITContext &JITCtx, Config &&JITConfig,
const std::vector<SYCLKernelInfo> &KernelInformation,
const std::vector<std::string> &KernelsToFuse,
const std::string &FusedKernelName,
jit_compiler::ParamIdentList &Identities,
BarrierFlags BarriersFlags,
const std::vector<jit_compiler::ParameterInternalization>
&Internalization,
const std::vector<jit_compiler::JITConstant> &JITConstants);
static FusionResult fuseKernels(
Config &&JITConfig, const std::vector<SYCLKernelInfo> &KernelInformation,
const std::vector<std::string> &KernelsToFuse,
const std::string &FusedKernelName,
jit_compiler::ParamIdentList &Identities, BarrierFlags BarriersFlags,
const std::vector<jit_compiler::ParameterInternalization>
&Internalization,
const std::vector<jit_compiler::JITConstant> &JITConstants);
};

} // namespace jit_compiler
Expand Down
4 changes: 2 additions & 2 deletions sycl-fusion/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ static bool isTargetFormatSupported(BinaryFormat TargetFormat) {
}

FusionResult KernelFusion::fuseKernels(
JITContext &JITCtx, Config &&JITConfig,
const std::vector<SYCLKernelInfo> &KernelInformation,
Config &&JITConfig, const std::vector<SYCLKernelInfo> &KernelInformation,
const std::vector<std::string> &KernelsToFuse,
const std::string &FusedKernelName, ParamIdentList &Identities,
BarrierFlags BarriersFlags,
Expand Down Expand Up @@ -103,6 +102,7 @@ FusionResult KernelFusion::fuseKernels(
"Fusion output target format not supported by this build");
}

auto &JITCtx = JITContext::getInstance();
bool CachingEnabled = ConfigHelper::get<option::JITEnableCaching>();
CacheKeyT CacheKey{TargetArch,
KernelsToFuse,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//

#ifndef SYCL_FUSION_JIT_COMPILER_HASHING_H
#define SYCL_FUSION_JIT_COMPILER_HASHING_H
#ifndef SYCL_FUSION_JIT_COMPILER_FUSION_HASHING_H
#define SYCL_FUSION_JIT_COMPILER_FUSION_HASHING_H

#include "Kernel.h"
#include "Parameter.h"
Expand Down Expand Up @@ -57,4 +57,4 @@ template <typename... T> struct hash<tuple<T...>> {
};
} // namespace std

#endif // SYCL_FUSION_JIT_COMPILER_HASHING_H
#endif // SYCL_FUSION_JIT_COMPILER_FUSION_HASHING_H
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ BinaryFormat KernelBinary::format() const { return Format; }

JITContext::JITContext() : LLVMCtx{new llvm::LLVMContext}, Binaries{} {}

JITContext::~JITContext() = default;

llvm::LLVMContext *JITContext::getLLVMContext() { return LLVMCtx.get(); }

std::optional<SYCLKernelInfo>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
//
//===----------------------------------------------------------------------===//

#ifndef SYCL_FUSION_JIT_COMPILER_JITCONTEXT_H
#define SYCL_FUSION_JIT_COMPILER_JITCONTEXT_H
#ifndef SYCL_FUSION_JIT_COMPILER_FUSION_JITCONTEXT_H
#define SYCL_FUSION_JIT_COMPILER_FUSION_JITCONTEXT_H

#include <memory>
#include <mutex>
Expand Down Expand Up @@ -61,9 +61,10 @@ class KernelBinary {
class JITContext {

public:
JITContext();

~JITContext();
static JITContext &getInstance() {
static JITContext Instance{};
return Instance;
}

llvm::LLVMContext *getLLVMContext();

Expand All @@ -77,6 +78,13 @@ class JITContext {
void addCacheEntry(CacheKeyT &Identifier, SYCLKernelInfo &Kernel);

private:
JITContext();
~JITContext() = default;
JITContext(const JITContext &) = delete;
JITContext(JITContext &&) = delete;
JITContext &operator=(const JITContext &) = delete;
JITContext &operator=(const JITContext &&) = delete;

// FIXME: Change this to std::shared_mutex after switching to C++17.
using MutexT = std::shared_timed_mutex;

Expand All @@ -96,4 +104,4 @@ class JITContext {
};
} // namespace jit_compiler

#endif // SYCL_FUSION_JIT_COMPILER_JITCONTEXT_H
#endif // SYCL_FUSION_JIT_COMPILER_FUSION_JITCONTEXT_H
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
#ifndef SYCL_FUSION_JIT_COMPILER_TRANSLATION_KERNELTRANSLATION_H
#define SYCL_FUSION_JIT_COMPILER_TRANSLATION_KERNELTRANSLATION_H

#include "JITContext.h"
#include "Kernel.h"
#include "fusion/JITContext.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/Error.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#ifndef SYCL_FUSION_JIT_COMPILER_TRANSLATION_SPIRVLLVMTRANSLATION_H
#define SYCL_FUSION_JIT_COMPILER_TRANSLATION_SPIRVLLVMTRANSLATION_H

#include "JITContext.h"
#include "Kernel.h"
#include "LLVMSPIRVOpts.h"
#include "fusion/JITContext.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include <llvm/Support/Error.h>
Expand Down
6 changes: 1 addition & 5 deletions sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@ namespace sycl {
inline namespace _V1 {
namespace detail {

jit_compiler::jit_compiler() : MJITContext{new ::jit_compiler::JITContext{}} {}

jit_compiler::~jit_compiler() = default;

static ::jit_compiler::BinaryFormat
translateBinaryImageFormat(pi::PiDeviceBinaryType Type) {
switch (Type) {
Expand Down Expand Up @@ -836,7 +832,7 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
JITConfig.set<::jit_compiler::option::JITTargetInfo>(TargetInfo);

auto FusionResult = ::jit_compiler::KernelFusion::fuseKernels(
*MJITContext, std::move(JITConfig), InputKernelInfo, InputKernelNames,
std::move(JITConfig), InputKernelInfo, InputKernelNames,
FusedKernelName.str(), ParamIdentities, BarrierFlags, InternalizeParams,
JITConstants);

Expand Down
6 changes: 2 additions & 4 deletions sycl/source/detail/jit_compiler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class jit_compiler {
}

private:
jit_compiler();
~jit_compiler();
jit_compiler() = default;
~jit_compiler() = default;
jit_compiler(const jit_compiler &) = delete;
jit_compiler(jit_compiler &&) = delete;
jit_compiler &operator=(const jit_compiler &) = delete;
Expand All @@ -61,8 +61,6 @@ class jit_compiler {

// Manages the lifetime of the PI structs for device binaries.
std::vector<DeviceBinariesCollection> JITDeviceBinaries;

std::unique_ptr<::jit_compiler::JITContext> MJITContext;
};

} // namespace detail
Expand Down