Skip to content

Commit 07bcf5a

Browse files
authored
Merge pull request #26995 from CodaFi/a-sign-from-on-high
Requestify Inferring Generic Requirements
2 parents 99c6521 + df66f9f commit 07bcf5a

15 files changed

+287
-242
lines changed

include/swift/AST/ASTTypeIDZone.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ SWIFT_TYPEID_NAMED(VarDecl *, VarDecl)
1919
SWIFT_TYPEID_NAMED(ValueDecl *, ValueDecl)
2020
SWIFT_TYPEID_NAMED(ProtocolDecl *, ProtocolDecl)
2121
SWIFT_TYPEID_NAMED(Decl *, Decl)
22+
SWIFT_TYPEID_NAMED(ModuleDecl *, ModuleDecl)
2223
SWIFT_TYPEID(Type)
2324
SWIFT_TYPEID(TypePair)
2425
SWIFT_TYPEID(PropertyWrapperBackingPropertyInfo)

include/swift/AST/ASTTypeIDs.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class Decl;
2626
class GenericSignature;
2727
class GenericTypeParamType;
2828
class IterableDeclContext;
29+
class ModuleDecl;
2930
class NominalTypeDecl;
3031
class OperatorDecl;
3132
struct PropertyWrapperBackingPropertyInfo;

include/swift/AST/Decl.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1428,7 +1428,7 @@ class GenericParamList final :
14281428
/// to the given DeclContext.
14291429
GenericParamList *clone(DeclContext *dc) const;
14301430

1431-
void print(raw_ostream &OS);
1431+
void print(raw_ostream &OS) const;
14321432
void dump();
14331433
};
14341434

@@ -7261,6 +7261,9 @@ inline void simple_display(llvm::raw_ostream &out,
72617261
simple_display(out, static_cast<const Decl *>(decl));
72627262
}
72637263

7264+
/// Display GenericParamList.
7265+
void simple_display(llvm::raw_ostream &out, const GenericParamList *GPL);
7266+
72647267
/// Extract the source location from the given declaration.
72657268
SourceLoc extractNearestSourceLoc(const Decl *decl);
72667269

include/swift/AST/GenericSignatureBuilder.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,6 @@ class GenericSignatureBuilder {
615615
/// because the type \c Dictionary<K,V> cannot be formed without it.
616616
void inferRequirements(ModuleDecl &module,
617617
Type type,
618-
const TypeRepr *typeRepr,
619618
FloatingRequirementSource source);
620619

621620
/// Infer requirements from the given pattern, recursively.

include/swift/AST/TypeCheckRequests.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,43 @@ class AbstractGenericSignatureRequest :
11031103
}
11041104
};
11051105

1106+
class InferredGenericSignatureRequest :
1107+
public SimpleRequest<InferredGenericSignatureRequest,
1108+
GenericSignature *(ModuleDecl *,
1109+
GenericSignature *,
1110+
SmallVector<GenericParamList *, 2>,
1111+
SmallVector<Requirement, 2>,
1112+
SmallVector<TypeLoc, 2>,
1113+
bool),
1114+
CacheKind::Cached> {
1115+
public:
1116+
using SimpleRequest::SimpleRequest;
1117+
1118+
private:
1119+
friend SimpleRequest;
1120+
1121+
// Evaluation.
1122+
llvm::Expected<GenericSignature *>
1123+
evaluate(Evaluator &evaluator,
1124+
ModuleDecl *module,
1125+
GenericSignature *baseSignature,
1126+
SmallVector<GenericParamList *, 2> addedParameters,
1127+
SmallVector<Requirement, 2> addedRequirements,
1128+
SmallVector<TypeLoc, 2> inferenceSources,
1129+
bool allowConcreteGenericParams) const;
1130+
1131+
public:
1132+
// Separate caching.
1133+
bool isCached() const;
1134+
1135+
/// Inferred generic signature requests don't have source-location info.
1136+
SourceLoc getNearestLoc() const {
1137+
return SourceLoc();
1138+
}
1139+
};
1140+
1141+
void simple_display(llvm::raw_ostream &out, const TypeLoc source);
1142+
11061143
class ExtendedTypeRequest
11071144
: public SimpleRequest<ExtendedTypeRequest,
11081145
Type(ExtensionDecl *),
@@ -1149,6 +1186,7 @@ inline bool AnyValue::Holder<Type>::equals(const HolderBase &other) const {
11491186
}
11501187

