Skip to content

Commit 2ab36be

Browse files
committed
Support multiple device images per RTC bundle.
Signed-off-by: Julian Oppermann <[email protected]>
1 parent 4a2e36e commit 2ab36be

File tree

10 files changed

+203
-139
lines changed

10 files changed

+203
-139
lines changed

sycl-jit/common/include/Kernel.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ struct InMemoryFile {
359359
const char *Contents;
360360
};
361361

362-
using RTCBundleBinaryInfo = SYCLKernelBinaryInfo;
362+
using RTCDevImgBinaryInfo = SYCLKernelBinaryInfo;
363363
using FrozenSymbolTable = DynArray<sycl::detail::string>;
364364

365365
// Note: `FrozenPropertyValue` and `FrozenPropertySet` constructors take
@@ -399,16 +399,18 @@ struct FrozenPropertySet {
399399

400400
using FrozenPropertyRegistry = DynArray<FrozenPropertySet>;
401401

402-
struct RTCBundleInfo {
403-
RTCBundleBinaryInfo BinaryInfo;
402+
struct RTCDevImgInfo {
403+
RTCDevImgBinaryInfo BinaryInfo;
404404
FrozenSymbolTable SymbolTable;
405405
FrozenPropertyRegistry Properties;
406406

407-
RTCBundleInfo() = default;
408-
RTCBundleInfo(RTCBundleInfo &&) = default;
409-
RTCBundleInfo &operator=(RTCBundleInfo &&) = default;
407+
RTCDevImgInfo() = default;
408+
RTCDevImgInfo(RTCDevImgInfo &&) = default;
409+
RTCDevImgInfo &operator=(RTCDevImgInfo &&) = default;
410410
};
411411

412+
using RTCBundleInfo = DynArray<RTCDevImgInfo>;
413+
412414
} // namespace jit_compiler
413415

414416
#endif // SYCL_FUSION_COMMON_KERNEL_H

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -266,17 +266,18 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
266266
return errorTo<RTCResult>(PostLinkResultOrError.takeError(),
267267
"Post-link phase failed");
268268
}
269-
RTCBundleInfo BundleInfo;
270-
std::tie(BundleInfo, Module) = std::move(*PostLinkResultOrError);
271-
272-
auto BinaryInfoOrError =
273-
translation::KernelTranslator::translateBundleToSPIRV(
274-
*Module, JITContext::getInstance());
275-
if (!BinaryInfoOrError) {
276-
return errorTo<RTCResult>(BinaryInfoOrError.takeError(),
277-
"SPIR-V translation failed");
269+
auto [BundleInfo, Modules] = std::move(*PostLinkResultOrError);
270+
271+
for (auto [DevImgInfo, Module] : llvm::zip_equal(BundleInfo, Modules)) {
272+
auto BinaryInfoOrError =
273+
translation::KernelTranslator::translateDevImgToSPIRV(
274+
*Module, JITContext::getInstance());
275+
if (!BinaryInfoOrError) {
276+
return errorTo<RTCResult>(BinaryInfoOrError.takeError(),
277+
"SPIR-V translation failed");
278+
}
279+
DevImgInfo.BinaryInfo = std::move(*BinaryInfoOrError);
278280
}
279-
BundleInfo.BinaryInfo = std::move(*BinaryInfoOrError);
280281

281282
return RTCResult{std::move(BundleInfo), BuildLog.c_str()};
282283
}

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 92 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,6 @@ Expected<std::unique_ptr<llvm::Module>> jit_compiler::compileDeviceCode(
232232
DerivedArgList DAL{UserArgList};
233233
const auto &OptTable = getDriverOptTable();
234234
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_fsycl_device_only));
235-
DAL.AddFlagArg(nullptr,
236-
OptTable.getOption(OPT_fno_sycl_dead_args_optimization));
237235
DAL.AddJoinedArg(
238236
nullptr, OptTable.getOption(OPT_resource_dir_EQ),
239237
(DPCPPRoot + "/lib/clang/" + Twine(CLANG_VERSION_MAJOR)).str());
@@ -435,15 +433,35 @@ template <class PassClass> static bool runModulePass(llvm::Module &M) {
435433
return !Res.areAllPreserved();
436434
}
437435

