Skip to content

[PrintAsObjC] Handle circularities introduced by ObjC generics. #4155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 163 additions & 22 deletions lib/PrintAsObjC/PrintAsObjC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ static bool isMistakableForInit(ObjCSelector selector) {


namespace {
using DelayedMemberSet = llvm::SmallSetVector<const ValueDecl *, 32>;

class ObjCPrinter : private DeclVisitor<ObjCPrinter>,
private TypeVisitor<ObjCPrinter, void,
Optional<OptionalTypeKind>>
Expand All @@ -134,6 +136,7 @@ class ObjCPrinter : private DeclVisitor<ObjCPrinter>,
raw_ostream &os;

SmallVector<const FunctionType *, 4> openFunctionTypes;
const DelayedMemberSet &delayedMembers;

Accessibility minRequiredAccess;
bool protocolMembersOptional = false;
Expand All @@ -144,14 +147,31 @@ class ObjCPrinter : private DeclVisitor<ObjCPrinter>,
friend TypeVisitor<ObjCPrinter>;

public:
explicit ObjCPrinter(Module &mod, raw_ostream &out, Accessibility access)
: M(mod), os(out), minRequiredAccess(access) {}
explicit ObjCPrinter(Module &mod, raw_ostream &out,
DelayedMemberSet &delayed, Accessibility access)
: M(mod), os(out), delayedMembers(delayed), minRequiredAccess(access) {}

void print(const Decl *D) {
PrettyStackTraceDecl trace("printing", D);
visit(const_cast<Decl *>(D));
}

void printAdHocCategory(iterator_range<const ValueDecl * const *> members) {
assert(members.begin() != members.end());

const DeclContext *origDC = (*members.begin())->getDeclContext();
auto *baseClass = dyn_cast<ClassDecl>(origDC);
if (!baseClass) {
Type extendedTy = cast<ExtensionDecl>(origDC)->getExtendedType();
baseClass = extendedTy->getClassOrBoundGenericClass();
}

os << "@interface " << getNameForObjC(baseClass)
<< " (SWIFT_EXTENSION(" << origDC->getParentModule()->getName() << "))\n";
printMembers</*allowDelayed*/true>(members);
os << "@end\n\n";
}

bool shouldInclude(const ValueDecl *VD, bool checkParent = true) {
if (!(VD->isObjC() || VD->getAttrs().hasAttribute<CDeclAttr>()))
return false;
Expand Down Expand Up @@ -198,19 +218,24 @@ class ObjCPrinter : private DeclVisitor<ObjCPrinter>,
}

/// Prints the members of a class, extension, or protocol.
void printMembers(DeclRange members) {
for (auto member : members) {
template <bool AllowDelayed = false, typename R>
void printMembers(R &&members) {
for (const Decl *member : members) {
auto VD = dyn_cast<ValueDecl>(member);
if (!VD || !shouldInclude(VD) || isa<TypeDecl>(VD))
continue;
if (auto FD = dyn_cast<FuncDecl>(VD))
if (FD->isAccessor())
continue;
if (!AllowDelayed && delayedMembers.count(VD)) {
os << "// '" << VD->getFullName() << "' below\n";
continue;
}
if (VD->getAttrs().hasAttribute<OptionalAttr>() != protocolMembersOptional) {
protocolMembersOptional = VD->getAttrs().hasAttribute<OptionalAttr>();
os << (protocolMembersOptional ? "@optional\n" : "@required\n");
}
visit(VD);
visit(const_cast<ValueDecl*>(VD));
}
}

Expand Down Expand Up @@ -1509,6 +1534,7 @@ class ReferencedTypeFinder : private TypeVisitor<ReferencedTypeFinder> {
friend TypeVisitor;

llvm::function_ref<void(ReferencedTypeFinder &, const TypeDecl *)> Callback;
bool IsWithinConstrainedObjCGeneric = false;

ReferencedTypeFinder(decltype(Callback) callback) : Callback(callback) {}

Expand Down Expand Up @@ -1568,18 +1594,48 @@ class ReferencedTypeFinder : private TypeVisitor<ReferencedTypeFinder> {
visit(inout->getObjectType());
}

/// Returns true if \p archetype has any constraints other than being
/// class-bound ("conforms to" AnyObject).
static bool isConstrainedArchetype(const ArchetypeType *archetype) {
if (archetype->getSuperclass())
return true;

if (archetype->getConformsTo().size() > 1)
return true;
if (archetype->getConformsTo().size() == 0)
return false;

const ProtocolDecl *proto = archetype->getConformsTo().front();
if (auto knownKind = proto->getKnownProtocolKind())
return knownKind.getValue() != KnownProtocolKind::AnyObject;
return true;
}

void visitBoundGenericType(BoundGenericType *boundGeneric) {
for (auto argTy : boundGeneric->getGenericArgs())
bool isObjCGeneric = boundGeneric->getDecl()->hasClangNode();

for_each(boundGeneric->getGenericArgs(),
boundGeneric->getDecl()->getGenericParams()->getPrimaryArchetypes(),
[&](Type argTy, const ArchetypeType *archetype) {
if (isObjCGeneric && isConstrainedArchetype(archetype))
IsWithinConstrainedObjCGeneric = true;
visit(argTy);
// Ignore the base type; that can't be exposed to Objective-C. Every
// bound generic type we care about gets mapped to a particular construct
// in Objective-C we care about. (For example, Optional<NSFoo> is mapped to
// NSFoo *.)
IsWithinConstrainedObjCGeneric = false;
});

// Ignore the base type; that either can't be exposed to Objective-C or
// was an Objective-C type to begin with. Every bound generic Swift type we
// care about gets mapped to a particular construct in Objective-C.
// (For example, Optional<NSFoo> is mapped to NSFoo *.)
}

public:
using TypeVisitor::visit;

bool isWithinConstrainedObjCGeneric() const {
return IsWithinConstrainedObjCGeneric;
}

static void walk(Type ty, decltype(Callback) callback) {
ReferencedTypeFinder(callback).visit(ty);
}
Expand All @@ -1603,13 +1659,14 @@ struct PointerLikeComparator {

class ModuleWriter {
enum class EmissionState {
DefinitionRequested = 0,
DefinitionInProgress,
NotYetDefined = 0,
DefinitionRequested,
Defined
};

llvm::DenseMap<const TypeDecl *, std::pair<EmissionState, bool>> seenTypes;
std::vector<const Decl *> declsToWrite;
DelayedMemberSet delayedMembers;

using ImportModuleTy = PointerUnion<Module*, const clang::Module*>;
SmallSetVector<ImportModuleTy, 8,
Expand All @@ -1623,7 +1680,7 @@ class ModuleWriter {
ObjCPrinter printer;
public:
ModuleWriter(Module &mod, StringRef header, Accessibility access)
: M(mod), bridgingHeader(header), printer(M, os, access) {}
: M(mod), bridgingHeader(header), printer(M, os, delayedMembers, access) {}

/// Returns true if we added the decl's module to the import set, false if
/// the decl is a local decl.
Expand Down Expand Up @@ -1660,6 +1717,19 @@ class ModuleWriter {
return true;
}

bool hasBeenRequested(const TypeDecl *D) const {
return seenTypes.lookup(D).first >= EmissionState::DefinitionRequested;
}

bool tryRequire(const TypeDecl *D) {
if (addImport(D)) {
seenTypes[D] = { EmissionState::Defined, true };
return true;
}
auto &state = seenTypes[D];
return state.first == EmissionState::Defined;
}

bool require(const TypeDecl *D) {
if (addImport(D)) {
seenTypes[D] = { EmissionState::Defined, true };
Expand All @@ -1668,11 +1738,11 @@ class ModuleWriter {

auto &state = seenTypes[D];
switch (state.first) {
case EmissionState::NotYetDefined:
case EmissionState::DefinitionRequested:
state.first = EmissionState::DefinitionRequested;
declsToWrite.push_back(D);
return false;
case EmissionState::DefinitionInProgress:
llvm_unreachable("circular requirements");
case EmissionState::Defined:
return true;
}
Expand Down Expand Up @@ -1717,7 +1787,17 @@ class ModuleWriter {
});
}

void forwardDeclareMemberTypes(DeclRange members) {
bool forwardDeclareMemberTypes(DeclRange members, const Decl *container) {
switch (container->getKind()) {
case DeclKind::Class:
case DeclKind::Protocol:
case DeclKind::Extension:
break;
default:
llvm_unreachable("unexpected container kind");
}

bool hadAnyDelayedMembers = false;
SmallVector<ValueDecl *, 4> nestedTypes;
for (auto member : members) {
auto VD = dyn_cast<ValueDecl>(member);
Expand All @@ -1738,9 +1818,44 @@ class ModuleWriter {
continue;
}

bool needsToBeIndividuallyDelayed = false;
ReferencedTypeFinder::walk(VD->getType(),
[this](ReferencedTypeFinder &finder,
const TypeDecl *TD) {
[&](ReferencedTypeFinder &finder,
const TypeDecl *TD) {
if (TD == container)
return;

if (finder.isWithinConstrainedObjCGeneric()) {
// We can delay individual members of classes; do so if necessary.
if (isa<ClassDecl>(container)) {
if (!tryRequire(TD)) {
needsToBeIndividuallyDelayed = true;
hadAnyDelayedMembers = true;
}
return;
}

// Extensions can always be delayed wholesale.
if (isa<ExtensionDecl>(container)) {
if (!require(TD))
hadAnyDelayedMembers = true;
return;
}

// Protocols should be delayed wholesale unless we might have a cycle.
auto *proto = cast<ProtocolDecl>(container);
if (!hasBeenRequested(proto) || !hasBeenRequested(TD)) {
if (!require(TD))
hadAnyDelayedMembers = true;
return;
}

// Otherwise, we have a cyclic dependency. Give up and continue with
// regular forward-declarations even though this will lead to an
// error; there's nothing we can do here.
// FIXME: It would be nice to diagnose this.
}

if (auto CD = dyn_cast<ClassDecl>(TD)) {
if (!forwardDeclare(CD)) {
(void)addImport(CD);
Expand All @@ -1758,13 +1873,18 @@ class ModuleWriter {
else
assert(false && "unknown local type decl");
});

if (needsToBeIndividuallyDelayed) {
assert(isa<ClassDecl>(container));
delayedMembers.insert(VD);
}
}

declsToWrite.insert(declsToWrite.end()-1, nestedTypes.rbegin(),
nestedTypes.rend());

// Separate forward declarations from the class itself.
os << '\n';
return !hadAnyDelayedMembers;
}

bool writeClass(const ClassDecl *CD) {
Expand All @@ -1789,8 +1909,9 @@ class ModuleWriter {
if (!allRequirementsSatisfied)
return false;

(void)forwardDeclareMemberTypes(CD->getMembers(), CD);
seenTypes[CD] = { EmissionState::Defined, true };
forwardDeclareMemberTypes(CD->getMembers());
os << '\n';
printer.print(CD);
return true;
}
Expand Down Expand Up @@ -1824,8 +1945,11 @@ class ModuleWriter {
if (!allRequirementsSatisfied)
return false;

if (!forwardDeclareMemberTypes(PD->getMembers(), PD))
return false;

seenTypes[PD] = { EmissionState::Defined, true };
forwardDeclareMemberTypes(PD->getMembers());
os << '\n';
printer.print(PD);
return true;
}
Expand All @@ -1842,7 +1966,13 @@ class ModuleWriter {
if (!allRequirementsSatisfied)
return false;

forwardDeclareMemberTypes(ED->getMembers());
// This isn't rolled up into the previous set of requirements because
// it /also/ prints forward declarations, and the header is a little
// prettier if those are as close as possible to the necessary extension.
if (!forwardDeclareMemberTypes(ED->getMembers(), ED))
return false;

os << '\n';
printer.print(ED);
return true;
}
Expand Down Expand Up @@ -2186,6 +2316,17 @@ class ModuleWriter {
}
}

if (!delayedMembers.empty()) {
auto groupBegin = delayedMembers.begin();
for (auto i = groupBegin, e = delayedMembers.end(); i != e; ++i) {
if ((*i)->getDeclContext() != (*groupBegin)->getDeclContext()) {
printer.printAdHocCategory(make_range(groupBegin, i));
groupBegin = i;
}
}
printer.printAdHocCategory(make_range(groupBegin, delayedMembers.end()));
}

writePrologue(out);
writeImports(out);
out <<
Expand Down
22 changes: 22 additions & 0 deletions test/PrintAsObjC/Inputs/circularity.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// This file is meant to be used with the mock SDK, not the real one.
#import <Foundation.h>

#define SWIFT_NAME(x) __attribute__((swift_name(#x)))

@protocol Proto
@end

@interface ProtoImpl : NSObject <Proto>
@end

@interface Parent : NSObject
@end

@interface Unconstrained<T> : NSObject
@end

@interface NeedsProto<T: id <Proto>> : NSObject
@end

@interface NeedsParent<T: Parent *> : NSObject
@end
Loading