Skip to content

[SYCL][Clang] Permit virtual functions in SYCL #7255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/LangOptions.def
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@ LANGOPT(
LANGOPT(SYCLDisableRangeRounding, 1, 0, "Disable parallel for range rounding")
LANGOPT(SYCLEnableIntHeaderDiags, 1, 0, "Enable diagnostics that require the "
"SYCL integration header")
LANGOPT(SYCLAllowVirtualFunctions, 1, 0,
"Allow virtual functions calls in code for SYCL device")

LANGOPT(HIPUseNewLaunchAPI, 1, 0, "Use new kernel launching API for HIP")

Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -6584,6 +6584,9 @@ def fsycl_use_main_file_name : Flag<["-"], "fsycl-use-main-file-name">,
HelpText<"Tells compiler that -main-file-name contains an absolute path and "
"file specified there should be used for checksum calculation.">,
MarshallingInfoFlag<CodeGenOpts<"SYCLUseMainFileName">>;
def fsycl_allow_virtual_functions : Flag<["-"], "fsycl-allow-virtual-functions">,
HelpText<"Allow virtual functions calls in code for SYCL device">,
MarshallingInfoFlag<LangOpts<"SYCLAllowVirtualFunctions">>;

} // let Flags = [CC1Option, NoDriverOption]

Expand Down
37 changes: 31 additions & 6 deletions clang/lib/CodeGen/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3565,8 +3565,33 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty) {
// Check if the alias exists. If it doesn't, then get or create the global.
if (CGM.getItaniumVTableContext().isRelativeLayout())
VTable = CGM.getModule().getNamedAlias(VTableName);
if (!VTable)
VTable = CGM.getModule().getOrInsertGlobal(VTableName, CGM.Int8PtrTy);

// To generate valid device code global pointers should have global address
// space in SYCL.
bool GenTyInfoGVWithGlobalAS =
CGM.getLangOpts().SYCLIsDevice &&
CGM.getLangOpts().SYCLAllowVirtualFunctions &&
(VTableName == ClassTypeInfo || VTableName == SIClassTypeInfo);
auto VTableTy =
GenTyInfoGVWithGlobalAS
? CGM.Int8Ty->getPointerTo(
CGM.getContext().getTargetAddressSpace(LangAS::sycl_global))
: CGM.Int8PtrTy;
if (!VTable) {
if (GenTyInfoGVWithGlobalAS) {
VTable = CGM.getModule().getOrInsertGlobal(VTableName, VTableTy, [&] {
return new llvm::GlobalVariable(
CGM.getModule(), VTableTy, /*isConstant=*/false,
llvm::GlobalVariable::ExternalLinkage, /*Initializer=*/nullptr,
VTableName, /*InsertBefore=*/nullptr,
llvm::GlobalValue::ThreadLocalMode::NotThreadLocal,
llvm::Optional<unsigned>(
CGM.getContext().getTargetAddressSpace(LangAS::sycl_global)));
});
} else {
VTable = CGM.getModule().getOrInsertGlobal(VTableName, VTableTy);
}
}

CGM.setDSOLocal(cast<llvm::GlobalValue>(VTable->stripPointerCasts()));

Expand All @@ -3578,15 +3603,15 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty) {
// The vtable address point is 8 bytes after its start:
// 4 for the offset to top + 4 for the relative offset to rtti.
llvm::Constant *Eight = llvm::ConstantInt::get(CGM.Int32Ty, 8);
VTable = llvm::ConstantExpr::getBitCast(VTable, CGM.Int8PtrTy);
VTable = llvm::ConstantExpr::getBitCast(VTable, VTableTy);
VTable =
llvm::ConstantExpr::getInBoundsGetElementPtr(CGM.Int8Ty, VTable, Eight);
} else {
llvm::Constant *Two = llvm::ConstantInt::get(PtrDiffTy, 2);
VTable = llvm::ConstantExpr::getInBoundsGetElementPtr(CGM.Int8PtrTy, VTable,
Two);
VTable =
llvm::ConstantExpr::getInBoundsGetElementPtr(VTableTy, VTable, Two);
}
VTable = llvm::ConstantExpr::getBitCast(VTable, CGM.Int8PtrTy);
VTable = llvm::ConstantExpr::getBitCast(VTable, VTableTy);

Fields.push_back(VTable);
}
Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ class DiagDeviceFunction : public RecursiveASTVisitor<DiagDeviceFunction> {
}

if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
if (Method->isVirtual())
if (Method->isVirtual() &&
!SemaRef.getLangOpts().SYCLAllowVirtualFunctions)
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict)
<< Sema::KernelCallVirtualFunction;

Expand Down
33 changes: 33 additions & 0 deletions clang/test/CodeGenSYCL/simple-sycl-virtual-function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// This test checks that the FE generates global variables corresponding to the
// virtual table in the global address space (addrspace(1)) when
// -fsycl-allow-virtual-functions is passed.

// RUN: %clang_cc1 -triple spir64 -fsycl-allow-virtual-functions -fsycl-is-device -emit-llvm %s -o - | FileCheck %s --check-prefixes CHECK,CHECK-PTR
// RUN: %clang_cc1 -triple spir64 -fsycl-allow-virtual-functions -fsycl-is-device -fexperimental-relative-c++-abi-vtables -emit-llvm %s -o - | FileCheck %s --check-prefixes CHECK,CHECK-REL

// CHECK: @_ZTVN10__cxxabiv120__si_class_type_infoE = external addrspace(1) global ptr addrspace(1)
// CHECK: @_ZTVN10__cxxabiv117__class_type_infoE = external addrspace(1) global ptr addrspace(1)
// CHECK-PTR: @_ZTI4Base = linkonce_odr constant { ptr addrspace(1), ptr } { ptr addrspace(1) getelementptr inbounds (ptr addrspace(1), ptr addrspace(1) @_ZTVN10__cxxabiv117__class_type_infoE, i64 2)
// CHECK-PTR: @_ZTI8Derived1 = linkonce_odr constant { ptr addrspace(1), ptr, ptr } { ptr addrspace(1) getelementptr inbounds (ptr addrspace(1), ptr addrspace(1) @_ZTVN10__cxxabiv120__si_class_type_infoE, i64 2)
// CHECK-REL: @_ZTI4Base = linkonce_odr constant { ptr addrspace(1), ptr } { ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @_ZTVN10__cxxabiv117__class_type_infoE, i32 8)
// CHECK-REL: @_ZTI8Derived1 = linkonce_odr constant { ptr addrspace(1), ptr, ptr } { ptr addrspace(1) getelementptr inbounds (i8, ptr addrspace(1) @_ZTVN10__cxxabiv120__si_class_type_infoE, i32 8)

SYCL_EXTERNAL bool rand();

class Base {
public:
virtual void display() {}
};

class Derived1 : public Base {
public:
void display() {}
};

SYCL_EXTERNAL void test() {
Derived1 d1;
Base *b = nullptr;
if (rand())
b = &d1;
b->display();
}