Skip to content

Commit 5ef2c7c

Browse files
authored
[SYCL][RTC] Query kernels by source code name (#17032)
This PR adds full support for the `registered_kernel_names` property to query kernels by their souce code name (and instantiate template kernels), leveraging the new `[[__sycl_detail__::__registered_kernels__(...)]]` attribute added in #16485 and #16821. Also, `kernel_bundle::ext_oneapi_get_raw_kernel_name` is implemented following the draft spec #11985. --------- Signed-off-by: Julian Oppermann <[email protected]>
1 parent 4c9b19b commit 5ef2c7c

16 files changed

+170
-73
lines changed

sycl/include/sycl/kernel_bundle.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ class __SYCL_EXPORT kernel_bundle_plain {
235235
return ext_oneapi_get_kernel(detail::string_view{name});
236236
}
237237

238+
std::string ext_oneapi_get_raw_kernel_name(const std::string &name) {
239+
return std::string{
240+
ext_oneapi_get_raw_kernel_name(detail::string_view{name}).c_str()};
241+
}
242+
238243
protected:
239244
// \returns a kernel object which represents the kernel identified by
240245
// kernel_id passed
@@ -263,6 +268,7 @@ class __SYCL_EXPORT kernel_bundle_plain {
263268
private:
264269
bool ext_oneapi_has_kernel(detail::string_view name);
265270
kernel ext_oneapi_get_kernel(detail::string_view name);
271+
detail::string ext_oneapi_get_raw_kernel_name(detail::string_view name);
266272
};
267273

268274
} // namespace detail
@@ -483,6 +489,16 @@ class kernel_bundle : public detail::kernel_bundle_plain,
483489
return detail::kernel_bundle_plain::ext_oneapi_get_kernel(name);
484490
}
485491

492+
/////////////////////////
493+
// ext_oneapi_get_raw_kernel_name
494+
// kernel_bundle must be created from source, throws if not present
495+
/////////////////////////
496+
template <bundle_state _State = State,
497+
typename = std::enable_if_t<_State == bundle_state::executable>>
498+
std::string ext_oneapi_get_raw_kernel_name(const std::string &name) {
499+
return detail::kernel_bundle_plain::ext_oneapi_get_raw_kernel_name(name);
500+
}
501+
486502
private:
487503
kernel_bundle(detail::KernelBundleImplPtr Impl)
488504
: kernel_bundle_plain(std::move(Impl)) {}

sycl/source/detail/compiler.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@
6868
#define __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS "SYCL/virtual functions"
6969
/// PropertySetRegistry::SYCL_IMPLICIT_LOCAL_ARG defined in PropertySetIO.h
7070
#define __SYCL_PROPERTY_SET_SYCL_IMPLICIT_LOCAL_ARG "SYCL/implicit local arg"
71+
/// PropertySetRegistry::SYCL_REGISTERED_KERNELS defined in PropertySetIO.h
72+
#define __SYCL_PROPERTY_SET_SYCL_REGISTERED_KERNELS "SYCL/registered kernels"
7173

7274
/// Program metadata tags recognized by the PI backends. For kernels the tag
7375
/// must appear after the kernel name.

sycl/source/detail/device_binary_image.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) {
195195
DeviceRequirements.init(Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_REQUIREMENTS);
196196
HostPipes.init(Bin, __SYCL_PROPERTY_SET_SYCL_HOST_PIPES);
197197
VirtualFunctions.init(Bin, __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS);
198+
RegisteredKernels.init(Bin, __SYCL_PROPERTY_SET_SYCL_REGISTERED_KERNELS);
198199

199200
ImageId = ImageCounter++;
200201
}