438-
llvm::Expected<PostLinkResult> jit_compiler::performPostLink(
439-
std::unique_ptr<llvm::Module> Module,
440-
[[maybe_unused]] const llvm::opt::InputArgList &UserArgList) {
436+
static IRSplitMode getDeviceCodeSplitMode(const InputArgList &UserArgList) {
437+
// This is the (combined) logic from
438+
// `get[NonTriple|Triple]BasedSYCLPostLinkOpts` in
439+
// `clang/lib/Driver/ToolChains/Clang.cpp`: Default is auto mode, but the user
440+
// can override it by specifying the `-fsycl-device-code-split=` option. The
441+
// no-argument variant `-fsycl-device-code-split` is ignored.
442+
if (auto *Arg = UserArgList.getLastArg(OPT_fsycl_device_code_split_EQ)) {
443+
StringRef ArgVal{Arg->getValue()};
444+
if (ArgVal == "per_kernel") {
445+
return SPLIT_PER_KERNEL;
446+
}
447+
if (ArgVal == "per_source") {
448+
return SPLIT_PER_TU;
449+
}
450+
if (ArgVal == "off") {
451+
return SPLIT_NONE;
452+
}
453+
}
454+
return SPLIT_AUTO;
455+
}
456+
457+
Expected<PostLinkResult>
458+
jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
459+
const InputArgList &UserArgList) {
441460
// This is a simplified version of `processInputModule` in
442461
// `llvm/tools/sycl-post-link.cpp`. Assertions/TODOs point to functionality
443462
// left out of the algorithm for now.
444463

445-
// TODO: SplitMode can be controlled by the user.
446-
const auto SplitMode = SPLIT_NONE;
464+
const auto SplitMode = getDeviceCodeSplitMode(UserArgList);
447465

448466
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
449467
// `shouldEmitOnlyKernelsAsEntryPoints` in
@@ -486,70 +504,83 @@ llvm::Expected<PostLinkResult> jit_compiler::performPostLink(
486504
ModuleDesc{std::move(Module)}, SplitMode,
487505
/*IROutputOnly=*/false, EmitOnlyKernelsAsEntryPoints);
488506
assert(Splitter->hasMoreSplits());
489-
if (Splitter->remainingSplits() > 1) {
490-
return createStringError("Device code requires splitting");
491-
}
492507

493508
// TODO: Call `verifyNoCrossModuleDeviceGlobalUsage` if device globals shall
494509
// be processed.
495510

496-
ModuleDesc MDesc = Splitter->nextSplit();
511+
// TODO: This allocation assumes that there are no further splits required,
512+
// i.e. due to mixed SYCL/ESIMD modules.
513+
RTCBundleInfo BundleInfo{Splitter->remainingSplits()};
514+
SmallVector<std::unique_ptr<llvm::Module>> Modules;
497515

498-
// TODO: Call `MDesc.fixupLinkageOfDirectInvokeSimdTargets()` when
499-
// `invoke_simd` is supported.
516+
auto *DevImgInfoIt = BundleInfo.begin();
517+
while (Splitter->hasMoreSplits()) {
518+
assert(DevImgInfoIt != BundleInfo.end());
500519

501-
SmallVector<ModuleDesc, 2> ESIMDSplits =
502-
splitByESIMD(std::move(MDesc), EmitOnlyKernelsAsEntryPoints);
503-
assert(!ESIMDSplits.empty());
504-
if (ESIMDSplits.size() > 1) {
505-
return createStringError("Mixing SYCL and ESIMD code is unsupported");
506-
}
507-
MDesc = std::move(ESIMDSplits.front());
520+
ModuleDesc MDesc = Splitter->nextSplit();
521+
RTCDevImgInfo &DevImgInfo = *DevImgInfoIt++;
508522

509-
if (MDesc.isESIMD()) {
510-
// `sycl-post-link` has a `-lower-esimd` option, but there's no clang driver
511-
// option to influence it. Rather, the driver sets it unconditionally in the
512-
// multi-file output mode, which we are mimicking here.
513-
lowerEsimdConstructs(MDesc, PerformOpts);
514-
}
523+
// TODO: Call `MDesc.fixupLinkageOfDirectInvokeSimdTargets()` when
524+
// `invoke_simd` is supported.
515525

516-
MDesc.saveSplitInformationAsMetadata();
517-
518-
RTCBundleInfo BundleInfo;
519-
BundleInfo.SymbolTable = FrozenSymbolTable{MDesc.entries().size()};
520-
transform(MDesc.entries(), BundleInfo.SymbolTable.begin(),
521-
[](Function *F) { return F->getName(); });
522-
523-
// TODO: Determine what is requested.
524-
GlobalBinImageProps PropReq{
525-
/*EmitKernelParamInfo=*/true, /*EmitProgramMetadata=*/true,
526-
/*EmitExportedSymbols=*/true, /*EmitImportedSymbols=*/true,
527-
/*DeviceGlobals=*/false};
528-
PropertySetRegistry Properties =
529-
computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq);
530-
// TODO: Manually add `compile_target` property as in
531-
// `saveModuleProperties`?
532-
const auto &PropertySets = Properties.getPropSets();
533-
534-
BundleInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
535-
for (auto &&[KV, FrozenPropSet] : zip(PropertySets, BundleInfo.Properties)) {
536-
const auto &PropertySetName = KV.first;
537-
const auto &PropertySet = KV.second;
538-
FrozenPropSet =
539-
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
540-
for (auto &&[KV2, FrozenProp] : zip(PropertySet, FrozenPropSet.Values)) {
541-
const auto &PropertyName = KV2.first;
542-
const auto &PropertyValue = KV2.second;
543-
FrozenProp = PropertyValue.getType() == PropertyValue::Type::UINT32
544-
? FrozenPropertyValue{PropertyName.str(),
545-
PropertyValue.asUint32()}
546-
: FrozenPropertyValue{
547-
PropertyName.str(), PropertyValue.asRawByteArray(),
548-
PropertyValue.getRawByteArraySize()};
526+
SmallVector<ModuleDesc, 2> ESIMDSplits =
527+
splitByESIMD(std::move(MDesc), EmitOnlyKernelsAsEntryPoints);
528+
assert(!ESIMDSplits.empty());
529+
if (ESIMDSplits.size() > 1) {
530+
return createStringError("Mixing SYCL and ESIMD code is unsupported");
549531
}
550-
};
532+
MDesc = std::move(ESIMDSplits.front());
533+
534+
if (MDesc.isESIMD()) {
535+
// `sycl-post-link` has a `-lower-esimd` option, but there's no clang
536+
// driver option to influence it. Rather, the driver sets it
537+
// unconditionally in the multi-file output mode, which we are mimicking
538+
// here.
539+
lowerEsimdConstructs(MDesc, PerformOpts);
540+
}
541+
542+
MDesc.saveSplitInformationAsMetadata();
543+
544+
DevImgInfo.SymbolTable = FrozenSymbolTable{MDesc.entries().size()};
545+
transform(MDesc.entries(), DevImgInfo.SymbolTable.begin(),
546+
[](Function *F) { return F->getName(); });
547+
548+
// TODO: Determine what is requested.
549+
GlobalBinImageProps PropReq{
550+
/*EmitKernelParamInfo=*/true, /*EmitProgramMetadata=*/true,
551+
/*EmitExportedSymbols=*/true, /*EmitImportedSymbols=*/true,
552+
/*DeviceGlobals=*/false};
553+
PropertySetRegistry Properties =
554+
computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq);
555+
// TODO: Manually add `compile_target` property as in
556+
// `saveModuleProperties`?
557+
const auto &PropertySets = Properties.getPropSets();
558+
559+
DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
560+
for (auto [KV, FrozenPropSet] :
561+
zip_equal(PropertySets, DevImgInfo.Properties)) {
562+
const auto &PropertySetName = KV.first;
563+
const auto &PropertySet = KV.second;
564+
FrozenPropSet =
565+
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
566+
for (auto [KV2, FrozenProp] :
567+
zip_equal(PropertySet, FrozenPropSet.Values)) {
568+
const auto &PropertyName = KV2.first;
569+
const auto &PropertyValue = KV2.second;
570+
FrozenProp =
571+
PropertyValue.getType() == PropertyValue::Type::UINT32
572+
? FrozenPropertyValue{PropertyName.str(),
573+
PropertyValue.asUint32()}
574+
: FrozenPropertyValue{PropertyName.str(),
575+
PropertyValue.asRawByteArray(),
576+
PropertyValue.getRawByteArraySize()};
577+
}
578+
};
579+
580+
Modules.push_back(MDesc.releaseModulePtr());
581+
}
551582

