Skip to content

[cxx-interop] Add support for calling members of base classes on UFOs. #42506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions lib/ClangImporter/ClangImporter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4382,6 +4382,11 @@ TinyPtrVector<ValueDecl *> CXXNamespaceMemberLookup::evaluate(
DeclRefExpr *getInteropStaticCastDeclRefExpr(ASTContext &ctx,
const clang::Module *owningModule,
Type base, Type derived) {
if (base->isForeignReferenceType() && derived->isForeignReferenceType()) {
base = base->wrapInPointer(PTK_UnsafePointer);
derived = derived->wrapInPointer(PTK_UnsafePointer);
}

// Lookup our static cast helper function.
// TODO: change this to stdlib or something.
auto wrapperModule =
Expand Down Expand Up @@ -4420,8 +4425,8 @@ DeclRefExpr *getInteropStaticCastDeclRefExpr(ASTContext &ctx,
// %3 = %2!
// return %3.pointee
MemberRefExpr *getInOutSelfInteropStaticCast(FuncDecl *funcDecl,
StructDecl *baseStruct,
StructDecl *derivedStruct) {
NominalTypeDecl *baseStruct,
NominalTypeDecl *derivedStruct) {
auto &ctx = funcDecl->getASTContext();

auto inoutSelf = [&ctx](FuncDecl *funcDecl) {
Expand Down Expand Up @@ -4485,18 +4490,13 @@ MemberRefExpr *getInOutSelfInteropStaticCast(FuncDecl *funcDecl,
->getResult());
casted->setThrows(false);

// Now force unwrap the casted pointer.
auto unwrapped = new (ctx) ForceValueExpr(casted, SourceLoc());
unwrapped->setType(baseStruct->getSelfInterfaceType()->wrapInPointer(
PTK_UnsafeMutablePointer));

SubstitutionMap pointeeSubst = SubstitutionMap::get(
ctx.getUnsafeMutablePointerDecl()->getGenericSignature(),
{baseStruct->getSelfInterfaceType()}, {});
VarDecl *pointeePropertyDecl =
ctx.getPointerPointeePropertyDecl(PTK_UnsafeMutablePointer);
auto pointeePropertyRefExpr = new (ctx) MemberRefExpr(
unwrapped, SourceLoc(),
casted, SourceLoc(),
ConcreteDeclRef(pointeePropertyDecl, pointeeSubst), DeclNameLoc(),
/*implicit=*/true);
pointeePropertyRefExpr->setType(
Expand All @@ -4521,9 +4521,9 @@ synthesizeBaseClassMethodBody(AbstractFunctionDecl *afd, void *context) {

auto funcDecl = cast<FuncDecl>(afd);
auto derivedStruct =
cast<StructDecl>(funcDecl->getDeclContext()->getAsDecl());
cast<NominalTypeDecl>(funcDecl->getDeclContext()->getAsDecl());
auto baseMember = static_cast<FuncDecl *>(context);
auto baseStruct = cast<StructDecl>(baseMember->getDeclContext()->getAsDecl());
auto baseStruct = cast<NominalTypeDecl>(baseMember->getDeclContext()->getAsDecl());
auto baseType = baseStruct->getDeclaredType();

SmallVector<Expr *, 8> forwardingParams;
Expand Down Expand Up @@ -4591,10 +4591,10 @@ synthesizeBaseClassFieldGetterBody(AbstractFunctionDecl *afd, void *context) {

AccessorDecl *getterDecl = cast<AccessorDecl>(afd);
AbstractStorageDecl *baseClassVar = static_cast<AbstractStorageDecl *>(context);
StructDecl *baseStruct =
cast<StructDecl>(baseClassVar->getDeclContext()->getAsDecl());
StructDecl *derivedStruct =
cast<StructDecl>(getterDecl->getDeclContext()->getAsDecl());
NominalTypeDecl *baseStruct =
cast<NominalTypeDecl>(baseClassVar->getDeclContext()->getAsDecl());
NominalTypeDecl *derivedStruct =
cast<NominalTypeDecl>(getterDecl->getDeclContext()->getAsDecl());

auto selfDecl = getterDecl->getImplicitSelfDecl();
auto selfExpr = new (ctx) DeclRefExpr(selfDecl, DeclNameLoc(),
Expand Down Expand Up @@ -4654,10 +4654,10 @@ synthesizeBaseClassFieldSetterBody(AbstractFunctionDecl *afd, void *context) {
AbstractStorageDecl *baseClassVar = static_cast<AbstractStorageDecl *>(context);
ASTContext &ctx = setterDecl->getASTContext();

StructDecl *baseStruct =
cast<StructDecl>(baseClassVar->getDeclContext()->getAsDecl());
StructDecl *derivedStruct =
cast<StructDecl>(setterDecl->getDeclContext()->getAsDecl());
NominalTypeDecl *baseStruct =
cast<NominalTypeDecl>(baseClassVar->getDeclContext()->getAsDecl());
NominalTypeDecl *derivedStruct =
cast<NominalTypeDecl>(setterDecl->getDeclContext()->getAsDecl());

auto *pointeePropertyRefExpr =
getInOutSelfInteropStaticCast(setterDecl, baseStruct, derivedStruct);
Expand Down Expand Up @@ -4894,7 +4894,7 @@ TinyPtrVector<ValueDecl *> ClangRecordMemberLookup::evaluate(
if (cast<ValueDecl>(import)->getName() == name)
continue;

auto baseResults = cast<StructDecl>(import)->lookupDirect(name);
auto baseResults = cast<NominalTypeDecl>(import)->lookupDirect(name);
for (auto foundInBase : baseResults) {
if (auto newDecl = cloneBaseMemberDecl(foundInBase, recordDecl)) {
result.push_back(newDecl);
Expand Down
19 changes: 17 additions & 2 deletions lib/ClangImporter/ImportType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "clang/AST/ASTContext.h"
#include "clang/AST/DeclCXX.h"
#include "clang/AST/DeclObjCCommon.h"
#include "clang/AST/DeclTemplate.h"
#include "clang/AST/TypeVisitor.h"
#include "clang/Basic/Builtins.h"
#include "clang/Lex/Preprocessor.h"
Expand Down Expand Up @@ -2077,6 +2078,17 @@ ImportedType ClangImporter::Implementation::importFunctionReturnType(
OptionalityOfReturn = OTK_ImplicitlyUnwrappedOptional;
}

// Specialized templates need to match the args/result exactly (i.e.,
// ptr -> ptr not ptr -> Optional<ptr>).
if (clangDecl->getReturnType()->isPointerType() &&
clangDecl->getPrimaryTemplate() &&
clangDecl
->getPrimaryTemplate()
->getAsFunction()
->getReturnType()
->isTemplateTypeParmType())
OptionalityOfReturn = OTK_None;

// Import the result type.
return importType(clangDecl->getReturnType(),
(isAuditedResult ? ImportTypeKind::AuditedResult
Expand Down Expand Up @@ -2240,9 +2252,12 @@ ParameterList *ClangImporter::Implementation::importFunctionParameterList(
continue;
}

bool knownNonNull = !nonNullArgs.empty() && nonNullArgs[index];
// Specialized templates need to match the args/result exactly.
knownNonNull |= clangDecl->isFunctionTemplateSpecialization();

// Check nullability of the parameter.
OptionalTypeKind OptionalityOfParam =
getParamOptionality(param, !nonNullArgs.empty() && nonNullArgs[index]);
OptionalTypeKind OptionalityOfParam = getParamOptionality(param, knownNonNull);

ImportTypeKind importKind = ImportTypeKind::Parameter;
if (param->hasAttr<clang::CFReturnsRetainedAttr>())
Expand Down
5 changes: 4 additions & 1 deletion lib/SILGen/SILGenApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,10 @@ class SILGenApply : public Lowering::ExprVisitor<SILGenApply> {
isObjCDirect = objcDecl->isDirectMethod();
}

if (isObjCDirect) {
// Methods on unsafe foreign objects are always called directly.
bool isUFO = isa_and_nonnull<ClassDecl>(afd->getDeclContext()) &&
cast<ClassDecl>(afd->getDeclContext())->isForeignReferenceType();
if (isObjCDirect || isUFO) {
setCallee(Callee::forDirect(SGF, constant, subs, e));
} else {
setCallee(Callee::forClassMethod(SGF, constant, subs, e));
Expand Down
22 changes: 22 additions & 0 deletions test/Interop/Cxx/foreign-reference/Inputs/pod.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

#include "visibility.h"

template <class From, class To>
To __swift_interopStaticCast(From from) { return from; }

inline void *operator new(size_t, void *p) { return p; }

SWIFT_BEGIN_NULLABILITY_ANNOTATIONS
Expand Down Expand Up @@ -117,6 +120,25 @@ void mutateIt(BigType &x) {
}
BigType passThroughByValue(BigType x) { return x; }

struct __attribute__((swift_attr("import_as_ref"))) BaseRef {
int a = 1;
int b = 2;

int test() const { return b - a; }
int test() { return b - a; }

static BaseRef *create() { return new (malloc(sizeof(BaseRef))) BaseRef(); }
};

struct __attribute__((swift_attr("import_as_ref"))) DerivedRef : BaseRef {
int c = 1;

int testDerived() const { return test() + c; }
int testDerived() { return test() + c; }

static DerivedRef *create() { return new (malloc(sizeof(DerivedRef))) DerivedRef(); }
};

SWIFT_END_NULLABILITY_ANNOTATIONS

#endif // TEST_INTEROP_CXX_FOREIGN_REFERENCE_INPUTS_POD_H
14 changes: 14 additions & 0 deletions test/Interop/Cxx/foreign-reference/pod.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,18 @@ PODTestSuite.test("BigType") {
expectEqual(x.test(), 1)
}

PODTestSuite.test("DerivedRef") {
var x = DerivedRef.create()
expectEqual(x.test(), 1)
expectEqual(x.testMutating(), 1)
expectEqual(x.testDerived(), 2)
expectEqual(x.testDerivedMutating(), 2)
}

PODTestSuite.test("BaseRef") {
var x = BaseRef.create()
expectEqual(x.test(), 1)
expectEqual(x.testMutating(), 1)
}

runAllTests()