Skip to content

Commit b90391e

Browse files
authored
[SYCL][Clang] Permit virtual functions in SYCL (#7255)
To enable virtual functions support in SYCL we need to change virtual table variables address space, so llvm-spirv translator is able to generate valid SPIR-V modules.
1 parent 082cde6 commit b90391e

File tree

5 files changed

+71
-7
lines changed

5 files changed

+71
-7
lines changed

clang/include/clang/Basic/LangOptions.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ LANGOPT(
293293
LANGOPT(SYCLDisableRangeRounding, 1, 0, "Disable parallel for range rounding")
294294
LANGOPT(SYCLEnableIntHeaderDiags, 1, 0, "Enable diagnostics that require the "
295295
"SYCL integration header")
296+
LANGOPT(SYCLAllowVirtualFunctions, 1, 0,
297+
"Allow virtual functions calls in code for SYCL device")
296298

297299
LANGOPT(HIPUseNewLaunchAPI, 1, 0, "Use new kernel launching API for HIP")
298300

clang/include/clang/Driver/Options.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6614,6 +6614,9 @@ def fsycl_use_main_file_name : Flag<["-"], "fsycl-use-main-file-name">,
66146614
HelpText<"Tells compiler that -main-file-name contains an absolute path and "
66156615
"file specified there should be used for checksum calculation.">,
66166616
MarshallingInfoFlag<CodeGenOpts<"SYCLUseMainFileName">>;
6617+
def fsycl_allow_virtual_functions : Flag<["-"], "fsycl-allow-virtual-functions">,
6618+
HelpText<"Allow virtual functions calls in code for SYCL device">,
6619+
MarshallingInfoFlag<LangOpts<"SYCLAllowVirtualFunctions">>;
66176620

66186621
} // let Flags = [CC1Option, NoDriverOption]
66196622

clang/lib/CodeGen/ItaniumCXXABI.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3572,8 +3572,33 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty) {
35723572
// Check if the alias exists. If it doesn't, then get or create the global.
35733573
if (CGM.getItaniumVTableContext().isRelativeLayout())
35743574
VTable = CGM.getModule().getNamedAlias(VTableName);
3575-
if (!VTable)
3576-
VTable = CGM.getModule().getOrInsertGlobal(VTableName, CGM.Int8PtrTy);
3575+
3576+
// To generate valid device code global pointers should have global address
3577+
// space in SYCL.
3578+
bool GenTyInfoGVWithGlobalAS =
3579+
CGM.getLangOpts().SYCLIsDevice &&
3580+
CGM.getLangOpts().SYCLAllowVirtualFunctions &&
3581+
(VTableName == ClassTypeInfo || VTableName == SIClassTypeInfo);
3582+
auto VTableTy =
3583+
GenTyInfoGVWithGlobalAS
3584+
? CGM.Int8Ty->getPointerTo(
3585+
CGM.getContext().getTargetAddressSpace(LangAS::sycl_global))
3586+
: CGM.Int8PtrTy;
3587+
if (!VTable) {
3588+
if (GenTyInfoGVWithGlobalAS) {
3589+
VTable = CGM.getModule().getOrInsertGlobal(VTableName, VTableTy, [&] {
3590+
return new llvm::GlobalVariable(
3591+
CGM.getModule(), VTableTy, /*isConstant=*/false,
3592+
llvm::GlobalVariable::ExternalLinkage, /*Initializer=*/nullptr,
3593+
VTableName, /*InsertBefore=*/nullptr,
3594+
llvm::GlobalValue::ThreadLocalMode::NotThreadLocal,
3595+
llvm::Optional<unsigned>(
3596+
CGM.getContext().getTargetAddressSpace(LangAS::sycl_global)));
3597+
});
3598+
} else {
3599+
VTable = CGM.getModule().getOrInsertGlobal(VTableName, VTableTy);
3600+
}
3601+
}
35773602

35783603
CGM.setDSOLocal(cast<llvm::GlobalValue>(VTable->stripPointerCasts()));
35793604

@@ -3585,15 +3610,15 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty) {
35853610
// The vtable address point is 8 bytes after its start:
35863611
// 4 for the offset to top + 4 for the relative offset to rtti.
35873612
llvm::Constant *Eight = llvm::ConstantInt::get(CGM.Int32Ty, 8);
3588-
VTable = llvm::ConstantExpr::getBitCast(VTable, CGM.Int8PtrTy);
3613+
VTable = llvm::ConstantExpr::getBitCast(VTable, VTableTy);
35893614
VTable =
35903615
llvm::ConstantExpr::getInBoundsGetElementPtr(CGM.Int8Ty, VTable, Eight);
35913616
} else {
35923617
llvm::Constant *Two = llvm::ConstantInt::get(PtrDiffTy, 2);
3593-
VTable = llvm::ConstantExpr::getInBoundsGetElementPtr(CGM.Int8PtrTy, VTable,
3594-
Two);
3618+
VTable =
3619+
llvm::ConstantExpr::getInBoundsGetElementPtr(VTableTy, VTable, Two);
35953620
}
3596-
VTable = llvm::ConstantExpr::getBitCast(VTable, CGM.Int8PtrTy);
3621+
VTable = llvm::ConstantExpr::getBitCast(VTable, VTableTy);
35973622