552-
return PostLinkResult{std::move(BundleInfo), MDesc.releaseModulePtr()};
583+
return PostLinkResult{std::move(BundleInfo), std::move(Modules)};
553584
}
554585

555586
Expected<InputArgList>
@@ -606,21 +637,10 @@ jit_compiler::parseUserArgs(View<const char *> UserArgs) {
606637
}
607638
}
608639

609-
if (auto DCSMode = AL.getLastArgValue(OPT_fsycl_device_code_split_EQ, "none");
610-
DCSMode != "none" && DCSMode != "auto") {
611-
return createStringError("Device code splitting is not yet supported");
612-
}
613-
614640
if (!AL.hasFlag(OPT_fsycl_device_code_split_esimd,
615641
OPT_fno_sycl_device_code_split_esimd, true)) {
616642
return createStringError("ESIMD device code split cannot be deactivated");
617643
}
618644

619-
if (AL.hasFlag(OPT_fsycl_dead_args_optimization,
620-
OPT_fno_sycl_dead_args_optimization, false)) {
621-
return createStringError(
622-
"Dead argument optimization must be disabled for runtime compilation");
623-
}
624-
625645
return std::move(AL);
626646
}

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "Kernel.h"
1313
#include "View.h"
1414

15+
#include <llvm/ADT/SmallVector.h>
1516
#include <llvm/IR/Module.h>
1617
#include <llvm/Option/ArgList.h>
1718
#include <llvm/Support/Error.h>
@@ -30,7 +31,8 @@ llvm::Error linkDeviceLibraries(llvm::Module &Module,
3031
const llvm::opt::InputArgList &UserArgList,
3132
std::string &BuildLog);
3233

