Skip to content

Commit 57404f0

Browse files
authored
Merge pull request #36226 from DougGregor/implicit-concurrent-value
Infer ConcurrentValue conformances for structs and enums.
2 parents c5ba171 + 2f2c0ba commit 57404f0

27 files changed

+505
-97
lines changed

include/swift/AST/DeclContext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ enum class ConformanceLookupKind : unsigned {
159159
OnlyExplicit,
160160
/// All conformances except for inherited ones.
161161
NonInherited,
162+
/// All conformances except structurally-derived conformances, of which
163+
/// ConcurrentValue is the only one.
164+
NonStructural,
162165
};
163166

164167
/// Describes a diagnostic for a conflict between two protocol

include/swift/AST/TypeCheckRequests.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2895,6 +2895,26 @@ class SynthesizeMainFunctionRequest
28952895
bool isCached() const { return true; }
28962896
};
28972897

2898+
/// Retrieve the implicit conformance for the given nominal type to
2899+
/// the ConcurrentValue protocol.
2900+
class GetImplicitConcurrentValueRequest :
2901+
public SimpleRequest<GetImplicitConcurrentValueRequest,
2902+
NormalProtocolConformance *(NominalTypeDecl *),
2903+
RequestFlags::Cached> {
2904+
public:
2905+
using SimpleRequest::SimpleRequest;
2906+
2907+
private:
2908+
friend SimpleRequest;
2909+
2910+
NormalProtocolConformance *evaluate(
2911+
Evaluator &evaluator, NominalTypeDecl *nominal) const;
2912+
2913+
public:
2914+
// Caching
2915+
bool isCached() const { return true; }
2916+
};
2917+
28982918
void simple_display(llvm::raw_ostream &out, Type value);
28992919
void simple_display(llvm::raw_ostream &out, const TypeRepr *TyR);
29002920
void simple_display(llvm::raw_ostream &out, ImplicitMemberAction action);

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,3 +316,6 @@ SWIFT_REQUEST(TypeChecker, SimpleDidSetRequest,
316316
bool(AccessorDecl *), Cached, NoLocationInfo)
317317
SWIFT_REQUEST(TypeChecker, SynthesizeMainFunctionRequest,
318318
FuncDecl *(Decl *), Cached, NoLocationInfo)
319+
SWIFT_REQUEST(TypeChecker, GetImplicitConcurrentValueRequest,
320+
NormalProtocolConformance *(NominalTypeDecl *),
321+
Cached, NoLocationInfo)

include/swift/Basic/LangOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,9 @@ namespace swift {
247247
/// Enable experimental flow-sensitive concurrent captures.
248248
bool EnableExperimentalFlowSensitiveConcurrentCaptures = false;
249249

250+
/// Enable inference of ConcurrentValue conformances for public types.
251+
bool EnableInferPublicConcurrentValue = false;
252+
250253
/// Enable experimental derivation of `Codable` for enums.
251254
bool EnableExperimentalEnumCodableDerivation = false;
252255

include/swift/Option/FrontendOptions.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ def batch_scan_input_file
188188
def import_prescan : Flag<["-"], "import-prescan">,
189189
HelpText<"When performing a dependency scan, only dentify all imports of the main Swift module sources">;
190190

191+
def enable_infer_public_concurrent_value : Flag<["-"], "enable-infer-public-concurrent-value">,
192+
HelpText<"Enable inference of ConcurrentValue conformances for public structs and enums">;
193+
194+
def disable_infer_public_concurrent_value : Flag<["-"], "disable-infer-public-concurrent-value">,
195+
HelpText<"Disable inference of ConcurrentValue conformances for public structs and enums">;
196+
191197
} // end let Flags = [FrontendOption, NoDriverOption]
192198

193199
def debug_crash_Group : OptionGroup<"<automatic crashing options>">;

lib/AST/Module.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -967,10 +967,21 @@ ModuleDecl::lookupExistentialConformance(Type type, ProtocolDecl *protocol) {
967967

968968
ProtocolConformanceRef ModuleDecl::lookupConformance(Type type,
969969
ProtocolDecl *protocol) {
970+
// If we are recursively checking for implicit conformance of a nominal
971+
// type to ConcurrentValue, fail without evaluating this request. This
972+
// squashes cycles.
973+
LookupConformanceInModuleRequest request{{this, type, protocol}};
974+
if (protocol->isSpecificProtocol(KnownProtocolKind::ConcurrentValue)) {
975+
if (auto nominal = type->getAnyNominal()) {
976+
GetImplicitConcurrentValueRequest icvRequest{nominal};
977+
if (getASTContext().evaluator.hasActiveRequest(icvRequest) ||
978+
getASTContext().evaluator.hasActiveRequest(request))
979+
return ProtocolConformanceRef::forInvalid();
980+
}
981+
}
982+
970983
return evaluateOrDefault(
971-
getASTContext().evaluator,
972-
LookupConformanceInModuleRequest{{this, type, protocol}},
973-
ProtocolConformanceRef::forInvalid());
984+
getASTContext().evaluator, request, ProtocolConformanceRef::forInvalid());
974985
}
975986