sycl/source/detail/device_binary_image.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,9 @@ class RTDeviceBinaryImage {
232232
const PropertyRange &getHostPipes() const { return HostPipes; }
233233
const PropertyRange &getVirtualFunctions() const { return VirtualFunctions; }
234234
const PropertyRange &getImplicitLocalArg() const { return ImplicitLocalArg; }
235+
const PropertyRange &getRegisteredKernels() const {
236+
return RegisteredKernels;
237+
}
235238

236239
std::uintptr_t getImageID() const {
237240
assert(Bin && "Image ID is not available without a binary image.");
@@ -258,6 +261,7 @@ class RTDeviceBinaryImage {
258261
RTDeviceBinaryImage::PropertyRange HostPipes;
259262
RTDeviceBinaryImage::PropertyRange VirtualFunctions;
260263
RTDeviceBinaryImage::PropertyRange ImplicitLocalArg;
264+
RTDeviceBinaryImage::PropertyRange RegisteredKernels;
261265

262266
std::vector<ur_program_metadata_t> ProgramMetadataUR;
263267

sycl/source/detail/jit_compiler.cpp

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,29 +1259,16 @@ std::vector<uint8_t> jit_compiler::encodeReqdWorkGroupSize(
12591259
std::pair<sycl_device_binaries, std::string> jit_compiler::compileSYCL(
12601260
const std::string &CompilationID, const std::string &SYCLSource,
12611261
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
1262-
const std::vector<std::string> &UserArgs, std::string *LogPtr,
1263-
const std::vector<std::string> &RegisteredKernelNames) {
1262+
const std::vector<std::string> &UserArgs, std::string *LogPtr) {
12641263
auto appendToLog = [LogPtr](const char *Msg) {
12651264
if (LogPtr) {
12661265
LogPtr->append(Msg);
12671266
}
12681267
};
12691268

1270-
// RegisteredKernelNames may contain template specializations, so we just put
1271-
// them in main() which ensures they are instantiated.
1272-
std::ostringstream ss;
1273-
ss << SYCLSource << '\n';
1274-
ss << "int main() {\n";
1275-
for (const std::string &KernelName : RegisteredKernelNames) {
1276-
ss << " (void)" << KernelName << ";\n";
1277-
}
1278-
ss << " return 0;\n}\n" << std::endl;
1279-
1280-
std::string FinalSource = ss.str();
1281-
12821269
std::string SYCLFileName = CompilationID + ".cpp";
12831270
::jit_compiler::InMemoryFile SourceFile{SYCLFileName.c_str(),
1284-
FinalSource.c_str()};
1271+
SYCLSource.c_str()};
12851272

12861273
std::vector<::jit_compiler::InMemoryFile> IncludeFilesView;
12871274
IncludeFilesView.reserve(IncludePairs.size());

sycl/source/detail/jit_compiler.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ class jit_compiler {
5252
std::pair<sycl_device_binaries, std::string> compileSYCL(
5353
const std::string &CompilationID, const std::string &SYCLSource,
5454
const std::vector<std::pair<std::string, std::string>> &IncludePairs,
55-
const std::vector<std::string> &UserArgs, std::string *LogPtr,
56-
const std::vector<std::string> &RegisteredKernelNames);
55+
const std::vector<std::string> &UserArgs, std::string *LogPtr);
5756

5857
void destroyDeviceBinaries(sycl_device_binaries Binaries);
5958

sycl/source/detail/kernel_bundle_impl.hpp

Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,13 @@ class kernel_bundle_impl {
378378

379379
// oneapi_ext_kernel_compiler
380380
// program manager integration, only for sycl_jit language
381-
kernel_bundle_impl(context Ctx, std::vector<device> Devs,
382-
const std::vector<kernel_id> &KernelIDs,
383-
std::vector<std::string> KNames,
384-
sycl_device_binaries Binaries, std::string Pfx,
385-
syclex::source_language Lang)
381+
kernel_bundle_impl(
382+
context Ctx, std::vector<device> Devs,
383+
const std::vector<kernel_id> &KernelIDs,
384+
std::vector<std::string> &&KernelNames,
385+
std::unordered_map<std::string, std::string> &&MangledKernelNames,
386+
sycl_device_binaries Binaries, std::string &&Prefix,
387+
syclex::source_language Lang)
386388
: kernel_bundle_impl(std::move(Ctx), std::move(Devs), KernelIDs,
387389
bundle_state::executable) {
388390
assert(Lang == syclex::source_language::sycl_jit);
@@ -392,9 +394,10 @@ class kernel_bundle_impl {
392394
// loaded via the program manager have `kernel_id`s, they can't be looked up
393395
// from the (unprefixed) kernel name.
394396
MIsInterop = true;
395-
MKernelNames = std::move(KNames);
397+
MKernelNames = std::move(KernelNames);
398+
MMangledKernelNames = std::move(MangledKernelNames);
396399
MDeviceBinaries = Binaries;
397-
MPrefix = std::move(Pfx);
400+
MPrefix = std::move(Prefix);
398401
MLanguage = Lang;
399402
}
400403

@@ -499,27 +502,70 @@ class kernel_bundle_impl {
499502
if (MLanguage == syclex::source_language::sycl_jit) {
500503
// Build device images via the program manager.
501504
const std::string &SourceStr = std::get<std::string>(MSource);
505+
std::ostringstream SourceExt;
506+
if (!RegisteredKernelNames.empty()) {
507+
SourceExt << SourceStr << '\n';
508+
509+
auto EmitEntry =
510+
[&SourceExt](const std::string &Name) -> std::ostringstream & {
511+
SourceExt << " {\"" << Name << "\", " << Name << "}";
512+
return SourceExt;
513+
};
514+
515+
SourceExt << "[[__sycl_detail__::__registered_kernels__(\n";
516+
for (auto It = RegisteredKernelNames.begin(),
517+
SecondToLast = RegisteredKernelNames.end() - 1;
518+
It != SecondToLast; ++It) {
519+
EmitEntry(*It) << ",\n";
520+
}
521+
EmitEntry(RegisteredKernelNames.back()) << "\n";
522+
SourceExt << ")]];\n";
523+
}
524+
502525
auto [Binaries, Prefix] = syclex::detail::SYCL_JIT_to_SPIRV(
503-
SourceStr, MIncludePairs, BuildOptions, LogPtr,
504-
RegisteredKernelNames);
526+
RegisteredKernelNames.empty() ? SourceStr : SourceExt.str(),
527+
MIncludePairs, BuildOptions, LogPtr);
505528

506529
auto &PM = detail::ProgramManager::getInstance();
507530
PM.addImages(Binaries);
508531

509532
std::vector<kernel_id> KernelIDs;
510533
std::vector<std::string> KernelNames;
534+
std::unordered_map<std::string, std::string> MangledKernelNames;
511535
for (const auto &KernelID : PM.getAllSYCLKernelIDs()) {
512536
std::string_view KernelName{KernelID.get_name()};
513537
if (KernelName.find(Prefix) == 0) {
514538
KernelIDs.push_back(KernelID);
515539
KernelName.remove_prefix(Prefix.length());
516540
KernelNames.emplace_back(KernelName);
541+
static constexpr std::string_view SYCLKernelMarker{"__sycl_kernel_"};
542+
if (KernelName.find(SYCLKernelMarker) == 0) {
543+
// extern "C" declaration, implicitly register kernel without the
544+
// marker.
545+
std::string_view KernelNameWithoutMarker{KernelName};
546+
KernelNameWithoutMarker.remove_prefix(SYCLKernelMarker.length());
547+
MangledKernelNames.emplace(KernelNameWithoutMarker, KernelName);
548+
}
517549
}
518550
}
519551

520-
return std::make_shared<kernel_bundle_impl>(MContext, MDevices, KernelIDs,
521-
KernelNames, Binaries, Prefix,
522-
MLanguage);
552+
// Apply frontend information.
553+
for (const auto *RawImg : PM.getRawDeviceImages(KernelIDs)) {
554+
for (const sycl_device_binary_property &RKProp :
555+
RawImg->getRegisteredKernels()) {
556+
557+
auto BA = DeviceBinaryProperty(RKProp).asByteArray();
558+
auto MangledNameLen = BA.consume<uint64_t>() / 8 /*bits in a byte*/;
559+
std::string_view MangledName{
560+
reinterpret_cast<const char *>(BA.begin()), MangledNameLen};
561+
MangledKernelNames.emplace(RKProp->Name, MangledName);
562+
}
563+
}
564+
565+
return std::make_shared<kernel_bundle_impl>(
566+
MContext, MDevices, KernelIDs, std::move(KernelNames),
567+
std::move(MangledKernelNames), Binaries, std::move(Prefix),
568+
MLanguage);
523569
}
524570

525571
ur_program_handle_t UrProgram = nullptr;
@@ -625,21 +671,27 @@ class kernel_bundle_impl {
625671
KernelNames, MLanguage);
626672
}
627673

628-
std::string adjust_kernel_name(const std::string &Name,
629-
syclex::source_language Lang) {
630-
// Once name demangling support is in, we won't need this.
631-
if (Lang != syclex::source_language::sycl &&
632-
Lang != syclex::source_language::sycl_jit)
633-
return Name;
674+
std::string adjust_kernel_name(const std::string &Name) {
675+
if (MLanguage == syclex::source_language::sycl_jit) {
676+
auto It = MMangledKernelNames.find(Name);
677+
return It == MMangledKernelNames.end() ? Name : It->second;
678+
}
634679

635-
bool isMangled = Name.find("__sycl_kernel_") != std::string::npos;
636-
return isMangled ? Name : "__sycl_kernel_" + Name;
680+
if (MLanguage == syclex::source_language::sycl) {
681+
bool isMangled = Name.find("__sycl_kernel_") != std::string::npos;
682+
return isMangled ? Name : "__sycl_kernel_" + Name;
683+
}
684+
685+
return Name;
686+
}
687+
688+
bool is_kernel_name(const std::string &Name) {
689+
return std::find(MKernelNames.begin(), MKernelNames.end(), Name) !=
690+
MKernelNames.end();
637691
}
638692

639693
bool ext_oneapi_has_kernel(const std::string &Name) {
640-
auto it = std::find(MKernelNames.begin(), MKernelNames.end(),
641-
adjust_kernel_name(Name, MLanguage));
642-
return it != MKernelNames.end();
694+
return is_kernel_name(adjust_kernel_name(Name));
643695
}
644696

645697
kernel
@@ -649,13 +701,12 @@ class kernel_bundle_impl {
649701
throw sycl::exception(make_error_code(errc::invalid),
650702
"'ext_oneapi_get_kernel' is only available in "
651703
"kernel_bundles successfully built from "
652-
"kernel_bundle<bundle_state:ext_oneapi_source>.");
704+
"kernel_bundle<bundle_state::ext_oneapi_source>.");
653705

654-
std::string AdjustedName = adjust_kernel_name(Name, MLanguage);
655-
if (!ext_oneapi_has_kernel(Name))
706+
std::string AdjustedName = adjust_kernel_name(Name);
707+
if (!is_kernel_name(AdjustedName))
656708
throw sycl::exception(make_error_code(errc::invalid),
657-
"kernel '" + AdjustedName +
658-
"' not found in kernel_bundle");
709+
"kernel '" + Name + "' not found in kernel_bundle");
659710

660711
if (MLanguage == syclex::source_language::sycl_jit) {
661712
auto &PM = ProgramManager::getInstance();
@@ -697,6 +748,22 @@ class kernel_bundle_impl {
697748
return detail::createSyclObjFromImpl<kernel>(KernelImpl);
698749
}
699750

751+
std::string ext_oneapi_get_raw_kernel_name(const std::string &Name) {
752+
if (MKernelNames.empty())
753+
throw sycl::exception(
754+
make_error_code(errc::invalid),
755+
"'ext_oneapi_get_raw_kernel_name' is only available in "
756+
"kernel_bundles successfully built from "
757+
"kernel_bundle<bundle_state::ext_oneapi_source>.");
758+
759+
std::string AdjustedName = adjust_kernel_name(Name);
760+
if (!is_kernel_name(AdjustedName))
761+
throw sycl::exception(make_error_code(errc::invalid),
762+
"kernel '" + Name + "' not found in kernel_bundle");
763+
764+
return AdjustedName;
765+
}
766+
700767
bool empty() const noexcept { return MDeviceImages.empty(); }
701768

702769
backend get_backend() const noexcept {
@@ -872,12 +939,11 @@ class kernel_bundle_impl {
872939
}
873940

874941
bool is_specialization_constant_set(const char *SpecName) const noexcept {
875-
bool SetInDevImg =
876-
std::any_of(begin(), end(),
877-
[SpecName](const device_image_plain &DeviceImage) {
878-
return getSyclObjImpl(DeviceImage)
879-
->is_specialization_constant_set(SpecName);
880-
});
942+
bool SetInDevImg = std::any_of(
943+
begin(), end(), [SpecName](const device_image_plain &DeviceImage) {
944+
return getSyclObjImpl(DeviceImage)
945+
->is_specialization_constant_set(SpecName);
946+
});
881947
return SetInDevImg || MSpecConstValues.count(std::string{SpecName}) != 0;
882948
}
883949

@@ -968,6 +1034,7 @@ class kernel_bundle_impl {
9681034
const std::variant<std::string, std::vector<std::byte>> MSource;
9691035
// only kernel_bundles created from source have KernelNames member.
9701036
std::vector<std::string> MKernelNames;
1037+
std::unordered_map<std::string, std::string> MMangledKernelNames;
9711038
sycl_device_binaries MDeviceBinaries = nullptr;
9721039
std::string MPrefix;
9731040
include_pairs_t MIncludePairs;

sycl/source/detail/kernel_compiler/kernel_compiler_sycl.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -303,18 +303,16 @@ bool SYCL_JIT_Compilation_Available() {
303303
#endif
304304
}
305305

306-
std::pair<sycl_device_binaries, std::string> SYCL_JIT_to_SPIRV(
307-
[[maybe_unused]] const std::string &SYCLSource,
308-
[[maybe_unused]] const include_pairs_t &IncludePairs,
309-
[[maybe_unused]] const std::vector<std::string> &UserArgs,
310-
[[maybe_unused]] std::string *LogPtr,
311-
[[maybe_unused]] const std::vector<std::string> &RegisteredKernelNames) {
306+
std::pair<sycl_device_binaries, std::string>
307+
SYCL_JIT_to_SPIRV([[maybe_unused]] const std::string &SYCLSource,
308+
[[maybe_unused]] const include_pairs_t &IncludePairs,
309+
[[maybe_unused]] const std::vector<std::string> &UserArgs,
310+
[[maybe_unused]] std::string *LogPtr) {
312311
#if SYCL_EXT_JIT_ENABLE
313312
static std::atomic_uintptr_t CompilationCounter;
314313
std::string CompilationID = "rtc_" + std::to_string(CompilationCounter++);
315314
return sycl::detail::jit_compiler::get_instance().compileSYCL(
316-
CompilationID, SYCLSource, IncludePairs, UserArgs, LogPtr,
317-
RegisteredKernelNames);
315+
CompilationID, SYCLSource, IncludePairs, UserArgs, LogPtr);
318316
#else
319317
throw sycl::exception(sycl::errc::build,
320318
"kernel_compiler via sycl-jit is not available");

sycl/source/detail/kernel_compiler/kernel_compiler_sycl.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,9 @@ std::string userArgsAsString(const std::vector<std::string> &UserArguments);
4040
//
4141
// Returns a pointer to the image (owned by the `jit_compiler` class), and the
4242
// bundle-specific prefix used for loading the kernels.
43-
std::pair<sycl_device_binaries, std::string>
44-
SYCL_JIT_to_SPIRV(const std::string &Source,
45-
const include_pairs_t &IncludePairs,
46-
const std::vector<std::string> &UserArgs, std::string *LogPtr,
47-
const std::vector<std::string> &RegisteredKernelNames);
43+
std::pair<sycl_device_binaries, std::string> SYCL_JIT_to_SPIRV(
44+
const std::string &Source, const include_pairs_t &IncludePairs,
45+
const std::vector<std::string> &UserArgs, std::string *LogPtr);
4846

4947
void SYCL_JIT_destroy(sycl_device_binaries Binaries);
5048

sycl/source/kernel_bundle.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ kernel kernel_bundle_plain::ext_oneapi_get_kernel(detail::string_view name) {
137137
return impl->ext_oneapi_get_kernel(name.data(), impl);
138138
}
139139

140+
detail::string
141+
kernel_bundle_plain::ext_oneapi_get_raw_kernel_name(detail::string_view name) {
142+
return detail::string{impl->ext_oneapi_get_raw_kernel_name(name.data())};
143+
}
144+
140145
//////////////////////////////////
141146
///// sycl::detail free functions
142147
//////////////////////////////////

sycl/test-e2e/KernelCompiler/kernel_compiler_opencl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ void test_build_and_run() {
139139
assert(hasHerKernel && "her_kernel should exist, but doesn't");
140140
assert(!notExistKernel && "non-existing kernel should NOT exist, but does?");
141141

142+
assert(
143+
kbExe2.ext_oneapi_get_raw_kernel_name("my_kernel") == "my_kernel" &&
144+
"source code name and compiler-generated name should match, but don't");
145+
142146
sycl::kernel my_kernel = kbExe2.ext_oneapi_get_kernel("my_kernel");
143147
sycl::kernel her_kernel = kbExe2.ext_oneapi_get_kernel("her_kernel");
144148

0 commit comments

Comments
 (0)