11511188
void simple_display(llvm::raw_ostream &out, Type value);
1189+
void simple_display(llvm::raw_ostream &out, const TypeRepr *TyR);
11521190

11531191
#define SWIFT_TYPEID_ZONE TypeChecker
11541192
#define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def"

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ SWIFT_REQUEST(TypeChecker, ExistentialTypeSupportedRequest)
2929
SWIFT_REQUEST(TypeChecker, ExtendedTypeRequest)
3030
SWIFT_REQUEST(TypeChecker, FunctionBuilderTypeRequest)
3131
SWIFT_REQUEST(TypeChecker, FunctionOperatorRequest)
32+
SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest)
3233
SWIFT_REQUEST(TypeChecker, InheritedTypeRequest)
3334
SWIFT_REQUEST(TypeChecker, InitKindRequest)
3435
SWIFT_REQUEST(TypeChecker, IsAccessorTransparentRequest)

include/swift/AST/TypeLoc.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,22 @@ struct TypeLoc {
6666
void setType(Type Ty);
6767

6868
TypeLoc clone(ASTContext &ctx) const;
69+
70+
friend llvm::hash_code hash_value(const TypeLoc &owner) {
71+
return hash_combine(llvm::hash_value(owner.Ty.getPointer()),
72+
llvm::hash_value(owner.TyR));
73+
}
74+
75+
friend bool operator==(const TypeLoc &lhs,
76+
const TypeLoc &rhs) {
77+
return lhs.Ty.getPointer() == rhs.Ty.getPointer()
78+
&& lhs.TyR == rhs.TyR;
79+
}
80+
81+
friend bool operator!=(const TypeLoc &lhs,
82+
const TypeLoc &rhs) {
83+
return !(lhs == rhs);
84+
}
6985
};
7086

7187
} // end namespace llvm

lib/AST/ASTDumper.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ void RequirementRepr::print(ASTPrinter &out) const {
161161
printImpl(out, /*AsWritten=*/true);
162162
}
163163