976987
ProtocolConformanceRef
@@ -1035,8 +1046,20 @@ LookupConformanceInModuleRequest::evaluate(
10351046

10361047
// Find the (unspecialized) conformance.
10371048
SmallVector<ProtocolConformance *, 2> conformances;
1038-
if (!nominal->lookupConformance(mod, protocol, conformances))
1039-
return ProtocolConformanceRef::forInvalid();
1049+
if (!nominal->lookupConformance(mod, protocol, conformances)) {
1050+
if (!protocol->isSpecificProtocol(KnownProtocolKind::ConcurrentValue))
1051+
return ProtocolConformanceRef::forInvalid();
1052+
1053+
// Try to infer ConcurrentValue conformance.
1054+
GetImplicitConcurrentValueRequest cvRequest{nominal};
1055+
if (auto conformance = evaluateOrDefault(
1056+
ctx.evaluator, cvRequest, nullptr)) {
1057+
conformances.clear();
1058+
conformances.push_back(conformance);
1059+
} else {
1060+
return ProtocolConformanceRef::forInvalid();
1061+
}
1062+
}
10401063

10411064
// FIXME: Ambiguity resolution.
10421065
auto conformance = conformances.front();

lib/AST/ProtocolConformance.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,41 @@ IterableDeclContext::getLocalProtocols(ConformanceLookupKind lookupKind) const {
13261326
return result;
13271327
}
13281328

1329+
/// Find a synthesized ConcurrentValue conformance in this declaration context,
1330+
/// if there is one.
1331+
static ProtocolConformance *findSynthesizedConcurrentValueConformance(
1332+
const DeclContext *dc) {
1333+
auto nominal = dc->getSelfNominalTypeDecl();
1334+
if (!nominal)
1335+
return nullptr;
1336+
1337+
if (isa<ProtocolDecl>(nominal))
1338+
return nullptr;
1339+
1340+
if (dc->getParentModule() != nominal->getParentModule())
1341+
return nullptr;
1342+
1343+
auto cvProto = nominal->getASTContext().getProtocol(
1344+
KnownProtocolKind::ConcurrentValue);
1345+
if (!cvProto)
1346+
return nullptr;
1347+
1348+
auto conformance = dc->getParentModule()->lookupConformance(
1349+
nominal->getDeclaredInterfaceType(), cvProto);
1350+
if (!conformance || !conformance.isConcrete())
1351+
return nullptr;
1352+
1353+
auto concrete = conformance.getConcrete();
1354+
if (concrete->getDeclContext() != dc)
1355+
return nullptr;
1356+
1357+
auto normal = concrete->getRootNormalConformance();
1358+
if (!normal || normal->getSourceKind() != ConformanceEntryKind::Synthesized)
1359+
return nullptr;
1360+
1361+
return normal;
1362+
}
1363+
13291364
std::vector<ProtocolConformance *>
13301365
LookupAllConformancesInContextRequest::evaluate(
13311366
Evaluator &eval, const IterableDeclContext *IDC) const {
@@ -1394,10 +1429,28 @@ IterableDeclContext::getLocalConformances(ConformanceLookupKind lookupKind)
13941429
}
13951430

13961431
case ConformanceLookupKind::All:
1432+
case ConformanceLookupKind::NonStructural:
13971433
return true;
13981434
}
13991435
});
14001436

1437+
// If we want to add structural conformances, do so now.
1438+
switch (lookupKind) {
1439+
case ConformanceLookupKind::All:
1440+
case ConformanceLookupKind::NonInherited: {
1441+
// Look for a ConcurrentValue conformance globally. If it is synthesized
1442+
// and matches this declaration context, use it.
1443+
auto dc = getAsGenericContext();
1444+
if (auto conformance = findSynthesizedConcurrentValueConformance(dc))
1445+
result.push_back(conformance);
1446+
break;
1447+
}
1448+
1449+
case ConformanceLookupKind::NonStructural:
1450+
case ConformanceLookupKind::OnlyExplicit:
1451+
break;
1452+
}
1453+
14011454
return result;
14021455
}
14031456

lib/Frontend/CompilerInvocation.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,10 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
385385

386386
Opts.EnableExperimentalConcurrency |=
387387
Args.hasArg(OPT_enable_experimental_concurrency);
388+
Opts.EnableInferPublicConcurrentValue |=
389+
Args.hasFlag(OPT_enable_infer_public_concurrent_value,
390+
OPT_disable_infer_public_concurrent_value,
391+
false);
388392
Opts.EnableExperimentalFlowSensitiveConcurrentCaptures |=
389393
Args.hasArg(OPT_enable_experimental_flow_sensitive_concurrent_captures);
390394

lib/Frontend/ModuleInterfaceSupport.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,10 @@ class InheritedProtocolCollector {
590590
DeclAttributes::print(printer, printOptions, attrs);
591591

592592
printer << "extension ";
593-
nominal->getDeclaredType().print(printer, printOptions);
593+
PrintOptions typePrintOptions = printOptions;
594+
typePrintOptions.FullyQualifiedTypes = false;
595+
typePrintOptions.FullyQualifiedTypesIfAmbiguous = false;
596+
nominal->getDeclaredType().print(printer, typePrintOptions);
594597
printer << " : ";
595598

596599
proto->getDeclaredInterfaceType()->print(printer, printOptions);

lib/Sema/DerivedConformanceEquatableHashable.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,8 @@ getHashableConformance(const Decl *parentDecl) {
942942
ASTContext &C = parentDecl->getASTContext();
943943
const auto IDC = cast<IterableDeclContext>(parentDecl);
944944
auto hashableProto = C.getProtocol(KnownProtocolKind::Hashable);
945-
for (auto conformance: IDC->getLocalConformances()) {
945+
for (auto conformance: IDC->getLocalConformances(
946+
ConformanceLookupKind::NonStructural)) {
946947
if (conformance->getProtocol() == hashableProto) {
947948
return conformance;
948949
}

0 commit comments

Comments
 (0)