Skip to content

Commit 0686208

Browse files
authored
[SYCL][NATIVECPU] Fix linker errors for WorkGroup collective functions (#15144)
This PR fixes some link-time error related to WorkGroup collective functions: * adds overloads for the `_Float16` data type and the `__spirv_GroupBroadcast` function, which leads to `undefined reference` linker errors. * Makes it so the call to `clang-offload-deps` is not emitted by the driver: the call introduced an `llvm.used` array, which contained function pointers to the kernels which prevented eliminating those functions from the module even when they are not needed anymore (due to inlining). This lead to other `undefined reference` linker errors for WorkGroup collectives.
1 parent a97e30d commit 0686208

File tree

7 files changed

+228
-70
lines changed

7 files changed

+228
-70
lines changed

clang/lib/Driver/Driver.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6850,6 +6850,10 @@ class OffloadingActionBuilder final {
68506850
/// Offload deps output is then forwarded to active device action builders so
68516851
/// they can add it to the device linker inputs.
68526852
void addDeviceLinkDependenciesFromHost(ActionList &LinkerInputs) {
6853+
if (isSYCLNativeCPU(C.getArgs())) {
6854+
// SYCL Native CPU doesn't need deps from clang-offload-deps.
6855+
return;
6856+
}
68536857
// Link image for reading dependencies from it.
68546858
auto *LA = C.MakeAction<LinkJobAction>(LinkerInputs,
68556859
types::TY_Host_Dependencies_Image);

clang/lib/Driver/ToolChains/Clang.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10810,7 +10810,7 @@ static bool shouldEmitOnlyKernelsAsEntryPoints(const ToolChain &TC,
1081010810
options::OPT_fsycl_remove_unused_external_funcs, false))
1081110811
return false;
1081210812
if (isSYCLNativeCPU(TC))
10813-
return false;
10813+
return true;
1081410814
// When supporting dynamic linking, non-kernels in a device image can be
1081510815
// called.
1081610816
if (supportDynamicLinking(TCArgs))

clang/test/Driver/sycl-native-cpu-fsycl.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
//CHECK_INVO:{{.*}}clang{{.*}}"-fsycl-is-host"{{.*}}
4949
//CHECK_INVO:{{.*}}clang{{.*}}"-x" "ir"
5050
//CHECK_INVO:{{.*}}sycl-post-link{{.*}}"-emit-program-metadata"
51-
//CHECK_INVO-NOT:{{.*}}sycl-post-link{{.*}}-emit-only-kernels-as-entry-points
5251

5352
// checks that the device and host triple is correct in the generated actions when it is set explicitly
5453
//CHECK_ACTIONS-AARCH64: +- 5: offload, "host-sycl (aarch64-unknown-linux-gnu)" {1}, "device-sycl (aarch64-unknown-linux-gnu)" {4}, c++-cpp-output

libdevice/nativecpu_utils.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ DefineGOp1(All, __mux_sub_group_all_i1)
145145

146146
#define DefineFPGOps(Name, MuxName) \
147147
DefineGOp(float, float, Name, MuxName##32) \
148+
DefineGOp(_Float16 , _Float16 , Name, MuxName##16) \
148149
DefineGOp(double, double, Name, MuxName##64)
149150

150151
DefineIntGOps(IAdd, add_i)
@@ -170,16 +171,23 @@ DefineBitwiseGroupOp(uint32_t, int32_t, i32)
170171
DefineBitwiseGroupOp(int64_t, int64_t, i64)
171172
DefineBitwiseGroupOp(uint64_t, int64_t, i64)
172173

173-
#define DefineBroadCastImpl(Type, Sfx, MuxType, IDType) \
174-
DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
175-
int32_t id, MuxType val, int64_t lidx, int64_t lidy, int64_t lidz); \
176-
DEVICE_EXTERN_C MuxType __mux_sub_group_broadcast_##Sfx(MuxType val, \
177-
int32_t sg_lid); \
178-
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \
179-
IDType l) { \
180-
if (__spv::Scope::Flag::Subgroup == g) \
181-
return __mux_sub_group_broadcast_##Sfx(v, l); \
182-
return Type(); /*todo: add support for other flags as they are tested*/ \
174+
#define DefineLogicalGroupOp(Type, MuxType, mux_sfx) \
175+
DefineGOp(Type, MuxType, LogicalOrKHR, logical_or_##mux_sfx) \
176+
DefineGOp(Type, MuxType, LogicalXorKHR, logical_xor_##mux_sfx) \
177+
DefineGOp(Type, MuxType, LogicalAndKHR, logical_and_##mux_sfx)
178+
179+
DefineLogicalGroupOp(bool, bool, i1)
180+
181+
#define DefineBroadCastImpl(Type, Sfx, MuxType, IDType) \
182+
DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
183+
int32_t id, MuxType val, int64_t lidx, int64_t lidy, int64_t lidz); \
184+
DEVICE_EXTERN_C MuxType __mux_sub_group_broadcast_##Sfx(MuxType val, \
185+
int32_t sg_lid); \
186+
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \
187+
IDType l) { \
188+
if (__spv::Scope::Flag::Subgroup == g) \
189+
return __mux_sub_group_broadcast_##Sfx(v, l); \
190+
return Type(); /*todo: add support for other flags as they are tested*/ \
183191
}
184192

185193
#define DefineBroadcastMuxType(Type, Sfx, MuxType, IDType) \

llvm/lib/SYCLNativeCPUUtils/PrepareSYCLNativeCPU.cpp

Lines changed: 85 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "llvm/SYCLLowerIR/PrepareSYCLNativeCPU.h"
15+
#include "llvm/ADT/SmallSet.h"
16+
#include "llvm/ADT/StringRef.h"
1517
#include "llvm/IR/BasicBlock.h"
1618
#include "llvm/IR/Constant.h"
1719
#include "llvm/IR/DebugInfoMetadata.h"
20+
#include "llvm/IR/GlobalValue.h"
1821
#include "llvm/IR/PassManager.h"
1922
#include "llvm/SYCLLowerIR/SYCLUtils.h"
2023

@@ -23,7 +26,6 @@
2326
#include "llvm/ADT/SmallVector.h"
2427
#include "llvm/IR/Attributes.h"
2528
#include "llvm/IR/CallingConv.h"
26-
#include "llvm/IR/Constants.h"
2729
#include "llvm/IR/DerivedTypes.h"
2830
#include "llvm/IR/IRBuilder.h"
2931
#include "llvm/IR/Instruction.h"
@@ -35,13 +37,13 @@
3537
#include "llvm/Support/Casting.h"
3638
#include "llvm/Support/ErrorHandling.h"
3739
#include "llvm/Transforms/Utils/Cloning.h"
40+
#include "llvm/Transforms/Utils/GlobalStatus.h"
3841
#include "llvm/Transforms/Utils/ValueMapper.h"
3942
#include <utility>
4043
#include <vector>
4144

4245
#ifdef NATIVECPU_USE_OCK
4346
#include "compiler/utils/attributes.h"
44-
#include "compiler/utils/builtin_info.h"
4547
#include "compiler/utils/metadata.h"
4648
#endif
4749

@@ -331,31 +333,85 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M,
331333
UsedBuiltins.push_back({Glob, Entry.second});
332334
}
333335

334-
SmallVector<Function *> NewKernels;
335-
for (auto &OldF : OldKernels) {
336336
#ifdef NATIVECPU_USE_OCK
337-
auto Name = compiler::utils::getBaseFnNameOrFnName(*OldF);
338-
OldF->setName(Name);
339-
// if vectorization occurred, at this point we have a wrapper function that
340-
// runs the vectorized kernel and peels using the scalar kernel. We make it
341-
// so this wrapper steals the original kernel name.
342-
std::optional<compiler::utils::LinkMetadataResult> veczR =
343-
compiler::utils::parseVeczToOrigFnLinkMetadata(*OldF);
344-
if (veczR && veczR.value().first) {
345-
auto ScalarF = veczR.value().first;
346-
OldF->takeName(ScalarF);
347-
ScalarF->setName(OldF->getName() + "_scalar");
348-
} else if (Name != OldF->getName()) {
349-
auto RealKernel = M.getFunction(Name);
350-
if (RealKernel) {
351-
// the real kernel was not inlined in the wrapper, steal its name
352-
OldF->takeName(RealKernel);
337+
{
338+
SmallSet<Function *, 5> RemovableFuncs;
339+
SmallVector<Function *, 5> WrapperFuncs;
340+
341+
// Retrieve the wrapper functions created by the WorkItemLoop pass.
342+
for (auto &OldF : OldKernels) {
343+
std::optional<compiler::utils::LinkMetadataResult> VeczR =
344+
compiler::utils::parseVeczToOrigFnLinkMetadata(*OldF);
345+
if (VeczR && VeczR.value().first) {
346+
WrapperFuncs.push_back(OldF);
353347
} else {
354-
// the real kernel has been inlined, just use the name
355-
OldF->setName(Name);
348+
auto Name = compiler::utils::getBaseFnNameOrFnName(*OldF);
349+
if (Name != OldF->getName()) {
350+
WrapperFuncs.push_back(OldF);
351+
}
356352
}
357353
}
354+
355+
for (auto &OldF : WrapperFuncs) {
356+
// If vectorization occurred, at this point we have a wrapper function
357+
// that runs the vectorized kernel and peels using the scalar kernel. We
358+
// make it so this wrapper steals the original kernel name.
359+
std::optional<compiler::utils::LinkMetadataResult> VeczR =
360+
compiler::utils::parseVeczToOrigFnLinkMetadata(*OldF);
361+
if (VeczR && VeczR.value().first) {
362+
auto ScalarF = VeczR.value().first;
363+
OldF->takeName(ScalarF);
364+
if (ScalarF->use_empty())
365+
RemovableFuncs.insert(ScalarF);
366+
} else {
367+
// The WorkItemLoops pass created a wrapper function for the original
368+
// kernel. If we have a kernel named foo(), the wrapper will be called
369+
// foo-wrapper(), and will have the original kernel name retrieved by
370+
// getBaseFnNameOrFnName. We set the name of the wrapper function
371+
// to the original kernel name and add the original kernel to the
372+
// list of functions that can be removed from the module.
373+
auto Name = compiler::utils::getBaseFnNameOrFnName(*OldF);
374+
Function *OrigF = M.getFunction(Name);
375+
if (OrigF != nullptr) {
376+
// The original kernel is inlined by the WorkItemLoops
377+
// pass if it contained barriers or group collectives, otherwise
378+
// we don't want to (and can't) remove it.
379+
if (OrigF->use_empty())
380+
RemovableFuncs.insert(OrigF);
381+
OldF->takeName(OrigF);
382+
} else {
383+
OldF->setName(Name);
384+
}
385+
}
386+
}
387+
388+
// Find any left over SYCL_EXTERNAL function that has no more uses
389+
std::set<Function *> Kernelset(OldKernels.begin(), OldKernels.end());
390+
for (auto &F : M) {
391+
if (Kernelset.count(&F) == 0 &&
392+
F.hasFnAttribute(sycl::utils::ATTR_SYCL_MODULE_ID) && F.use_empty() &&
393+
!F.getName().starts_with("__dpcpp_nativecpu")) {
394+
// SYCL_EXTERNAL functions end up in static array of function pointers,
395+
// at this point we can remove them from the array and remove the
396+
// function if no other uses are left.
397+
RemovableFuncs.insert(&F);
398+
}
399+
}
400+
401+
// Remove unused functions. This is necessary in case they still contain
402+
// calls to group collective functions that haven't been processed by the
403+
// work item loops pass, which will lead to linker errors.
404+
llvm::erase_if(OldKernels,
405+
[&](Function *F) { return RemovableFuncs.contains(F); });
406+
407+
for (Function *F : RemovableFuncs) {
408+
F->eraseFromParent();
409+
}
410+
}
358411
#endif
412+
413+
SmallVector<Function *> NewKernels;
414+
for (auto &OldF : OldKernels) {
359415
auto *NewF =
360416
cloneFunctionAndAddParam(OldF, StatePtrType, CurrentStatePointerTLS);
361417
NewF->takeName(OldF);
@@ -416,54 +472,26 @@ PreservedAnalyses PrepareSYCLNativeCPUPass::run(Module &M,
416472
OldI->replaceAllUsesWith(NewI);
417473
OldI->eraseFromParent();
418474
}
419-
for (auto temp : ToRemove2)
420-
temp->eraseFromParent();
475+
for (auto Temp : ToRemove2)
476+
Temp->eraseFromParent();
421477

422478
// Finally, we erase the builtin from the module
423479
Glob->eraseFromParent();
424480
}
425481

426-
#ifdef NATIVECPU_USE_OCK
427-
// Define __mux_mem_barrier here using the OCK
428-
compiler::utils::BuiltinInfo BI;
429-
for (auto &F : M) {
430-
if (F.getName() == compiler::utils::MuxBuiltins::mem_barrier) {
431-
BI.defineMuxBuiltin(compiler::utils::BaseBuiltinID::eMuxBuiltinMemBarrier,
432-
M);
433-
}
434-
}
435-
// if we find calls to mux barrier now, it means that we had SYCL_EXTERNAL
436-
// functions that called __mux_work_group_barrier, which didn't get processed
437-
// by the WorkItemLoop pass. This means that the actual function call has been
438-
// inlined into the kernel, and the call to __mux_work_group_barrier has been
439-
// removed in the inlined call, but not in the original function. The original
440-
// function will not be executed (since it has been inlined) and so we can
441-
// just define __mux_work_group_barrier as a no-op to avoid linker errors.
442-
// Todo: currently we can't remove the function here even if it has no uses,
443-
// because we may still emit a declaration for it in the offload-wrapper.
444-
auto BarrierF =
445-
M.getFunction(compiler::utils::MuxBuiltins::work_group_barrier);
446-
if (BarrierF && BarrierF->isDeclaration()) {
447-
IRBuilder<> Builder(M.getContext());
448-
auto BB = BasicBlock::Create(M.getContext(), "noop", BarrierF);
449-
Builder.SetInsertPoint(BB);
450-
Builder.CreateRetVoid();
451-
}
452-
#endif
453-
454-
// removing unused builtins
482+
// Removing unused builtins
455483
SmallVector<Function *> UnusedLibBuiltins;
456484
for (auto &F : M) {
457485
if (IsUnusedBuiltinOrPrivateDef(F)) {
458486
UnusedLibBuiltins.push_back(&F);
459487
}
460488
}
461-
for (Function *f : UnusedLibBuiltins) {
462-
f->eraseFromParent();
489+
for (Function *F : UnusedLibBuiltins) {
490+
F->eraseFromParent();
463491
ModuleChanged = true;
464492
}
465-
for (auto it = M.begin(); it != M.end();) {
466-
auto Curr = it++;
493+
for (auto It = M.begin(); It != M.end();) {
494+
auto Curr = It++;
467495
Function &F = *Curr;
468496
if (F.getNumUses() == 0 && F.isDeclaration() &&
469497
F.getName().starts_with("__mux_")) {
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// REQUIRES: native_cpu_ock
2+
3+
// Tests that no linker errors occur when group collective functions are used
4+
// in conjuction with SYCL_EXTERNAL.
5+
6+
// RUN: %clangxx -fsycl -fsycl-targets=native_cpu -DFILE1 -c -o %t1.o %s
7+
// RUN: %clangxx -fsycl -fsycl-targets=native_cpu -DFILE2 -c -o %t2.o %s
8+
// RUN: llvm-ar crv %t1.a %t1.o
9+
// RUN: %clangxx -fsycl -fsycl-targets=native_cpu %t2.o %t1.a -o %t.out
10+
// RUN: env ONEAPI_DEVICE_SELECTOR=native_cpu:cpu %t.out
11+
12+
/*
13+
test performs a lattice reduction.
14+
sycl::vec<float> is sensitive to .get_size() vs .size() in SYCL headers
15+
(ie, byte size versus vector size)
16+
*/
17+
18+
#include <sycl/detail/core.hpp>
19+
#include <sycl/group_algorithm.hpp>
20+
#include <sycl/usm.hpp>
21+
22+
using namespace sycl;
23+
24+
#define NX 32
25+
#define NZ 2
26+
#define NV 8
27+
using vecn = sycl::vec<float, NV>; // 8 floats
28+
#ifdef FILE1
29+
30+
SYCL_EXTERNAL void groupSum(vecn *r, const vecn &in, const int k,
31+
sycl::group<2> &grp, const int i) {
32+
33+
vecn tin = (i == k ? in : vecn(0));
34+
auto out = reduce_over_group(grp, tin, sycl::plus<>());
35+
if (i == k && grp.get_local_id()[1] == 0)
36+
r[k] = out;
37+
}
38+
#endif
39+
40+
#ifdef FILE2
41+
SYCL_EXTERNAL void groupSum(vecn *r, const vecn &in, const int k,
42+
sycl::group<2> &grp, const int i);
43+
void test(queue q, float *r, float *x,
44+
int n) { // r is 16 floats, x is 256 floats. n is 256
45+
46+
sycl::range<2> globalSize(NZ, NX); // 2,32
47+
sycl::range<2> localSize(1, NX); // 1,8 so 16 iterations
48+
sycl::nd_range<2> range{globalSize, localSize};
49+
50+
q.submit([&](sycl::handler &h) {
51+
h.parallel_for<>(range, [=](sycl::nd_item<2> ndi) {
52+
int i = ndi.get_global_id(1);
53+
int k = ndi.get_global_id(0);
54+
55+
auto vx = reinterpret_cast<vecn *>(x);
56+
auto vr = reinterpret_cast<vecn *>(r);
57+
58+
auto myg = ndi.get_group();
59+
60+
for (int iz = 0; iz < NZ; iz++) { // loop over Z (2)
61+
groupSum(vr, vx[k * NX + i], k, myg, iz);
62+
}
63+
});
64+
});
65+
q.wait();
66+
}
67+
68+
int main() {
69+
70+
queue q{default_selector_v};
71+
auto dev = q.get_device();
72+
std::cout << "Device: " << dev.get_info<info::device::name>() << std::endl;
73+
74+
auto ctx = q.get_context();
75+
int n = NX * NZ * NV; // 16 * 8 * 2 => 256
76+
auto *x = (float *)sycl::malloc_shared(n * sizeof(float), dev,
77+
ctx); // 256 * sizeof(float)
78+
auto *r = (float *)sycl::malloc_shared(
79+
NZ * NV * sizeof(float), dev, ctx); // 2 * 8 => 16 ( * sizeof(float) )
80+
81+
for (int i = 0; i < n; i++) {
82+
x[i] = i;
83+
}
84+
85+
q.wait();
86+
87+
test(q, r, x, n);
88+
89+
int fails = 0;
90+
for (int k = 0; k < NZ; k++) {
91+
float s[NV] = {0};
92+
for (int i = 0; i < NX; i++) {
93+
for (int j = 0; j < NV; j++) {
94+
s[j] += x[(k * NX + i) * NV + j];
95+
}
96+
}
97+
for (int j = 0; j < NV; j++) {
98+
auto d = s[j] - r[k * NV + j];
99+
if (std::abs(d) > 1e-10) {
100+
printf("partial fail ");
101+
printf("%i\t%i\t%g\t%g\n", k, j, s[j], r[k * NV + j]);
102+
fails++;
103+
} else {
104+
printf("partial pass ");
105+
printf("%i\t%i\t%g\t%g\n", k, j, s[j], r[k * NV + j]);
106+
}
107+
}
108+
}
109+
110+
if (fails == 0) {
111+
printf("test passed!\n");
112+
} else {
113+
printf("test failed!\n");
114+
}
115+
free(x, ctx);
116+
free(r, ctx);
117+
return fails;
118+
}
119+
#endif

0 commit comments

Comments
 (0)