Skip to content

[SYCL][RTC] Preliminary support for ESIMD kernels #16222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sycl-jit/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ struct RTCBundleInfo {
RTCBundleBinaryInfo BinaryInfo;
FrozenSymbolTable SymbolTable;
FrozenPropertyRegistry Properties;

RTCBundleInfo() = default;
RTCBundleInfo(RTCBundleInfo &&) = default;
RTCBundleInfo &operator=(RTCBundleInfo &&) = default;
};

} // namespace jit_compiler
Expand Down
2 changes: 2 additions & 0 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_llvm_library(sycl-jit
lib/fusion/JITContext.cpp
lib/fusion/ModuleHelper.cpp
lib/rtc/DeviceCompilation.cpp
lib/rtc/ESIMD.cpp
lib/helper/ConfigHelper.cpp

SHARED
Expand All @@ -32,6 +33,7 @@ add_llvm_library(sycl-jit
TargetParser
MC
SYCLLowerIR
GenXIntrinsics
${LLVM_TARGETS_TO_BUILD}

LINK_LIBS
Expand Down
9 changes: 5 additions & 4 deletions sycl-jit/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,13 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
}

auto BundleInfoOrError = performPostLink(*Module, UserArgList);
if (!BundleInfoOrError) {
return errorTo<RTCResult>(BundleInfoOrError.takeError(),
auto PostLinkResultOrError = performPostLink(std::move(Module), UserArgList);
if (!PostLinkResultOrError) {
return errorTo<RTCResult>(PostLinkResultOrError.takeError(),
"Post-link phase failed");
}
auto BundleInfo = std::move(*BundleInfoOrError);
RTCBundleInfo BundleInfo;
std::tie(BundleInfo, Module) = std::move(*PostLinkResultOrError);

auto BinaryInfoOrError =
translation::KernelTranslator::translateBundleToSPIRV(
Expand Down
88 changes: 64 additions & 24 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "DeviceCompilation.h"
#include "ESIMD.h"

#include <clang/Basic/DiagnosticDriver.h>
#include <clang/Basic/Version.h>
Expand All @@ -27,6 +28,8 @@
#include <llvm/IRReader/IRReader.h>
#include <llvm/Linker/Linker.h>
#include <llvm/SYCLLowerIR/ComputeModuleRuntimeInfo.h>
#include <llvm/SYCLLowerIR/ESIMD/LowerESIMD.h>
#include <llvm/SYCLLowerIR/LowerInvokeSimd.h>
#include <llvm/SYCLLowerIR/ModuleSplitter.h>
#include <llvm/SYCLLowerIR/SYCLJointMatrixTransform.h>
#include <llvm/Support/PropertySetIO.h>
Expand Down Expand Up @@ -432,42 +435,84 @@ template <class PassClass> static bool runModulePass(llvm::Module &M) {
return !Res.areAllPreserved();
}

Expected<RTCBundleInfo> jit_compiler::performPostLink(
llvm::Module &Module, [[maybe_unused]] const InputArgList &UserArgList) {
llvm::Expected<PostLinkResult> jit_compiler::performPostLink(
std::unique_ptr<llvm::Module> Module,
[[maybe_unused]] const llvm::opt::InputArgList &UserArgList) {
// This is a simplified version of `processInputModule` in
// `llvm/tools/sycl-post-link.cpp`. Assertions/TODOs point to functionality
// left out of the algorithm for now.

assert(!Module.getGlobalVariable("llvm.used") &&
!Module.getGlobalVariable("llvm.compiler.used"));
// TODO: SplitMode can be controlled by the user.
const auto SplitMode = SPLIT_NONE;

// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
// `shouldEmitOnlyKernelsAsEntryPoints` in
// `clang/lib/Driver/ToolChains/Clang.cpp`.
const bool EmitOnlyKernelsAsEntryPoints = true;

// TODO: The optlevel passed to `sycl-post-link` is determined by
// `getSYCLPostLinkOptimizationLevel` in
// `clang/lib/Driver/ToolChains/Clang.cpp`.
const bool PerformOpts = true;

// Propagate ESIMD attribute to wrapper functions to prevent spurious splits
// and kernel link errors.
runModulePass<SYCLFixupESIMDKernelWrapperMDPass>(*Module);

assert(!Module->getGlobalVariable("llvm.used") &&
!Module->getGlobalVariable("llvm.compiler.used"));
// Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
// `removeDeviceGlobalFromCompilerUsed` methods.

assert(!isModuleUsingAsan(Module));
assert(!isModuleUsingAsan(*Module));
// Otherwise: Need to instrument each image scope device globals if the module
// has been instrumented by sanitizer pass.

// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
// LLVM IR specification.
runModulePass<SYCLJointMatrixTransformPass>(Module);
runModulePass<SYCLJointMatrixTransformPass>(*Module);

// Do invoke_simd processing before splitting because this:
// - saves processing time (the pass is run once, even though on larger IR)
// - doing it before SYCL/ESIMD splitting is required for correctness
if (runModulePass<SYCLLowerInvokeSimdPass>(*Module)) {
return createStringError("`invoke_simd` calls detected");
}

// TODO: Implement actual device code splitting. We're just using the splitter
// to obtain additional information about the module for now.
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
// `shouldEmitOnlyKernelsAsEntryPoints` in
// `clang/lib/Driver/ToolChains/Clang.cpp`.

std::unique_ptr<ModuleSplitterBase> Splitter = getDeviceCodeSplitter(
ModuleDesc{std::unique_ptr<llvm::Module>{&Module}}, SPLIT_NONE,
/*IROutputOnly=*/false,
/*EmitOnlyKernelsAsEntryPoints=*/true);
assert(Splitter->remainingSplits() == 1);
ModuleDesc{std::move(Module)}, SplitMode,
/*IROutputOnly=*/false, EmitOnlyKernelsAsEntryPoints);
assert(Splitter->hasMoreSplits());
if (Splitter->remainingSplits() > 1) {
return createStringError("Device code requires splitting");
}

// TODO: Call `verifyNoCrossModuleDeviceGlobalUsage` if device globals shall
// be processed.

assert(Splitter->hasMoreSplits());
ModuleDesc MDesc = Splitter->nextSplit();
assert(&Module == &MDesc.getModule());

// TODO: Call `MDesc.fixupLinkageOfDirectInvokeSimdTargets()` when
// `invoke_simd` is supported.

SmallVector<ModuleDesc, 2> ESIMDSplits =
splitByESIMD(std::move(MDesc), EmitOnlyKernelsAsEntryPoints);
assert(!ESIMDSplits.empty());
if (ESIMDSplits.size() > 1) {
return createStringError("Mixing SYCL and ESIMD code is unsupported");
}
MDesc = std::move(ESIMDSplits.front());

if (MDesc.isESIMD()) {
// `sycl-post-link` has a `-lower-esimd` option, but there's no clang driver
// option to influence it. Rather, the driver sets it unconditionally in the
// multi-file output mode, which we are mimicking here.
lowerEsimdConstructs(MDesc, PerformOpts);
}

MDesc.saveSplitInformationAsMetadata();

RTCBundleInfo BundleInfo;
Expand Down Expand Up @@ -504,10 +549,7 @@ Expected<RTCBundleInfo> jit_compiler::performPostLink(
}
};

// Regain ownership of the module.
MDesc.releaseModulePtr().release();

return std::move(BundleInfo);
return PostLinkResult{std::move(BundleInfo), MDesc.releaseModulePtr()};
}

Expected<InputArgList>
Expand Down Expand Up @@ -569,11 +611,9 @@ jit_compiler::parseUserArgs(View<const char *> UserArgs) {
return createStringError("Device code splitting is not yet supported");
}

if (AL.hasArg(OPT_fsycl_device_code_split_esimd,
OPT_fno_sycl_device_code_split_esimd)) {
// TODO: There are more ESIMD-related options.
return createStringError(
"Runtime compilation of ESIMD kernels is not yet supported");
if (!AL.hasFlag(OPT_fsycl_device_code_split_esimd,
OPT_fno_sycl_device_code_split_esimd, true)) {
return createStringError("ESIMD device code split cannot be deactivated");
}

if (AL.hasFlag(OPT_fsycl_dead_args_optimization,
Expand Down
5 changes: 3 additions & 2 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ llvm::Error linkDeviceLibraries(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog);

llvm::Expected<RTCBundleInfo>
performPostLink(llvm::Module &Module,
using PostLinkResult = std::pair<RTCBundleInfo, std::unique_ptr<llvm::Module>>;
llvm::Expected<PostLinkResult>
performPostLink(std::unique_ptr<llvm::Module> Module,
const llvm::opt::InputArgList &UserArgList);

llvm::Expected<llvm::opt::InputArgList>
Expand Down
77 changes: 77 additions & 0 deletions sycl-jit/jit-compiler/lib/rtc/ESIMD.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===------------- ESIMD.cpp - Driver for ESIMD lowering ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "ESIMD.h"

#include "llvm/Analysis/CGSCCPassManager.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/GenXIntrinsics/GenXSPIRVWriterAdaptor.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar/DCE.h"
#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Transforms/Scalar/SROA.h"

using namespace llvm;

using string_vector = std::vector<std::string>;

// When ESIMD code was separated from the regular SYCL code,
// we can safely process ESIMD part.
void jit_compiler::lowerEsimdConstructs(module_split::ModuleDesc &MD,
bool PerformOpts) {
LoopAnalysisManager LAM;
CGSCCAnalysisManager CGAM;
FunctionAnalysisManager FAM;
ModuleAnalysisManager MAM;

PassBuilder PB;
PB.registerModuleAnalyses(MAM);
PB.registerCGSCCAnalyses(CGAM);
PB.registerFunctionAnalyses(FAM);
PB.registerLoopAnalyses(LAM);
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);

ModulePassManager MPM;
MPM.addPass(SYCLLowerESIMDPass(/*ModuleContainsScalar=*/false));

if (PerformOpts) {
FunctionPassManager FPM;
FPM.addPass(SROAPass(SROAOptions::ModifyCFG));
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
}
MPM.addPass(ESIMDOptimizeVecArgCallConvPass{});
FunctionPassManager MainFPM;
MainFPM.addPass(ESIMDLowerLoadStorePass{});

if (PerformOpts) {
MainFPM.addPass(SROAPass(SROAOptions::ModifyCFG));
MainFPM.addPass(EarlyCSEPass(true));
MainFPM.addPass(InstCombinePass{});
MainFPM.addPass(DCEPass{});
// TODO: maybe remove some passes below that don't affect code quality
MainFPM.addPass(SROAPass(SROAOptions::ModifyCFG));
MainFPM.addPass(EarlyCSEPass(true));
MainFPM.addPass(InstCombinePass{});
MainFPM.addPass(DCEPass{});
}
MPM.addPass(ESIMDLowerSLMReservationCalls{});
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(MainFPM)));
MPM.addPass(GenXSPIRVWriterAdaptor(/*RewriteTypes=*/true,
/*RewriteSingleElementVectorsIn*/ false));
// GenXSPIRVWriterAdaptor pass replaced some functions with "rewritten"
// versions so the entry point table must be rebuilt.
// TODO Change entry point search to analysis?
std::vector<std::string> Names;
MD.saveEntryPointNames(Names);
MPM.run(MD.getModule(), MAM);
MD.rebuildEntryPoints(Names);
}
23 changes: 23 additions & 0 deletions sycl-jit/jit-compiler/lib/rtc/ESIMD.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===-------------- ESIMD.h - Driver for ESIMD lowering -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef SYCL_JIT_COMPILER_RTC_ESIMD_H
#define SYCL_JIT_COMPILER_RTC_ESIMD_H

#include "llvm/SYCLLowerIR/ModuleSplitter.h"

namespace jit_compiler {

// Runs a pass pipeline to lower ESIMD constructs on the given split model,
// which must only contain ESIMD entrypoints. This is a copy of the similar
// function in `sycl-post-link`.
void lowerEsimdConstructs(llvm::module_split::ModuleDesc &MD, bool PerformOpts);

} // namespace jit_compiler

#endif // SYCL_JIT_COMPILER_RTC_ESIMD_H
Loading
Loading