Skip to content

[SYCL] Add support for JIT-ing in AMD and NVIDIA backends #14280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4158578
[SYCL] Introduce SYCL_JIT_KERNELS env var
jchlanda Mar 15, 2024
9bf827f
[SYCL] Extend kernel fusion with JIT-ing
jchlanda Mar 15, 2024
0bc1146
[SYCL] Define JIT pipeline and introduce materializer pass
jchlanda Mar 15, 2024
1aaf3e5
[SYCL] Add functionality to create/cache/retrieve materialized kernels
jchlanda Jun 20, 2024
8221f0f
[SYCL] Introduce SYCL_JIT_TARGET_{CPU,FEATURES} env variables
jchlanda Jun 24, 2024
434bc9f
[SYCL] Document SYCL_JIT_{KERNELS,TARGET_CPU,TARGET_FEATURES} env vars
jchlanda Jun 27, 2024
c3c9abc
PR feedback
jchlanda Jul 5, 2024
df9133f
PR feedback 2
jchlanda Jul 10, 2024
b725f00
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 10, 2024
c23a986
Merge fixes
jchlanda Jul 10, 2024
bb1e3f9
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 10, 2024
93b07a7
Merge fixes
jchlanda Jul 11, 2024
cf4ec36
Debug printout fix
jchlanda Jul 11, 2024
f86d998
PR feedback 3
jchlanda Jul 11, 2024
e6169ce
Correct assert
jchlanda Jul 12, 2024
5385ad5
strstr returns a pointer on success
jchlanda Jul 12, 2024
0a3ecf1
Use default pipeline
jchlanda Jul 15, 2024
1bc67ce
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 15, 2024
8bef1d4
Docs tidy-up
jchlanda Jul 16, 2024
681da06
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 16, 2024
391ce43
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 16, 2024
1fb459a
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 22, 2024
5f3d2c9
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 22, 2024
94b5ad5
Constexpr debug output in program manager
jchlanda Jul 24, 2024
a876daa
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 24, 2024
21c814f
build fix
jchlanda Jul 24, 2024
87f1a87
orfer of includes
jchlanda Jul 24, 2024
509e3e6
JIT e2e test
jchlanda Jul 24, 2024
d8c6499
clang format the test
jchlanda Jul 24, 2024
66bd110
include fix in the test
jchlanda Jul 24, 2024
a403149
Merge remote-tracking branch 'upstream/sycl' into jakub/jit_spec_const
jchlanda Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 21 additions & 16 deletions sycl-fusion/jit-compiler/include/KernelFusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@