35983623
Fields.push_back(VTable);
35993624
}

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,8 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
579579
}
580580

581581
if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
582-
if (Method->isVirtual())
582+
if (Method->isVirtual() &&
583+
!SemaRef.getLangOpts().SYCLAllowVirtualFunctions)
583584
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
584585
<< Sema::KernelCallVirtualFunction;
585586

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// This test checks that the FE generates global variables corresponding to the
2+
// virtual table in the global address space (addrspace(1)) when
3+
// -fsycl-allow-virtual-functions is passed.
4+
5+
// RUN: %clang_cc1 -triple spir64 -fsycl-allow-virtual-functions -fsycl-is-device -emit-llvm %s -o - | FileCheck %s --check-prefixes CHECK,CHECK-PTR
6+
// RUN: %clang_cc1 -triple spir64 -fsycl-allow-virtual-functions -fsycl-is-device -fexperimental-relative-c++-abi-vtables -emit-llvm %s -o - | FileCheck %s --check-prefixes CHECK,CHECK-REL
7+
8+
// CHECK: @_ZTVN10__cxxabiv120__si_class_type_infoE = external addrspace(1) global ptr addrspace(1)
9+
// CHECK: @_ZTVN10__cxxabiv117__class_type_infoE = external addrspace(1) global ptr addrspace(1)
10+
// CHECK-PTR: @_ZTI4Base = linkonce_odr constant { ptr addrspace(1), ptr } { ptr addrspace(1) getelementptr inbounds (ptr addrspace(1), ptr addrspace(1) @_ZTVN10__cxxabiv117__class_type_infoE, i64 2)
11+
// CHECK-PTR: @_ZTI8Derived1 = linkonce_odr constant { ptr addrspace(1), ptr, ptr } { ptr addrspace(1) getelementptr inbounds (ptr addrspace(1), ptr addrspace(1) @_ZTVN10__cxxabiv120__si_class_type_infoE, i64 2)
12+
// CHECK-REL: @_ZTI4Base = linkonce_odr constant { ptr addrspace(1), ptr } { ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @_ZTVN10__cxxabiv117__class_type_infoE, i32 8)
13+
// CHECK-REL: @_ZTI8Derived1 = linkonce_odr constant { ptr addrspace(1), ptr, ptr } { ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @_ZTVN10__cxxabiv120__si_class_type_infoE, i32 8)
14+
15+
SYCL_EXTERNAL bool rand();
16+
17+
class Base {
18+
public:
19+
virtual void display() {}
20+
};
21+
22+
class Derived1 : public Base {
23+
public:
24+
void display() {}
25+
};
26+
27+
SYCL_EXTERNAL void test() {
28+
Derived1 d1;
29+
Base *b = nullptr;
30+
if (rand())
31+
b = &d1;
32+
b->display();
33+
}

0 commit comments

Comments
 (0)