Skip to content

Commit 8a17167

Browse files
authored
[SYCL][RTC] Preliminary support for ESIMD kernels (#16222)
Adds support for compiling source strings that contain ESIMD kernels *only*, hence require no device splitting. I'm using a simplified version of the driver logic in `sycl-post-link`. The pass pipeline construction helper `lowerEsimdConstructs` is also copied from `sycl-post-link`. --------- Signed-off-by: Julian Oppermann <[email protected]>
1 parent e9cbb87 commit 8a17167

File tree

8 files changed

+273
-33
lines changed

8 files changed

+273
-33
lines changed

sycl-jit/common/include/Kernel.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,10 @@ struct RTCBundleInfo {
403403
RTCBundleBinaryInfo BinaryInfo;
404404
FrozenSymbolTable SymbolTable;
405405
FrozenPropertyRegistry Properties;
406+
407+
RTCBundleInfo() = default;
408+
RTCBundleInfo(RTCBundleInfo &&) = default;
409+
RTCBundleInfo &operator=(RTCBundleInfo &&) = default;
406410
};
407411

408412
} // namespace jit_compiler

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_llvm_library(sycl-jit
88
lib/fusion/JITContext.cpp
99
lib/fusion/ModuleHelper.cpp
1010
lib/rtc/DeviceCompilation.cpp
11+
lib/rtc/ESIMD.cpp
1112
lib/helper/ConfigHelper.cpp
1213

1314
SHARED
@@ -32,6 +33,7 @@ add_llvm_library(sycl-jit
3233
TargetParser
3334
MC
3435
SYCLLowerIR
36+
GenXIntrinsics
3537
${LLVM_TARGETS_TO_BUILD}
3638

3739
LINK_LIBS

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,13 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
261261
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
262262
}
263263

264-
auto BundleInfoOrError = performPostLink(*Module, UserArgList);
265-
if (!BundleInfoOrError) {
266-
return errorTo<RTCResult>(BundleInfoOrError.takeError(),
264+
auto PostLinkResultOrError = performPostLink(std::move(Module), UserArgList);
265+
if (!PostLinkResultOrError) {
266+
return errorTo<RTCResult>(PostLinkResultOrError.takeError(),
267267
"Post-link phase failed");
268268
}
269-
auto BundleInfo = std::move(*BundleInfoOrError);
269+
RTCBundleInfo BundleInfo;
270+
std::tie(BundleInfo, Module) = std::move(*PostLinkResultOrError);
270271

271272
auto BinaryInfoOrError =
272273
translation::KernelTranslator::translateBundleToSPIRV(

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "DeviceCompilation.h"
10+
#include "ESIMD.h"
1011

1112
#include <clang/Basic/DiagnosticDriver.h>
1213
#include <clang/Basic/Version.h>
@@ -27,6 +28,8 @@
2728
#include <llvm/IRReader/IRReader.h>
2829
#include <llvm/Linker/Linker.h>
2930
#include <llvm/SYCLLowerIR/ComputeModuleRuntimeInfo.h>
31+
#include <llvm/SYCLLowerIR/ESIMD/LowerESIMD.h>
32+
#include <llvm/SYCLLowerIR/LowerInvokeSimd.h>
3033
#include <llvm/SYCLLowerIR/ModuleSplitter.h>
3134
#include <llvm/SYCLLowerIR/SYCLJointMatrixTransform.h>
3235
#include <llvm/Support/PropertySetIO.h>
@@ -432,42 +435,84 @@ template <class PassClass> static bool runModulePass(llvm::Module &M) {
432435
return !Res.areAllPreserved();
433436
}
434437

435-
Expected<RTCBundleInfo> jit_compiler::performPostLink(
436-
llvm::Module &Module, [[maybe_unused]] const InputArgList &UserArgList) {
438+
llvm::Expected<PostLinkResult> jit_compiler::performPostLink(
439+
std::unique_ptr<llvm::Module> Module,
440+
[[maybe_unused]] const llvm::opt::InputArgList &UserArgList) {
437441
// This is a simplified version of `processInputModule` in
438442
// `llvm/tools/sycl-post-link.cpp`. Assertions/TODOs point to functionality
439443
// left out of the algorithm for now.
440444

441-
assert(!Module.getGlobalVariable("llvm.used") &&
442-
!Module.getGlobalVariable("llvm.compiler.used"));
445+
// TODO: SplitMode can be controlled by the user.
446+
const auto SplitMode = SPLIT_NONE;
447+
448+
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
449+
// `shouldEmitOnlyKernelsAsEntryPoints` in
450+
// `clang/lib/Driver/ToolChains/Clang.cpp`.
451+
const bool EmitOnlyKernelsAsEntryPoints = true;
452+
453+
// TODO: The optlevel passed to `sycl-post-link` is determined by
454+
// `getSYCLPostLinkOptimizationLevel` in
455+
// `clang/lib/Driver/ToolChains/Clang.cpp`.
456+
const bool PerformOpts = true;
457+
458+
// Propagate ESIMD attribute to wrapper functions to prevent spurious splits
459+
// and kernel link errors.
460+
runModulePass<SYCLFixupESIMDKernelWrapperMDPass>(*Module);
461+
462+
assert(!Module->getGlobalVariable("llvm.used") &&
463+
!Module->getGlobalVariable("llvm.compiler.used"));
443464
// Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
444465
// `removeDeviceGlobalFromCompilerUsed` methods.
445466

446-
assert(!isModuleUsingAsan(Module));
467+
assert(!isModuleUsingAsan(*Module));
447468
// Otherwise: Need to instrument each image scope device globals if the module
448469
// has been instrumented by sanitizer pass.
449470

450471
// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
451472
// LLVM IR specification.
452-
runModulePass<SYCLJointMatrixTransformPass>(Module);
473+
runModulePass<SYCLJointMatrixTransformPass>(*Module);
474+
475+
// Do invoke_simd processing before splitting because this:
476+
// - saves processing time (the pass is run once, even though on larger IR)
477+
// - doing it before SYCL/ESIMD splitting is required for correctness
478+
if (runModulePass<SYCLLowerInvokeSimdPass>(*Module)) {
479+
return createStringError("`invoke_simd` calls detected");
480+
}
453481

454482
// TODO: Implement actual device code splitting. We're just using the splitter
455483
// to obtain additional information about the module for now.
456-
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
457-
// `shouldEmitOnlyKernelsAsEntryPoints` in
458-
// `clang/lib/Driver/ToolChains/Clang.cpp`.
484+
459485
std::unique_ptr<ModuleSplitterBase> Splitter = getDeviceCodeSplitter(
460-
ModuleDesc{std::unique_ptr<llvm::Module>{&Module}}, SPLIT_NONE,
461-
/*IROutputOnly=*/false,
462-
/*EmitOnlyKernelsAsEntryPoints=*/true);
463-
assert(Splitter->remainingSplits() == 1);
486+
ModuleDesc{std::move(Module)}, SplitMode,
487+
/*IROutputOnly=*/false, EmitOnlyKernelsAsEntryPoints);
488+
assert(Splitter->hasMoreSplits());
489+
if (Splitter->remainingSplits() > 1) {
490+
return createStringError("Device code requires splitting");
491+
}
464492

465493
// TODO: Call `verifyNoCrossModuleDeviceGlobalUsage` if device globals shall
466494
// be processed.
467495

468-
assert(Splitter->hasMoreSplits());
469496
ModuleDesc MDesc = Splitter->nextSplit();
470-
assert(&Module == &MDesc.getModule());
497+
498+
// TODO: Call `MDesc.fixupLinkageOfDirectInvokeSimdTargets()` when
499+
// `invoke_simd` is supported.
500+
501+
SmallVector<ModuleDesc, 2> ESIMDSplits =
502+
splitByESIMD(std::move(MDesc), EmitOnlyKernelsAsEntryPoints);
503+
assert(!ESIMDSplits.empty());
504+
if (ESIMDSplits.size() > 1) {
505+
return createStringError("Mixing SYCL and ESIMD code is unsupported");
506+
}
507+
MDesc = std::move(ESIMDSplits.front());
508+
509+
if (MDesc.isESIMD()) {
510+
// `sycl-post-link` has a `-lower-esimd` option, but there's no clang driver
511+
// option to influence it. Rather, the driver sets it unconditionally in the
512+
// multi-file output mode, which we are mimicking here.
513+
lowerEsimdConstructs(MDesc, PerformOpts);
514+
}
515+
471516
MDesc.saveSplitInformationAsMetadata();
472517

473518
RTCBundleInfo BundleInfo;
@@ -504,10 +549,7 @@ Expected<RTCBundleInfo> jit_compiler::performPostLink(
504549
}
505550
};
506551

507-
// Regain ownership of the module.
508-
MDesc.releaseModulePtr().release();
509-
510-
return std::move(BundleInfo);
552+
return PostLinkResult{std::move(BundleInfo), MDesc.releaseModulePtr()};
511553
}
512554

513555
Expected<InputArgList>
@@ -569,11 +611,9 @@ jit_compiler::parseUserArgs(View<const char *> UserArgs) {
569611
return createStringError("Device code splitting is not yet supported");
570612
}
571613

572-
if (AL.hasArg(OPT_fsycl_device_code_split_esimd,
573-
OPT_fno_sycl_device_code_split_esimd)) {
574-
// TODO: There are more ESIMD-related options.
575-
return createStringError(
576-
"Runtime compilation of ESIMD kernels is not yet supported");
614+
if (!AL.hasFlag(OPT_fsycl_device_code_split_esimd,
615+
OPT_fno_sycl_device_code_split_esimd, true)) {
616+
return createStringError("ESIMD device code split cannot be deactivated");
577617
}
578618

579619
if (AL.hasFlag(OPT_fsycl_dead_args_optimization,

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ llvm::Error linkDeviceLibraries(llvm::Module &Module,
3030
const llvm::opt::InputArgList &UserArgList,
3131
std::string &BuildLog);
3232

33-
llvm::Expected<RTCBundleInfo>
34-
performPostLink(llvm::Module &Module,
33+
using PostLinkResult = std::pair<RTCBundleInfo, std::unique_ptr<llvm::Module>>;
34+
llvm::Expected<PostLinkResult>
35+
performPostLink(std::unique_ptr<llvm::Module> Module,
3536
const llvm::opt::InputArgList &UserArgList);
3637

3738
llvm::Expected<llvm::opt::InputArgList>
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
//===------------- ESIMD.cpp - Driver for ESIMD lowering ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "ESIMD.h"
10+
11+
#include "llvm/Analysis/CGSCCPassManager.h"
12+
#include "llvm/Analysis/LoopAnalysisManager.h"
13+
#include "llvm/GenXIntrinsics/GenXSPIRVWriterAdaptor.h"
14+
#include "llvm/IR/Module.h"
15+
#include "llvm/IR/PassManager.h"
16+
#include "llvm/Passes/PassBuilder.h"
17+
#include "llvm/SYCLLowerIR/ESIMD/LowerESIMD.h"
18+
#include "llvm/Transforms/InstCombine/InstCombine.h"
19+
#include "llvm/Transforms/Scalar/DCE.h"
20+
#include "llvm/Transforms/Scalar/EarlyCSE.h"
21+
#include "llvm/Transforms/Scalar/SROA.h"
22+
23+
using namespace llvm;
24+
25+
using string_vector = std::vector<std::string>;
26+
27+
// When ESIMD code was separated from the regular SYCL code,
28+
// we can safely process ESIMD part.
29+
void jit_compiler::lowerEsimdConstructs(module_split::ModuleDesc &MD,
30+
bool PerformOpts) {
31+
LoopAnalysisManager LAM;
32+
CGSCCAnalysisManager CGAM;
33+
FunctionAnalysisManager FAM;
34+
ModuleAnalysisManager MAM;
35+
36+
PassBuilder PB;
37+
PB.registerModuleAnalyses(MAM);
38+
PB.registerCGSCCAnalyses(CGAM);
39+
PB.registerFunctionAnalyses(FAM);
40+
PB.registerLoopAnalyses(LAM);
41+
PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
42+
43+
ModulePassManager MPM;
44+
MPM.addPass(SYCLLowerESIMDPass(/*ModuleContainsScalar=*/false));
45+
46+
if (PerformOpts) {
47+
FunctionPassManager FPM;
48+
FPM.addPass(SROAPass(SROAOptions::ModifyCFG));
49+
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
50+
}
51+
MPM.addPass(ESIMDOptimizeVecArgCallConvPass{});
52+
FunctionPassManager MainFPM;
53+
MainFPM.addPass(ESIMDLowerLoadStorePass{});
54+
55+
if (PerformOpts) {
56+
MainFPM.addPass(SROAPass(SROAOptions::ModifyCFG));
57+
MainFPM.addPass(EarlyCSEPass(true));
58+
MainFPM.addPass(InstCombinePass{});
59+
MainFPM.addPass(DCEPass{});
60+
// TODO: maybe remove some passes below that don't affect code quality
61+
MainFPM.addPass(SROAPass(SROAOptions::ModifyCFG));
62+
MainFPM.addPass(EarlyCSEPass(true));
63+
MainFPM.addPass(InstCombinePass{});
64+
MainFPM.addPass(DCEPass{});
65+
}
66+
MPM.addPass(ESIMDLowerSLMReservationCalls{});
67+
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(MainFPM)));
68+
MPM.addPass(GenXSPIRVWriterAdaptor(/*RewriteTypes=*/true,
69+
/*RewriteSingleElementVectorsIn*/ false));
70+
// GenXSPIRVWriterAdaptor pass replaced some functions with "rewritten"
71+
// versions so the entry point table must be rebuilt.
72+
// TODO Change entry point search to analysis?
73+
std::vector<std::string> Names;
74+
MD.saveEntryPointNames(Names);
75+
MPM.run(MD.getModule(), MAM);
76+
MD.rebuildEntryPoints(Names);
77+
}

sycl-jit/jit-compiler/lib/rtc/ESIMD.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===-------------- ESIMD.h - Driver for ESIMD lowering -------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SYCL_JIT_COMPILER_RTC_ESIMD_H
10+
#define SYCL_JIT_COMPILER_RTC_ESIMD_H
11+
12+
#include "llvm/SYCLLowerIR/ModuleSplitter.h"
13+
14+
namespace jit_compiler {
15+
16+
// Runs a pass pipeline to lower ESIMD constructs on the given split model,
17+
// which must only contain ESIMD entrypoints. This is a copy of the similar
18+
// function in `sycl-post-link`.
19+
void lowerEsimdConstructs(llvm::module_split::ModuleDesc &MD, bool PerformOpts);
20+
21+
} // namespace jit_compiler
22+
23+
#endif // SYCL_JIT_COMPILER_RTC_ESIMD_H

0 commit comments

Comments
 (0)