namespace jit_compiler {

class FusionResult {
class JITResult {
public:
explicit FusionResult(const char *ErrorMessage)
: Type{FusionResultType::FAILED}, KernelInfo{},
ErrorMessage{ErrorMessage} {}
explicit JITResult(const char *ErrorMessage)
: Type{JITResultType::FAILED}, KernelInfo{}, ErrorMessage{ErrorMessage} {}

explicit FusionResult(const SYCLKernelInfo &KernelInfo, bool Cached = false)
: Type{(Cached) ? FusionResultType::CACHED : FusionResultType::NEW},
explicit JITResult(const SYCLKernelInfo &KernelInfo, bool Cached = false)
: Type{(Cached) ? JITResultType::CACHED : JITResultType::NEW},
KernelInfo(KernelInfo), ErrorMessage{} {}

bool failed() const { return Type == FusionResultType::FAILED; }
bool failed() const { return Type == JITResultType::FAILED; }

bool cached() const { return Type == FusionResultType::CACHED; }
bool cached() const { return Type == JITResultType::CACHED; }

const char *getErrorMessage() const {
assert(failed() && "No error message present");
Expand All @@ -44,9 +43,9 @@ class FusionResult {
}

private:
enum class FusionResultType { FAILED, CACHED, NEW };
enum class JITResultType { FAILED, CACHED, NEW };

FusionResultType Type;
JITResultType Type;
SYCLKernelInfo KernelInfo;
sycl::detail::string ErrorMessage;
};
Expand All @@ -56,12 +55,18 @@ extern "C" {
#ifdef __clang__
#pragma clang diagnostic ignored "-Wreturn-type-c-linkage"
#endif // __clang__
FusionResult fuseKernels(View<SYCLKernelInfo> KernelInformation,
const char *FusedKernelName,
View<ParameterIdentity> Identities,
BarrierFlags BarriersFlags,
View<ParameterInternalization> Internalization,
View<jit_compiler::JITConstant> JITConstants);
JITResult fuseKernels(View<SYCLKernelInfo> KernelInformation,
const char *FusedKernelName,
View<ParameterIdentity> Identities,
BarrierFlags BarriersFlags,
View<ParameterInternalization> Internalization,
View<jit_compiler::JITConstant> JITConstants);

JITResult materializeSpecConstants(const char *KernelName,
jit_compiler::SYCLKernelBinaryInfo &BinInfo,
View<unsigned char> SpecConstBlob,
const char *TargetCPU,
const char *TargetFeatures);

/// Clear all previously set options.
void resetJITConfiguration();
Expand Down
1 change: 1 addition & 0 deletions sycl-fusion/jit-compiler/ld-version-script.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
global:
/* Export the library entry points */
fuseKernels;
materializeSpecConstants;
resetJITConfiguration;
addToJITConfiguration;

Expand Down
74 changes: 60 additions & 14 deletions sycl-fusion/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ using namespace jit_compiler;
using FusedFunction = helper::FusionHelper::FusedFunction;
using FusedFunctionList = std::vector<FusedFunction>;

static FusionResult errorToFusionResult(llvm::Error &&Err,
const std::string &Msg) {
static JITResult errorToFusionResult(llvm::Error &&Err,
const std::string &Msg) {
std::stringstream ErrMsg;
ErrMsg << Msg << "\nDetailed information:\n";
llvm::handleAllErrors(std::move(Err),
Expand All @@ -34,7 +34,7 @@ static FusionResult errorToFusionResult(llvm::Error &&Err,
// compiled without exception support.
ErrMsg << "\t" << StrErr.getMessage() << "\n";
});
return FusionResult{ErrMsg.str().c_str()};
return JITResult{ErrMsg.str().c_str()};
}

static std::vector<jit_compiler::NDRange>
Expand Down Expand Up @@ -70,11 +70,58 @@ static bool isTargetFormatSupported(BinaryFormat TargetFormat) {
}
}

extern "C" FusionResult
fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
View<ParameterIdentity> Identities, BarrierFlags BarriersFlags,
View<ParameterInternalization> Internalization,
View<jit_compiler::JITConstant> Constants) {
extern "C" JITResult
materializeSpecConstants(const char *KernelName,
jit_compiler::SYCLKernelBinaryInfo &BinInfo,
View<unsigned char> SpecConstBlob,
const char *TargetCPU, const char *TargetFeatures) {
auto &JITCtx = JITContext::getInstance();

TargetInfo TargetInfo = ConfigHelper::get<option::JITTargetInfo>();
BinaryFormat TargetFormat = TargetInfo.getFormat();
if (TargetFormat != BinaryFormat::PTX &&
TargetFormat != BinaryFormat::AMDGCN) {
return JITResult("Output target format not supported by this build. "
"Available targets are: PTX or AMDGCN.");
}

::jit_compiler::SYCLKernelInfo KernelInfo{
KernelName, ::jit_compiler::SYCLArgumentDescriptor{},
::jit_compiler::NDRange{}, BinInfo};
SYCLModuleInfo ModuleInfo;
ModuleInfo.kernels().insert(ModuleInfo.kernels().end(), KernelInfo);
// Load all input kernels from their respective modules into a single
// LLVM IR module.
llvm::Expected<std::unique_ptr<llvm::Module>> ModOrError =
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
ModuleInfo.kernels());
if (auto Error = ModOrError.takeError()) {
return errorToFusionResult(std::move(Error), "Failed to load kernels");
}
std::unique_ptr<llvm::Module> NewMod = std::move(*ModOrError);
if (!fusion::FusionPipeline::runMaterializerPasses(
*NewMod, SpecConstBlob.to<llvm::ArrayRef>()) ||
!NewMod->getFunction(KernelName)) {
return JITResult{"Materializer passes should not fail"};
}

SYCLKernelInfo &MaterializerKernelInfo = *ModuleInfo.getKernelFor(KernelName);
if (auto Error = translation::KernelTranslator::translateKernel(
MaterializerKernelInfo, *NewMod, JITCtx, TargetFormat, TargetCPU,
TargetFeatures)) {
return errorToFusionResult(std::move(Error),
"Translation to output format failed");
}

return JITResult{MaterializerKernelInfo};
}

extern "C" JITResult fuseKernels(View<SYCLKernelInfo> KernelInformation,
const char *FusedKernelName,
View<ParameterIdentity> Identities,
BarrierFlags BarriersFlags,
View<ParameterInternalization> Internalization,
View<jit_compiler::JITConstant> Constants) {

std::vector<std::string> KernelsToFuse;
llvm::transform(KernelInformation, std::back_inserter(KernelsToFuse),
Expand All @@ -93,8 +140,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
}

if (!isTargetFormatSupported(TargetFormat)) {
return FusionResult(
"Fusion output target format not supported by this build");
return JITResult("Fusion output target format not supported by this build");
}

auto &JITCtx = JITContext::getInstance();
Expand All @@ -117,7 +163,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
// before returning the kernel info to the runtime.
CachedKernel->NDR = FusedNDR->getNDR();
}
return FusionResult{*CachedKernel, /*Cached*/ true};
return JITResult{*CachedKernel, /*Cached*/ true};
}
helper::printDebugMessage(
"Compiling new kernel, no suitable cached kernel found");
Expand Down Expand Up @@ -165,13 +211,13 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
BarriersFlags);

if (!NewMod->getFunction(FusedKernelName)) {
return FusionResult{"Kernel fusion failed"};
return JITResult{"Kernel fusion failed"};
}

// Get the updated kernel info for the fused kernel and add the information to
// the existing KernelInfo.
if (!NewModInfo->hasKernelFor(FusedKernelName)) {
return FusionResult{"No KernelInfo for fused kernel"};
return JITResult{"No KernelInfo for fused kernel"};
}

SYCLKernelInfo &FusedKernelInfo = *NewModInfo->getKernelFor(FusedKernelName);
Expand All @@ -188,7 +234,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
JITCtx.addCacheEntry(CacheKey, FusedKernelInfo);
}

