Skip to content

Infer ConcurrentValue conformances for structs and enums. #36226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3004,6 +3004,7 @@ class NominalTypeDecl : public GenericTypeDecl, public IterableDeclContext {
friend class DeclContext;
friend class IterableDeclContext;
friend class DirectLookupRequest;
friend class LookupAllConformancesInContextRequest;
friend ArrayRef<ValueDecl *>
ValueDecl::getSatisfiedProtocolRequirements(bool Sorted) const;

Expand Down
5 changes: 5 additions & 0 deletions include/swift/AST/DeclContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ enum class ConformanceLookupKind : unsigned {
OnlyExplicit,
/// All conformances except for inherited ones.
NonInherited,
/// All conformances except structurally-derived conformances, of which
/// ConcurrentValue is the only one.
NonStructural,
};

/// Describes a diagnostic for a conflict between two protocol
Expand Down Expand Up @@ -734,6 +737,8 @@ class IterableDeclContext {

static IterableDeclContext *castDeclToIterableDeclContext(const Decl *D);

friend class LookupAllConformancesInContextRequest;

/// Retrieve the \c ASTContext in which this iterable context occurs.
ASTContext &getASTContext() const;

Expand Down
20 changes: 20 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -2895,6 +2895,26 @@ class SynthesizeMainFunctionRequest
bool isCached() const { return true; }
};

/// Retrieve the implicit conformance for the given nominal type to
/// the ConcurrentValue protocol.
class GetImplicitConcurrentValueRequest :
public SimpleRequest<GetImplicitConcurrentValueRequest,
NormalProtocolConformance *(NominalTypeDecl *),
RequestFlags::Cached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

NormalProtocolConformance *evaluate(
Evaluator &evaluator, NominalTypeDecl *nominal) const;

public:
// Caching
bool isCached() const { return true; }
};