33-
using PostLinkResult = std::pair<RTCBundleInfo, std::unique_ptr<llvm::Module>>;
34+
using PostLinkResult =
35+
std::pair<RTCBundleInfo, llvm::SmallVector<std::unique_ptr<llvm::Module>>>;
3436
llvm::Expected<PostLinkResult>
3537
performPostLink(std::unique_ptr<llvm::Module> Module,
3638
const llvm::opt::InputArgList &UserArgList);

sycl-jit/jit-compiler/lib/translation/KernelTranslation.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,18 @@ llvm::Error KernelTranslator::translateKernel(SYCLKernelInfo &Kernel,
222222
return Error::success();
223223
}
224224

225-
llvm::Expected<RTCBundleBinaryInfo>
226-
KernelTranslator::translateBundleToSPIRV(llvm::Module &Mod,
225+
llvm::Expected<RTCDevImgBinaryInfo>
226+
KernelTranslator::translateDevImgToSPIRV(llvm::Module &Mod,
227227
JITContext &JITCtx) {
228228
llvm::Expected<KernelBinary *> BinaryOrError = translateToSPIRV(Mod, JITCtx);
229229
if (auto Error = BinaryOrError.takeError()) {
230230
return Error;
231231
}
232232
KernelBinary *Binary = *BinaryOrError;
233-
RTCBundleBinaryInfo BBI{BinaryFormat::SPIRV,
234-
Mod.getDataLayout().getPointerSizeInBits(),
235-
Binary->address(), Binary->size()};
236-
return BBI;
233+
RTCDevImgBinaryInfo DIBI{BinaryFormat::SPIRV,
234+
Mod.getDataLayout().getPointerSizeInBits(),
235+
Binary->address(), Binary->size()};
236+
return DIBI;
237237
}
238238

239239
llvm::Expected<KernelBinary *>

sycl-jit/jit-compiler/lib/translation/KernelTranslation.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ class KernelTranslator {
2727
static llvm::Error translateKernel(SYCLKernelInfo &Kernel, llvm::Module &Mod,
2828
JITContext &JITCtx, BinaryFormat Format);
2929

30-
static llvm::Expected<RTCBundleBinaryInfo>
31-
translateBundleToSPIRV(llvm::Module &Mod, JITContext &JITCtx);
30+
static llvm::Expected<RTCDevImgBinaryInfo>
31+
translateDevImgToSPIRV(llvm::Module &Mod, JITContext &JITCtx);
3232

3333
private:
3434
///

0 commit comments

Comments
 (0)