Skip to content

Commit 1dd8ea1

Browse files
authored
[SYCL][Fusion][NoSTL] Use free functions for configuration management (#12445)
Untangle the interaction between options and the `Config` class, and hide `Config` (which needs to store options in a map) from the KF interface. The idea is to introduce free functions to interact with the `thread_local` instance of `Config` held by the existing `ConfigHelper` class, instead of letting the client construct the `Config` object and handing it to over `ConfigHelper` for storing it. _This PR is part of a series of changes to remove uses of STL classes in the kernel fusion interface to prevent ABI issues in the future._ Signed-off-by: Julian Oppermann <[email protected]>
1 parent acf89a6 commit 1dd8ea1

File tree

5 files changed

+87
-70
lines changed

5 files changed

+87
-70
lines changed

sycl-fusion/jit-compiler/include/KernelFusion.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,26 @@ class FusionResult {
5454
class KernelFusion {
5555

5656
public:
57-
static FusionResult fuseKernels(
58-
Config &&JITConfig, const std::vector<SYCLKernelInfo> &KernelInformation,
59-
const char *FusedKernelName, jit_compiler::ParamIdentList &Identities,
60-
BarrierFlags BarriersFlags,
61-
const std::vector<jit_compiler::ParameterInternalization>
62-
&Internalization,
63-
const std::vector<jit_compiler::JITConstant> &JITConstants);
57+
static FusionResult
58+
fuseKernels(const std::vector<SYCLKernelInfo> &KernelInformation,
59+
const char *FusedKernelName,
60+
jit_compiler::ParamIdentList &Identities,
61+
BarrierFlags BarriersFlags,
62+
const std::vector<jit_compiler::ParameterInternalization>
63+
&Internalization,
64+
const std::vector<jit_compiler::JITConstant> &JITConstants);
65+
66+
/// Clear all previously set options.
67+
static void resetConfiguration();
68+
69+
/// Set \p Opt to the value built in-place by \p As.
70+
template <typename Opt, typename... Args> static void set(Args &&...As) {
71+
set(new Opt{std::forward<Args>(As)...});
72+
}
73+
74+
private:
75+
/// Take ownership of \p Option and include it in the current configuration.
76+
static void set(OptionPtrBase *Option);
6477
};
6578

6679
} // namespace jit_compiler

sycl-fusion/jit-compiler/include/Options.h

Lines changed: 18 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -11,77 +11,43 @@
1111

1212
#include "Kernel.h"
1313

14-
#include <memory>
15-
#include <unordered_map>
16-
1714
namespace jit_compiler {
1815

1916
enum OptionID { VerboseOutput, EnableCaching, TargetDeviceInfo };
2017

21-
class OptionPtrBase {};
22-
23-
class Config {
18+
class OptionPtrBase {
19+
protected:
20+
explicit OptionPtrBase(OptionID Id) : Id(Id) {}
2421

2522
public:
26-
template <typename Opt> void set(typename Opt::ValueType Value) {
27-
Opt::set(*this, Value);
28-
}
29-
30-
template <typename Opt> typename Opt::ValueType get() {
31-
return Opt::get(*this);
32-
}
33-
34-
private:
35-
std::unordered_map<OptionID, std::unique_ptr<OptionPtrBase>> OptionValues;
36-
37-
void set(OptionID ID, std::unique_ptr<OptionPtrBase> Value) {
38-
OptionValues[ID] = std::move(Value);
39-
}
40-
41-
OptionPtrBase *get(OptionID ID) {
42-
if (OptionValues.count(ID)) {
43-
return OptionValues.at(ID).get();
44-
}
45-
return nullptr;
46-
}
47-
48-
template <OptionID ID, typename T> friend class OptionBase;
23+
const OptionID Id;
4924
};
5025

51-
template <OptionID ID, typename T> class OptionBase : public OptionPtrBase {
52-
public:
26+
template <OptionID ID, typename T> struct OptionBase : public OptionPtrBase {
27+
static constexpr OptionID Id = ID;
5328
using ValueType = T;
5429

55-
protected:
56-
static void set(Config &Cfg, T Value) {
57-
Cfg.set(ID,
58-
std::unique_ptr<OptionBase<ID, T>>{new OptionBase<ID, T>{Value}});
59-
}
60-
61-
static const T get(Config &Cfg) {
62-
auto *ConfigValue = Cfg.get(ID);
63-
if (!ConfigValue) {
64-
return T{};
65-
}
66-
return static_cast<OptionBase<ID, T> *>(ConfigValue)->Value;
67-
}
30+
template <typename... Args>
31+
explicit OptionBase(Args &&...As)
32+
: OptionPtrBase{ID}, Value{std::forward<Args>(As)...} {}
6833

69-
private:
7034
T Value;
71-
72-
OptionBase(T Val) : Value{Val} {}
73-
74-
friend Config;
7535
};
7636

7737
namespace option {
7838

79-
struct JITEnableVerbose : public OptionBase<OptionID::VerboseOutput, bool> {};
39+
struct JITEnableVerbose : public OptionBase<OptionID::VerboseOutput, bool> {
40+
using OptionBase::OptionBase;
41+
};
8042

81-
struct JITEnableCaching : public OptionBase<OptionID::EnableCaching, bool> {};
43+
struct JITEnableCaching : public OptionBase<OptionID::EnableCaching, bool> {
44+
using OptionBase::OptionBase;
45+
};
8246

8347
struct JITTargetInfo
84-
: public OptionBase<OptionID::TargetDeviceInfo, TargetInfo> {};
48+
: public OptionBase<OptionID::TargetDeviceInfo, TargetInfo> {
49+
using OptionBase::OptionBase;
50+
};
8551

8652
} // namespace option
8753
} // namespace jit_compiler

sycl-fusion/jit-compiler/lib/KernelFusion.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,11 @@ static bool isTargetFormatSupported(BinaryFormat TargetFormat) {
7171
}
7272

7373
FusionResult KernelFusion::fuseKernels(
74-
Config &&JITConfig, const std::vector<SYCLKernelInfo> &KernelInformation,
74+
const std::vector<SYCLKernelInfo> &KernelInformation,
7575
const char *FusedKernelName, ParamIdentList &Identities,
7676
BarrierFlags BarriersFlags,
7777
const std::vector<jit_compiler::ParameterInternalization> &Internalization,
7878
const std::vector<jit_compiler::JITConstant> &Constants) {
79-
// Initialize the configuration helper to make the options for this invocation
80-
// available (on a per-thread basis).
81-
ConfigHelper::setConfig(std::move(JITConfig));
8279

8380
std::vector<std::string> KernelsToFuse;
8481
llvm::transform(KernelInformation, std::back_inserter(KernelsToFuse),
@@ -194,3 +191,9 @@ FusionResult KernelFusion::fuseKernels(
194191

195192
return FusionResult{FusedKernelInfo};
196193
}
194+
195+
void KernelFusion::resetConfiguration() { ConfigHelper::reset(); }
196+
197+
void KernelFusion::set(OptionPtrBase *Option) {
198+
ConfigHelper::getConfig().set(Option);
199+
}

sycl-fusion/jit-compiler/lib/helper/ConfigHelper.h

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,44 @@
1111

1212
#include "Options.h"
1313

14+
#include <memory>
15+
#include <unordered_map>
16+
1417
namespace jit_compiler {
1518

19+
class Config {
20+
21+
public:
22+
template <typename Opt> typename Opt::ValueType get() const {
23+
using T = typename Opt::ValueType;
24+
25+
auto *ConfigValue = get(Opt::Id);
26+
if (!ConfigValue) {
27+
return T{};
28+
}
29+
return static_cast<const OptionBase<Opt::Id, T> *>(ConfigValue)->Value;
30+
}
31+
32+
void set(OptionPtrBase *Option) {
33+
OptionValues[Option->Id] = std::unique_ptr<OptionPtrBase>(Option);
34+
}
35+
36+
private:
37+
std::unordered_map<OptionID, std::unique_ptr<OptionPtrBase>> OptionValues;
38+
39+
const OptionPtrBase *get(OptionID ID) const {
40+
const auto Iter = OptionValues.find(ID);
41+
if (Iter == OptionValues.end()) {
42+
return nullptr;
43+
}
44+
return Iter->second.get();
45+
}
46+
};
47+
1648
class ConfigHelper {
1749
public:
18-
static void setConfig(Config &&JITConfig) { Cfg = std::move(JITConfig); }
50+
static void reset() { Cfg = {}; }
51+
static Config &getConfig() { return Cfg; }
1952

2053
template <typename Opt> static typename Opt::ValueType get() {
2154
return Cfg.get<Opt>();

sycl/source/detail/jit_compiler.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -823,20 +823,22 @@ jit_compiler::fuseKernels(QueueImplPtr Queue,
823823

824824
static size_t FusedKernelNameIndex = 0;
825825
auto FusedKernelName = "fused_" + std::to_string(FusedKernelNameIndex++);
826-
::jit_compiler::Config JITConfig;
826+
::jit_compiler::KernelFusion::resetConfiguration();
827827
bool DebugEnabled =
828828
detail::SYCLConfig<detail::SYCL_RT_WARNING_LEVEL>::get() > 0;
829-
JITConfig.set<::jit_compiler::option::JITEnableVerbose>(DebugEnabled);
830-
JITConfig.set<::jit_compiler::option::JITEnableCaching>(
829+
::jit_compiler::KernelFusion::set<::jit_compiler::option::JITEnableVerbose>(
830+
DebugEnabled);
831+
::jit_compiler::KernelFusion::set<::jit_compiler::option::JITEnableCaching>(
831832
detail::SYCLConfig<detail::SYCL_ENABLE_FUSION_CACHING>::get());
832833

833834
::jit_compiler::TargetInfo TargetInfo = getTargetInfo(Queue);
834835
::jit_compiler::BinaryFormat TargetFormat = TargetInfo.getFormat();
835-
JITConfig.set<::jit_compiler::option::JITTargetInfo>(TargetInfo);
836+
::jit_compiler::KernelFusion::set<::jit_compiler::option::JITTargetInfo>(
837+
std::move(TargetInfo));
836838

837839
auto FusionResult = ::jit_compiler::KernelFusion::fuseKernels(
838-
std::move(JITConfig), InputKernelInfo, FusedKernelName.c_str(),
839-
ParamIdentities, BarrierFlags, InternalizeParams, JITConstants);
840+
InputKernelInfo, FusedKernelName.c_str(), ParamIdentities, BarrierFlags,
841+
InternalizeParams, JITConstants);
840842

841843
if (FusionResult.failed()) {
842844
if (DebugEnabled) {

0 commit comments

Comments
 (0)