void simple_display(llvm::raw_ostream &out, Type value);
void simple_display(llvm::raw_ostream &out, const TypeRepr *TyR);
void simple_display(llvm::raw_ostream &out, ImplicitMemberAction action);
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,6 @@ SWIFT_REQUEST(TypeChecker, SimpleDidSetRequest,
bool(AccessorDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, SynthesizeMainFunctionRequest,
FuncDecl *(Decl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, GetImplicitConcurrentValueRequest,
NormalProtocolConformance *(NominalTypeDecl *),
Cached, NoLocationInfo)
3 changes: 3 additions & 0 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,9 @@ namespace swift {
/// Enable experimental flow-sensitive concurrent captures.
bool EnableExperimentalFlowSensitiveConcurrentCaptures = false;

/// Enable inference of ConcurrentValue conformances for public types.
bool EnableInferPublicConcurrentValue = false;

/// Enable experimental derivation of `Codable` for enums.
bool EnableExperimentalEnumCodableDerivation = false;

Expand Down
6 changes: 6 additions & 0 deletions include/swift/Option/FrontendOptions.td
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ def batch_scan_input_file
def import_prescan : Flag<["-"], "import-prescan">,
HelpText<"When performing a dependency scan, only dentify all imports of the main Swift module sources">;

def enable_infer_public_concurrent_value : Flag<["-"], "enable-infer-public-concurrent-value">,
HelpText<"Enable inference of ConcurrentValue conformances for public structs and enums">;

def disable_infer_public_concurrent_value : Flag<["-"], "disable-infer-public-concurrent-value">,
HelpText<"Disable inference of ConcurrentValue conformances for public structs and enums">;

} // end let Flags = [FrontendOption, NoDriverOption]

def debug_crash_Group : OptionGroup<"<automatic crashing options>">;
Expand Down
34 changes: 1 addition & 33 deletions lib/AST/ConformanceLookupTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -955,9 +955,7 @@ bool ConformanceLookupTable::lookupConformance(
void ConformanceLookupTable::lookupConformances(
NominalTypeDecl *nominal,
DeclContext *dc,
ConformanceLookupKind lookupKind,
SmallVectorImpl<ProtocolDecl *> *protocols,
SmallVectorImpl<ProtocolConformance *> *conformances,
std::vector<ProtocolConformance *> *conformances,
SmallVectorImpl<ConformanceDiagnostic> *diagnostics) {
// We need to expand all implied conformances before we can find
// those conformances that pertain to this declaration context.
Expand All @@ -980,36 +978,6 @@ void ConformanceLookupTable::lookupConformances(
if (entry->isSuperseded())
return true;

// If we are to filter out this result, do so now.
switch (lookupKind) {
case ConformanceLookupKind::OnlyExplicit:
switch (entry->getKind()) {
case ConformanceEntryKind::Explicit:
case ConformanceEntryKind::Synthesized:
break;
case ConformanceEntryKind::Implied:
case ConformanceEntryKind::Inherited:
return false;
}
break;
case ConformanceLookupKind::NonInherited:
switch (entry->getKind()) {
case ConformanceEntryKind::Explicit:
case ConformanceEntryKind::Synthesized:
case ConformanceEntryKind::Implied:
break;
case ConformanceEntryKind::Inherited:
return false;
}
break;
case ConformanceLookupKind::All:
break;
}

// Record the protocol.
if (protocols)
protocols->push_back(entry->getProtocol());

// Record the conformance.
if (conformances) {
if (auto conformance = getConformance(nominal, entry))
Expand Down
4 changes: 1 addition & 3 deletions lib/AST/ConformanceLookupTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,7 @@ class ConformanceLookupTable {
/// Look for all of the conformances within the given declaration context.
void lookupConformances(NominalTypeDecl *nominal,
DeclContext *dc,
ConformanceLookupKind lookupKind,
SmallVectorImpl<ProtocolDecl *> *protocols,
SmallVectorImpl<ProtocolConformance *> *conformances,
std::vector<ProtocolConformance *> *conformances,
SmallVectorImpl<ConformanceDiagnostic> *diagnostics);

/// Retrieve the complete set of protocols to which this nominal
Expand Down
33 changes: 28 additions & 5 deletions lib/AST/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,10 +967,21 @@ ModuleDecl::lookupExistentialConformance(Type type, ProtocolDecl *protocol) {

ProtocolConformanceRef ModuleDecl::lookupConformance(Type type,
ProtocolDecl *protocol) {
// If we are recursively checking for implicit conformance of a nominal
// type to ConcurrentValue, fail without evaluating this request. This
// squashes cycles.
LookupConformanceInModuleRequest request{{this, type, protocol}};
if (protocol->isSpecificProtocol(KnownProtocolKind::ConcurrentValue)) {
if (auto nominal = type->getAnyNominal()) {
GetImplicitConcurrentValueRequest icvRequest{nominal};
if (getASTContext().evaluator.hasActiveRequest(icvRequest) ||
getASTContext().evaluator.hasActiveRequest(request))
return ProtocolConformanceRef::forInvalid();
}
}

return evaluateOrDefault(
getASTContext().evaluator,
LookupConformanceInModuleRequest{{this, type, protocol}},
ProtocolConformanceRef::forInvalid());
getASTContext().evaluator, request, ProtocolConformanceRef::forInvalid());
}

ProtocolConformanceRef
Expand Down Expand Up @@ -1035,8 +1046,20 @@ LookupConformanceInModuleRequest::evaluate(

// Find the (unspecialized) conformance.
SmallVector<ProtocolConformance *, 2> conformances;
if (!nominal->lookupConformance(mod, protocol, conformances))
return ProtocolConformanceRef::forInvalid();
if (!nominal->lookupConformance(mod, protocol, conformances)) {
if (!protocol->isSpecificProtocol(KnownProtocolKind::ConcurrentValue))
return ProtocolConformanceRef::forInvalid();

// Try to infer ConcurrentValue conformance.
GetImplicitConcurrentValueRequest cvRequest{nominal};
if (auto conformance = evaluateOrDefault(
ctx.evaluator, cvRequest, nullptr)) {
conformances.clear();
conformances.push_back(conformance);
} else {
return ProtocolConformanceRef::forInvalid();
}
}

// FIXME: Ambiguity resolution.
auto conformance = conformances.front();
Expand Down
141 changes: 108 additions & 33 deletions lib/AST/ProtocolConformance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1321,59 +1321,136 @@ NominalTypeDecl::getSatisfiedProtocolRequirementsForMember(
SmallVector<ProtocolDecl *, 2>
IterableDeclContext::getLocalProtocols(ConformanceLookupKind lookupKind) const {
SmallVector<ProtocolDecl *, 2> result;
for (auto conformance : getLocalConformances(lookupKind))
result.push_back(conformance->getProtocol());
return result;
}

// Dig out the nominal type.
const auto dc = getAsGenericContext();
const auto nominal = dc->getSelfNominalTypeDecl();
if (!nominal) {
return result;
}
/// Find a synthesized ConcurrentValue conformance in this declaration context,
/// if there is one.
static ProtocolConformance *findSynthesizedConcurrentValueConformance(
const DeclContext *dc) {
auto nominal = dc->getSelfNominalTypeDecl();
if (!nominal)
return nullptr;

// Update to record all potential conformances.
nominal->prepareConformanceTable();
nominal->ConformanceTable->lookupConformances(
nominal,
const_cast<GenericContext *>(dc),
lookupKind,
&result,
nullptr,
nullptr);
if (isa<ProtocolDecl>(nominal))
return nullptr;

return result;
}
if (dc->getParentModule() != nominal->getParentModule())
return nullptr;

SmallVector<ProtocolConformance *, 2>
IterableDeclContext::getLocalConformances(ConformanceLookupKind lookupKind)
const {
SmallVector<ProtocolConformance *, 2> result;
auto cvProto = nominal->getASTContext().getProtocol(
KnownProtocolKind::ConcurrentValue);
if (!cvProto)
return nullptr;

auto conformance = dc->getParentModule()->lookupConformance(
nominal->getDeclaredInterfaceType(), cvProto);
if (!conformance || !conformance.isConcrete())
return nullptr;

auto concrete = conformance.getConcrete();
if (concrete->getDeclContext() != dc)
return nullptr;

auto normal = concrete->getRootNormalConformance();
if (!normal || normal->getSourceKind() != ConformanceEntryKind::Synthesized)
return nullptr;

return normal;
}

std::vector<ProtocolConformance *>
LookupAllConformancesInContextRequest::evaluate(
Evaluator &eval, const IterableDeclContext *IDC) const {
// Dig out the nominal type.
const auto dc = getAsGenericContext();
const auto dc = IDC->getAsGenericContext();
const auto nominal = dc->getSelfNominalTypeDecl();
if (!nominal) {
return result;
return { };
}

// Protocols only have self-conformances.
if (auto protocol = dyn_cast<ProtocolDecl>(nominal)) {
if (protocol->requiresSelfConformanceWitnessTable()) {
return SmallVector<ProtocolConformance *, 2>{
protocol->getASTContext().getSelfConformance(protocol)
};
return { protocol->getASTContext().getSelfConformance(protocol) };
}
return SmallVector<ProtocolConformance *, 2>();

return { };
}

// Update to record all potential conformances.
// Record all potential conformances.
nominal->prepareConformanceTable();
std::vector<ProtocolConformance *> conformances;
nominal->ConformanceTable->lookupConformances(
nominal,
const_cast<GenericContext *>(dc),
lookupKind,
nullptr,
&result,
&conformances,
nullptr);

return conformances;
}

SmallVector<ProtocolConformance *, 2>
IterableDeclContext::getLocalConformances(ConformanceLookupKind lookupKind)
const {
// Look up the cached set of all of the conformances.
std::vector<ProtocolConformance *> conformances =
evaluateOrDefault(
getASTContext().evaluator, LookupAllConformancesInContextRequest{this},
{ });

// Copy all of the conformances we want.
SmallVector<ProtocolConformance *, 2> result;
std::copy_if(
conformances.begin(), conformances.end(), std::back_inserter(result),
[&](ProtocolConformance *conformance) {
// If we are to filter out this result, do so now.
switch (lookupKind) {
case ConformanceLookupKind::OnlyExplicit:
switch (conformance->getSourceKind()) {
case ConformanceEntryKind::Explicit:
case ConformanceEntryKind::Synthesized:
return true;
case ConformanceEntryKind::Implied:
case ConformanceEntryKind::Inherited:
return false;
}

case ConformanceLookupKind::NonInherited:
switch (conformance->getSourceKind()) {
case ConformanceEntryKind::Explicit:
case ConformanceEntryKind::Synthesized:
case ConformanceEntryKind::Implied:
return true;
case ConformanceEntryKind::Inherited:
return false;
}

case ConformanceLookupKind::All:
case ConformanceLookupKind::NonStructural:
return true;
}
});

// If we want to add structural conformances, do so now.
switch (lookupKind) {
case ConformanceLookupKind::All:
case ConformanceLookupKind::NonInherited: {
// Look for a ConcurrentValue conformance globally. If it is synthesized
// and matches this declaration context, use it.
auto dc = getAsGenericContext();
if (auto conformance = findSynthesizedConcurrentValueConformance(dc))
result.push_back(conformance);
break;
}

case ConformanceLookupKind::NonStructural:
case ConformanceLookupKind::OnlyExplicit:
break;
}

return result;
}

Expand All @@ -1399,8 +1476,6 @@ IterableDeclContext::takeConformanceDiagnostics() const {
nominal->ConformanceTable->lookupConformances(
nominal,
const_cast<GenericContext *>(dc),
ConformanceLookupKind::All,
nullptr,
nullptr,
&result);

Expand Down
4 changes: 4 additions & 0 deletions lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,10 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,

Opts.EnableExperimentalConcurrency |=
Args.hasArg(OPT_enable_experimental_concurrency);
Opts.EnableInferPublicConcurrentValue |=
Args.hasFlag(OPT_enable_infer_public_concurrent_value,
OPT_disable_infer_public_concurrent_value,
false);
Opts.EnableExperimentalFlowSensitiveConcurrentCaptures |=
Args.hasArg(OPT_enable_experimental_flow_sensitive_concurrent_captures);

Expand Down
5 changes: 4 additions & 1 deletion lib/Frontend/ModuleInterfaceSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,10 @@ class InheritedProtocolCollector {
DeclAttributes::print(printer, printOptions, attrs);

printer << "extension ";
nominal->getDeclaredType().print(printer, printOptions);
PrintOptions typePrintOptions = printOptions;
typePrintOptions.FullyQualifiedTypes = false;
typePrintOptions.FullyQualifiedTypesIfAmbiguous = false;
nominal->getDeclaredType().print(printer, typePrintOptions);
printer << " : ";

proto->getDeclaredInterfaceType()->print(printer, printOptions);
Expand Down
Loading