Skip to content

Commit 3b315d1

Browse files
authored
Merge pull request #4155 from jrose-apple/PrintAsObjC-circularity
[PrintAsObjC] Handle circularities introduced by ObjC generics.
2 parents 4b41a7f + 8282160 commit 3b315d1

File tree

4 files changed

+541
-22
lines changed

4 files changed

+541
-22
lines changed

lib/PrintAsObjC/PrintAsObjC.cpp

Lines changed: 163 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ static bool isMistakableForInit(ObjCSelector selector) {
118118

119119

120120
namespace {
121+
using DelayedMemberSet = llvm::SmallSetVector<const ValueDecl *, 32>;
122+
121123
class ObjCPrinter : private DeclVisitor<ObjCPrinter>,
122124
private TypeVisitor<ObjCPrinter, void,
123125
Optional<OptionalTypeKind>>
@@ -134,6 +136,7 @@ class ObjCPrinter : private DeclVisitor<ObjCPrinter>,
134136
raw_ostream &os;
135137

136138
SmallVector<const FunctionType *, 4> openFunctionTypes;
139+
const DelayedMemberSet &delayedMembers;
137140

138141
Accessibility minRequiredAccess;
139142
bool protocolMembersOptional = false;
@@ -144,14 +147,31 @@ class ObjCPrinter : private DeclVisitor<ObjCPrinter>,
144147
friend TypeVisitor<ObjCPrinter>;
145148

146149
public:
147-
explicit ObjCPrinter(Module &mod, raw_ostream &out, Accessibility access)
148-
: M(mod), os(out), minRequiredAccess(access) {}
150+
explicit ObjCPrinter(Module &mod, raw_ostream &out,
151+
DelayedMemberSet &delayed, Accessibility access)
152+
: M(mod), os(out), delayedMembers(delayed), minRequiredAccess(access) {}
149153

150154
void print(const Decl *D) {
151155
PrettyStackTraceDecl trace("printing", D);
152156
visit(const_cast<Decl *>(D));
153157
}
154158

159+
void printAdHocCategory(iterator_range<const ValueDecl * const *> members) {
160+
assert(members.begin() != members.end());
161+
162+
const DeclContext *origDC = (*members.begin())->getDeclContext();
163+
auto *baseClass = dyn_cast<ClassDecl>(origDC);
164+
if (!baseClass) {
165+
Type extendedTy = cast<ExtensionDecl>(origDC)->getExtendedType();
166+
baseClass = extendedTy->getClassOrBoundGenericClass();
167+
}
168+
169+
os << "@interface " << getNameForObjC(baseClass)
170+
<< " (SWIFT_EXTENSION(" << origDC->getParentModule()->getName() << "))\n";
171+
printMembers</*allowDelayed*/true>(members);
172+
os << "@end\n\n";
173+
}
174+
155175
bool shouldInclude(const ValueDecl *VD, bool checkParent = true) {
156176
if (!(VD->isObjC() || VD->getAttrs().hasAttribute<CDeclAttr>()))
157177
return false;
@@ -198,19 +218,24 @@ class ObjCPrinter : private DeclVisitor<ObjCPrinter>,
198218
}
199219

200220
/// Prints the members of a class, extension, or protocol.
201-
void printMembers(DeclRange members) {
202-
for (auto member : members) {
221+
template <bool AllowDelayed = false, typename R>
222+
void printMembers(R &&members) {
223+
for (const Decl *member : members) {
203224
auto VD = dyn_cast<ValueDecl>(member);
204225
if (!VD || !shouldInclude(VD) || isa<TypeDecl>(VD))
205226
continue;
206227
if (auto FD = dyn_cast<FuncDecl>(VD))
207228
if (FD->isAccessor())
208229
continue;
230+
if (!AllowDelayed && delayedMembers.count(VD)) {
231+
os << "// '" << VD->getFullName() << "' below\n";
232+
continue;
233+
}
209234
if (VD->getAttrs().hasAttribute<OptionalAttr>() != protocolMembersOptional) {
210235
protocolMembersOptional = VD->getAttrs().hasAttribute<OptionalAttr>();
211236
os << (protocolMembersOptional ? "@optional\n" : "@required\n");
212237
}
213-
visit(VD);
238+
visit(const_cast<ValueDecl*>(VD));
214239
}
215240
}
216241

@@ -1509,6 +1534,7 @@ class ReferencedTypeFinder : private TypeVisitor<ReferencedTypeFinder> {
15091534
friend TypeVisitor;
15101535

15111536
llvm::function_ref<void(ReferencedTypeFinder &, const TypeDecl *)> Callback;
1537+
bool IsWithinConstrainedObjCGeneric = false;
15121538

15131539
ReferencedTypeFinder(decltype(Callback) callback) : Callback(callback) {}
15141540

@@ -1568,18 +1594,48 @@ class ReferencedTypeFinder : private TypeVisitor<ReferencedTypeFinder> {
15681594
visit(inout->getObjectType());
15691595
}
15701596

1597+
/// Returns true if \p archetype has any constraints other than being
1598+
/// class-bound ("conforms to" AnyObject).
1599+
static bool isConstrainedArchetype(const ArchetypeType *archetype) {
1600+
if (archetype->getSuperclass())
1601+
return true;
1602+
1603+
if (archetype->getConformsTo().size() > 1)
1604+
return true;
1605+
if (archetype->getConformsTo().size() == 0)
1606+
return false;
1607+
1608+
const ProtocolDecl *proto = archetype->getConformsTo().front();
1609+
if (auto knownKind = proto->getKnownProtocolKind())
1610+
return knownKind.getValue() != KnownProtocolKind::AnyObject;
1611+
return true;
1612+
}
1613+
15711614
void visitBoundGenericType(BoundGenericType *boundGeneric) {
1572-
for (auto argTy : boundGeneric->getGenericArgs())
1615+
bool isObjCGeneric = boundGeneric->getDecl()->hasClangNode();
1616+
1617+
for_each(boundGeneric->getGenericArgs(),
1618+
boundGeneric->getDecl()->getGenericParams()->getPrimaryArchetypes(),
1619+
[&](Type argTy, const ArchetypeType *archetype) {
1620+
if (isObjCGeneric && isConstrainedArchetype(archetype))
1621+
IsWithinConstrainedObjCGeneric = true;
15731622
visit(argTy);
1574-
// Ignore the base type; that can't be exposed to Objective-C. Every
1575-
// bound generic type we care about gets mapped to a particular construct
1576-
// in Objective-C we care about. (For example, Optional<NSFoo> is mapped to
1577-
// NSFoo *.)
1623+
IsWithinConstrainedObjCGeneric = false;
1624+
});
1625+
1626+
// Ignore the base type; that either can't be exposed to Objective-C or
1627+
// was an Objective-C type to begin with. Every bound generic Swift type we
1628+
// care about gets mapped to a particular construct in Objective-C.
1629+
// (For example, Optional<NSFoo> is mapped to NSFoo *.)
15781630
}
15791631

15801632
public:
15811633
using TypeVisitor::visit;
15821634

1635+
bool isWithinConstrainedObjCGeneric() const {
1636+
return IsWithinConstrainedObjCGeneric;
1637+
}
1638+
15831639
static void walk(Type ty, decltype(Callback) callback) {
15841640
ReferencedTypeFinder(callback).visit(ty);
15851641
}
@@ -1603,13 +1659,14 @@ struct PointerLikeComparator {
16031659

16041660
class ModuleWriter {
16051661
enum class EmissionState {
1606-
DefinitionRequested = 0,
1607-
DefinitionInProgress,
1662+
NotYetDefined = 0,
1663+
DefinitionRequested,
16081664
Defined
16091665
};
16101666

16111667
llvm::DenseMap<const TypeDecl *, std::pair<EmissionState, bool>> seenTypes;
16121668
std::vector<const Decl *> declsToWrite;
1669+
DelayedMemberSet delayedMembers;
16131670

16141671
using ImportModuleTy = PointerUnion<Module*, const clang::Module*>;
16151672
SmallSetVector<ImportModuleTy, 8,
@@ -1623,7 +1680,7 @@ class ModuleWriter {
16231680
ObjCPrinter printer;
16241681
public:
16251682
ModuleWriter(Module &mod, StringRef header, Accessibility access)
1626-
: M(mod), bridgingHeader(header), printer(M, os, access) {}
1683+
: M(mod), bridgingHeader(header), printer(M, os, delayedMembers, access) {}
16271684

16281685
/// Returns true if we added the decl's module to the import set, false if
16291686
/// the decl is a local decl.
@@ -1660,6 +1717,19 @@ class ModuleWriter {
16601717
return true;
16611718
}
16621719

1720+
bool hasBeenRequested(const TypeDecl *D) const {
1721+
return seenTypes.lookup(D).first >= EmissionState::DefinitionRequested;
1722+
}
1723+
1724+
bool tryRequire(const TypeDecl *D) {
1725+
if (addImport(D)) {
1726+
seenTypes[D] = { EmissionState::Defined, true };
1727+
return true;
1728+
}
1729+
auto &state = seenTypes[D];
1730+
return state.first == EmissionState::Defined;
1731+
}
1732+
16631733
bool require(const TypeDecl *D) {
16641734
if (addImport(D)) {
16651735
seenTypes[D] = { EmissionState::Defined, true };
@@ -1668,11 +1738,11 @@ class ModuleWriter {
16681738

16691739
auto &state = seenTypes[D];
16701740
switch (state.first) {
1741+
case EmissionState::NotYetDefined:
16711742
case EmissionState::DefinitionRequested:
1743+
state.first = EmissionState::DefinitionRequested;
16721744
declsToWrite.push_back(D);
16731745
return false;
1674-
case EmissionState::DefinitionInProgress:
1675-
llvm_unreachable("circular requirements");
16761746
case EmissionState::Defined:
16771747
return true;
16781748
}
@@ -1717,7 +1787,17 @@ class ModuleWriter {
17171787
});
17181788
}
17191789

1720-
void forwardDeclareMemberTypes(DeclRange members) {
1790+
bool forwardDeclareMemberTypes(DeclRange members, const Decl *container) {
1791+
switch (container->getKind()) {
1792+
case DeclKind::Class:
1793+
case DeclKind::Protocol:
1794+
case DeclKind::Extension:
1795+
break;
1796+
default:
1797+
llvm_unreachable("unexpected container kind");
1798+
}
1799+
1800+
bool hadAnyDelayedMembers = false;
17211801
SmallVector<ValueDecl *, 4> nestedTypes;
17221802
for (auto member : members) {
17231803
auto VD = dyn_cast<ValueDecl>(member);
@@ -1738,9 +1818,44 @@ class ModuleWriter {
17381818
continue;
17391819
}
17401820

1821+
bool needsToBeIndividuallyDelayed = false;
17411822
ReferencedTypeFinder::walk(VD->getType(),
1742-
[this](ReferencedTypeFinder &finder,
1743-
const TypeDecl *TD) {
1823+
[&](ReferencedTypeFinder &finder,
1824+
const TypeDecl *TD) {
1825+
if (TD == container)
1826+
return;
1827+
1828+
if (finder.isWithinConstrainedObjCGeneric()) {
1829+
// We can delay individual members of classes; do so if necessary.
1830+
if (isa<ClassDecl>(container)) {
1831+
if (!tryRequire(TD)) {
1832+
needsToBeIndividuallyDelayed = true;
1833+
hadAnyDelayedMembers = true;
1834+
}
1835+
return;
1836+
}
1837+
1838+
// Extensions can always be delayed wholesale.
1839+
if (isa<ExtensionDecl>(container)) {
1840+
if (!require(TD))
1841+
hadAnyDelayedMembers = true;
1842+
return;
1843+
}
1844+
1845+
// Protocols should be delayed wholesale unless we might have a cycle.
1846+
auto *proto = cast<ProtocolDecl>(container);
1847+
if (!hasBeenRequested(proto) || !hasBeenRequested(TD)) {
1848+
if (!require(TD))
1849+
hadAnyDelayedMembers = true;
1850+
return;
1851+
}
1852+
1853+
// Otherwise, we have a cyclic dependency. Give up and continue with
1854+
// regular forward-declarations even though this will lead to an
1855+
// error; there's nothing we can do here.
1856+
// FIXME: It would be nice to diagnose this.
1857+
}
1858+
17441859
if (auto CD = dyn_cast<ClassDecl>(TD)) {
17451860
if (!forwardDeclare(CD)) {
17461861
(void)addImport(CD);
@@ -1758,13 +1873,18 @@ class ModuleWriter {
17581873
else
17591874
assert(false && "unknown local type decl");
17601875
});
1876+
1877+
if (needsToBeIndividuallyDelayed) {
1878+
assert(isa<ClassDecl>(container));
1879+
delayedMembers.insert(VD);
1880+
}
17611881
}
17621882

17631883
declsToWrite.insert(declsToWrite.end()-1, nestedTypes.rbegin(),
17641884
nestedTypes.rend());
17651885

17661886
// Separate forward declarations from the class itself.
1767-
os << '\n';
1887+
return !hadAnyDelayedMembers;
17681888
}
17691889

17701890
bool writeClass(const ClassDecl *CD) {
@@ -1789,8 +1909,9 @@ class ModuleWriter {
17891909
if (!allRequirementsSatisfied)
17901910
return false;
17911911

1912+
(void)forwardDeclareMemberTypes(CD->getMembers(), CD);
17921913
seenTypes[CD] = { EmissionState::Defined, true };
1793-
forwardDeclareMemberTypes(CD->getMembers());
1914+
os << '\n';
17941915
printer.print(CD);
17951916
return true;
17961917
}
@@ -1824,8 +1945,11 @@ class ModuleWriter {
18241945
if (!allRequirementsSatisfied)
18251946
return false;
18261947

1948+
if (!forwardDeclareMemberTypes(PD->getMembers(), PD))
1949+
return false;
1950+
18271951
seenTypes[PD] = { EmissionState::Defined, true };
1828-
forwardDeclareMemberTypes(PD->getMembers());
1952+
os << '\n';
18291953
printer.print(PD);
18301954
return true;
18311955
}
@@ -1842,7 +1966,13 @@ class ModuleWriter {
18421966
if (!allRequirementsSatisfied)
18431967
return false;
18441968

1845-
forwardDeclareMemberTypes(ED->getMembers());
1969+
// This isn't rolled up into the previous set of requirements because
1970+
// it /also/ prints forward declarations, and the header is a little
1971+
// prettier if those are as close as possible to the necessary extension.
1972+
if (!forwardDeclareMemberTypes(ED->getMembers(), ED))
1973+
return false;
1974+
1975+
os << '\n';
18461976
printer.print(ED);
18471977
return true;
18481978
}
@@ -2186,6 +2316,17 @@ class ModuleWriter {
21862316
}
21872317
}
21882318

2319+
if (!delayedMembers.empty()) {
2320+
auto groupBegin = delayedMembers.begin();
2321+
for (auto i = groupBegin, e = delayedMembers.end(); i != e; ++i) {
2322+
if ((*i)->getDeclContext() != (*groupBegin)->getDeclContext()) {
2323+
printer.printAdHocCategory(make_range(groupBegin, i));
2324+
groupBegin = i;
2325+
}
2326+
}
2327+
printer.printAdHocCategory(make_range(groupBegin, delayedMembers.end()));
2328+
}
2329+
21892330
writePrologue(out);
21902331
writeImports(out);
21912332
out <<

test/PrintAsObjC/Inputs/circularity.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// This file is meant to be used with the mock SDK, not the real one.
2+
#import <Foundation.h>
3+
4+
#define SWIFT_NAME(x) __attribute__((swift_name(#x)))
5+
6+
@protocol Proto
7+
@end
8+
9+
@interface ProtoImpl : NSObject <Proto>
10+
@end
11+
12+
@interface Parent : NSObject
13+
@end
14+
15+
@interface Unconstrained<T> : NSObject
16+
@end
17+
18+
@interface NeedsProto<T: id <Proto>> : NSObject
19+
@end
20+
21+
@interface NeedsParent<T: Parent *> : NSObject
22+
@end

0 commit comments

Comments
 (0)