164-
void GenericParamList::print(llvm::raw_ostream &OS) {
164+
void GenericParamList::print(llvm::raw_ostream &OS) const {
165165
OS << '<';
166166
interleave(*this,
167167
[&](const GenericTypeParamDecl *P) {

lib/AST/Decl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7676,6 +7676,12 @@ void swift::simple_display(llvm::raw_ostream &out, const ValueDecl *decl) {
76767676
else out << "(null)";
76777677
}
76787678

7679+
void swift::simple_display(llvm::raw_ostream &out, const GenericParamList *GPL) {
7680+
if (GPL) GPL->print(out);
7681+
else out << "(null)";
7682+
}
7683+
7684+
76797685
StringRef swift::getAccessorLabel(AccessorKind kind) {
76807686
switch (kind) {
76817687
#define SINGLETON_ACCESSOR(ID, KEYWORD) \

lib/AST/GenericSignatureBuilder.cpp

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5158,7 +5158,7 @@ ConstraintResult GenericSignatureBuilder::addInheritedRequirements(
51585158

51595159
auto visitType = [&](Type inheritedType, const TypeRepr *typeRepr) {
51605160
if (inferForModule) {
5161-
inferRequirements(*inferForModule, inheritedType, typeRepr,
5161+
inferRequirements(*inferForModule, inheritedType,
51625162
getFloatingSource(typeRepr, /*forInferred=*/true));
51635163
}
51645164

@@ -5200,11 +5200,9 @@ GenericSignatureBuilder::addRequirement(const Requirement &req,
52005200

52015201
if (inferForModule) {
52025202
inferRequirements(*inferForModule, firstType,
5203-
RequirementRepr::getFirstTypeRepr(reqRepr),
52045203
source.asInferred(
52055204
RequirementRepr::getFirstTypeRepr(reqRepr)));
52065205
inferRequirements(*inferForModule, secondType,
5207-
RequirementRepr::getSecondTypeRepr(reqRepr),
52085206
source.asInferred(
52095207
RequirementRepr::getSecondTypeRepr(reqRepr)));
52105208
}
@@ -5217,7 +5215,6 @@ GenericSignatureBuilder::addRequirement(const Requirement &req,
52175215
case RequirementKind::Layout: {
52185216
if (inferForModule) {
52195217
inferRequirements(*inferForModule, firstType,
5220-
RequirementRepr::getFirstTypeRepr(reqRepr),
52215218
source.asInferred(
52225219
RequirementRepr::getFirstTypeRepr(reqRepr)));
52235220
}
@@ -5231,11 +5228,9 @@ GenericSignatureBuilder::addRequirement(const Requirement &req,
52315228

52325229
if (inferForModule) {
52335230
inferRequirements(*inferForModule, firstType,
5234-
RequirementRepr::getFirstTypeRepr(reqRepr),
52355231
source.asInferred(
52365232
RequirementRepr::getFirstTypeRepr(reqRepr)));
52375233
inferRequirements(*inferForModule, secondType,
5238-
RequirementRepr::getSecondTypeRepr(reqRepr),
52395234
source.asInferred(
52405235
RequirementRepr::getSecondTypeRepr(reqRepr)));
52415236
}
@@ -5322,7 +5317,6 @@ class GenericSignatureBuilder::InferRequirementsWalker : public TypeWalker {
53225317
void GenericSignatureBuilder::inferRequirements(
53235318
ModuleDecl &module,
53245319
Type type,
5325-
const TypeRepr *typeRepr,
53265320
FloatingRequirementSource source) {
53275321
if (!type)
53285322
return;
@@ -5337,7 +5331,6 @@ void GenericSignatureBuilder::inferRequirements(
53375331
ParameterList *params) {
53385332
for (auto P : *params) {
53395333
inferRequirements(module, P->getTypeLoc().getType(),
5340-
P->getTypeLoc().getTypeRepr(),
53415334
FloatingRequirementSource::forInferred(
53425335
P->getTypeLoc().getTypeRepr()));
53435336
}
@@ -7545,6 +7538,10 @@ bool AbstractGenericSignatureRequest::isCached() const {
75457538
return true;
75467539
}
75477540

7541+
bool InferredGenericSignatureRequest::isCached() const {
7542+
return true;
7543+
}
7544+
75487545
/// Check whether the inputs to the \c AbstractGenericSignatureRequest are
75497546
/// all canonical.
75507547
static bool isCanonicalRequest(GenericSignature *baseSignature,
@@ -7678,3 +7675,100 @@ AbstractGenericSignatureRequest::evaluate(
76787675
return std::move(builder).computeGenericSignature(
76797676
SourceLoc(), /*allowConcreteGenericParams=*/true);
76807677
}
7678+
7679+
llvm::Expected<GenericSignature *>
7680+
InferredGenericSignatureRequest::evaluate(
7681+
Evaluator &evaluator, ModuleDecl *parentModule,
7682+
GenericSignature *parentSig,
7683+
SmallVector<GenericParamList *, 2> gpLists,
7684+
SmallVector<Requirement, 2> addedRequirements,
7685+
SmallVector<TypeLoc, 2> inferenceSources,
7686+
bool allowConcreteGenericParams) const {
7687+
7688+
GenericSignatureBuilder builder(parentModule->getASTContext());
7689+
7690+
// If there is a parent context, add the generic parameters and requirements
7691+
// from that context.
7692+
builder.addGenericSignature(parentSig);
7693+
7694+
// The generic parameter lists MUST appear from innermost to outermost.
7695+
// We walk them backwards to order outer requirements before
7696+
// inner requirements.
7697+
for (auto &genericParams : llvm::reverse(gpLists)) {
7698+
assert(genericParams->size() > 0 &&
7699+
"Parsed an empty generic parameter list?");
7700+
7701+
// Determine where and how to perform name lookup.
7702+
DeclContext *lookupDC = genericParams->begin()[0]->getDeclContext();
7703+
7704+
// First, add the generic parameters to the generic signature builder.
7705+
// Do this before checking the inheritance clause, since it may
7706+
// itself be dependent on one of these parameters.
7707+
for (auto param : *genericParams)
7708+
builder.addGenericParameter(param);
7709+
7710+
// Add the requirements for each of the generic parameters to the builder.
7711+
// Now, check the inheritance clauses of each parameter.
7712+
for (auto param : *genericParams)
7713+
builder.addGenericParameterRequirements(param);
7714+
7715+
// Add the requirements clause to the builder.
7716+
7717+
WhereClauseOwner owner(lookupDC, genericParams);
7718+
using FloatingRequirementSource =
7719+
GenericSignatureBuilder::FloatingRequirementSource;
7720+
RequirementRequest::visitRequirements(owner, TypeResolutionStage::Structural,
7721+
[&](const Requirement &req, RequirementRepr *reqRepr) {
7722+
auto source = FloatingRequirementSource::forExplicit(reqRepr);
7723+
7724+
// If we're extending a protocol and adding a redundant requirement,
7725+
// for example, `extension Foo where Self: Foo`, then emit a
7726+
// diagnostic.
7727+
7728+
if (auto decl = owner.dc->getAsDecl()) {
7729+
if (auto extDecl = dyn_cast<ExtensionDecl>(decl)) {
7730+
auto extType = extDecl->getDeclaredInterfaceType();
7731+
auto extSelfType = extDecl->getSelfInterfaceType();
7732+
auto reqLHSType = req.getFirstType();
7733+
auto reqRHSType = req.getSecondType();
7734+
7735+
if (extType->isExistentialType() &&
7736+
reqLHSType->isEqual(extSelfType) &&
7737+
reqRHSType->isEqual(extType)) {
7738+
7739+
auto &ctx = extDecl->getASTContext();
7740+
ctx.Diags.diagnose(extDecl->getLoc(),
7741+
diag::protocol_extension_redundant_requirement,
7742+
extType->getString(),
7743+
extSelfType->getString(),
7744+
reqRHSType->getString());
7745+
}
7746+
}
7747+
}
7748+
7749+
builder.addRequirement(req, reqRepr, source, nullptr,
7750+
lookupDC->getParentModule());
7751+
return false;
7752+
});
7753+
}
7754+
7755+
/// Perform any remaining requirement inference.
7756+
for (auto sourcePair : inferenceSources) {
7757+
auto source =
7758+
FloatingRequirementSource::forInferred(sourcePair.getTypeRepr());
7759+
7760+
builder.inferRequirements(*parentModule,
7761+
sourcePair.getType(),
7762+
source);
7763+
}
7764+
7765+
// Finish by adding any remaining requirements.
7766+
auto source =
7767+
FloatingRequirementSource::forInferred(nullptr);
7768+
7769+
for (const auto &req : addedRequirements)
7770+
builder.addRequirement(req, source, parentModule);
7771+
7772+
return std::move(builder).computeGenericSignature(
7773+
SourceLoc(), allowConcreteGenericParams);
7774+
}

lib/AST/TypeCheckRequests.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,21 @@ void swift::simple_display(llvm::raw_ostream &out, Type type) {
6666
out << "null";
6767
}
6868

69+
void swift::simple_display(llvm::raw_ostream &out, const TypeRepr *TyR) {
70+
if (TyR)
71+
TyR->print(out);
72+
else
73+
out << "null";
74+
}
75+
76+
void swift::simple_display(llvm::raw_ostream &out, const TypeLoc source) {
77+
out << "(";
78+
simple_display(out, source.getType());
79+
out << ", ";
80+
simple_display(out, source.getTypeRepr());
81+
out << ")";
82+
}
83+
6984
//----------------------------------------------------------------------------//
7085
// Inherited type computation.
7186
//----------------------------------------------------------------------------//

0 commit comments

Comments
 (0)