return FusionResult{FusedKernelInfo};
return JITResult{FusedKernelInfo};
}

extern "C" void resetJITConfiguration() { ConfigHelper::reset(); }
Expand Down
46 changes: 46 additions & 0 deletions sycl-fusion/jit-compiler/lib/fusion/FusionPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "helper/ConfigHelper.h"
#include "internalization/Internalization.h"
#include "kernel-fusion/SYCLKernelFusion.h"
#include "kernel-fusion/SYCLSpecConstMaterializer.h"
#include "kernel-info/SYCLKernelInfo.h"
#include "syclcp/SYCLCP.h"

Expand Down Expand Up @@ -141,3 +142,48 @@ FusionPipeline::runFusionPasses(Module &Mod, SYCLModuleInfo &InputInfo,

return std::make_unique<SYCLModuleInfo>(std::move(*NewModInfo.ModuleInfo));
}

bool FusionPipeline::runMaterializerPasses(
llvm::Module &Mod, llvm::ArrayRef<unsigned char> SpecConstData) {
PassBuilder PB;
LoopAnalysisManager LAM;
FunctionAnalysisManager FAM;
CGSCCAnalysisManager CGAM;
ModuleAnalysisManager MAM;
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);

ModulePassManager MPM;
// Register inserter and materializer passes.
{
FunctionPassManager FPM;
MPM.addPass(SYCLSpecConstDataInserter{SpecConstData});
FPM.addPass(SYCLSpecConstMaterializer{});
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}
// Add generic optimizations,
{
FunctionPassManager FPM;
MPM.addPass(AlwaysInlinerPass{});
FPM.addPass(SROAPass{SROAOptions::ModifyCFG});
FPM.addPass(SCCPPass{});
FPM.addPass(ADCEPass{});
FPM.addPass(EarlyCSEPass{/*UseMemorySSA*/ true});
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}
// followed by unrolling.
{
FunctionPassManager FPM;
FPM.addPass(createFunctionToLoopPassAdaptor(IndVarSimplifyPass{}));
LoopUnrollOptions UnrollOptions;
FPM.addPass(LoopUnrollPass{UnrollOptions});
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}

MPM.run(Mod, MAM);

return true;
}
7 changes: 7 additions & 0 deletions sycl-fusion/jit-compiler/lib/fusion/FusionPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ class FusionPipeline {
static std::unique_ptr<SYCLModuleInfo>
runFusionPasses(llvm::Module &Mod, SYCLModuleInfo &InputInfo,
BarrierFlags BarriersFlags);

///
/// Run the necessary passes in a custom pass pipeline to perform
/// materialization of kernel specialization constants.
static bool
runMaterializerPasses(llvm::Module &Mod,
llvm::ArrayRef<unsigned char> SpecConstData);
};
} // namespace fusion
} // namespace jit_compiler
Expand Down
Loading
Loading