Skip to content

Commit 09a4e16

Browse files
authored
Merge pull request #80301 from slavapestov/subst-generic-function-type
Clean up GenericFunctionType substitution
2 parents 119fbac + 1859533 commit 09a4e16

15 files changed

+97
-147
lines changed

include/swift/AST/Types.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4150,7 +4150,6 @@ class GenericFunctionType final
41504150
/// function type and return the resulting non-generic type.
41514151
FunctionType *substGenericArgs(SubstitutionMap subs,
41524152
SubstOptions options = std::nullopt);
4153-
FunctionType *substGenericArgs(llvm::function_ref<Type(Type)> substFn) const;
41544153

41554154
void Profile(llvm::FoldingSetNodeID &ID) {
41564155
std::optional<ExtInfo> info = std::nullopt;

lib/AST/ASTContext.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5754,7 +5754,7 @@ ProtocolConformanceRef ProtocolConformanceRef::forAbstract(
57545754
properties |= conformingType->getRecursiveProperties();
57555755
auto arena = getArena(properties);
57565756

5757-
// Profile the substitution map.
5757+
// Form the folding set key.
57585758
llvm::FoldingSetNodeID id;
57595759
AbstractConformance::Profile(id, conformingType, proto);
57605760

lib/AST/ASTMangler.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -828,14 +828,12 @@ std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
828828
// Since we don't have a distinct mangling for sugared generic
829829
// parameter types, we must desugar them here.
830830
static Type getTypeForDWARFMangling(Type t) {
831-
return t.subst(
832-
[](SubstitutableType *t) -> Type {
833-
if (t->isRootParameterPack()) {
834-
return PackType::getSingletonPackExpansion(t->getCanonicalType());
835-
}
836-
return t->getCanonicalType();
837-
},
838-
MakeAbstractConformanceForGenericType());
831+
return t.transformRec(
832+
[](TypeBase *t) -> std::optional<Type> {
833+
if (isa<GenericTypeParamType>(t))
834+
return t->getCanonicalType();
835+
return std::nullopt;
836+
});
839837
}
840838

841839
std::string ASTMangler::mangleTypeForDebugger(Type Ty, GenericSignature sig) {

lib/AST/AbstractConformance.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212
//
13-
// This file defines the AbstractConformance class, which stores an
14-
// abstract conformance that is stashed in a ProtocolConformanceRef.
13+
// This file defines the AbstractConformance class, which represents
14+
// the conformance of a type parameter or archetype to a protocol.
15+
// These are usually stashed inside a ProtocolConformanceRef.
1516
//
1617
//===----------------------------------------------------------------------===//
1718
#ifndef SWIFT_AST_ABSTRACT_CONFORMANCE_H
@@ -36,7 +37,7 @@ class AbstractConformance final : public llvm::FoldingSetNode {
3637
Profile(id, getType(), getProtocol());
3738
}
3839

39-
/// Profile the substitution map storage, for use with LLVM's FoldingSet.
40+
/// Profile the storage for this conformance, for use with LLVM's FoldingSet.
4041
static void Profile(llvm::FoldingSetNodeID &id,
4142
Type conformingType,
4243
ProtocolDecl *requirement) {

lib/AST/TypeSubstitution.cpp

Lines changed: 7 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -61,29 +61,11 @@ Type QuerySubstitutionMap::operator()(SubstitutableType *type) const {
6161
FunctionType *
6262
GenericFunctionType::substGenericArgs(SubstitutionMap subs,
6363
SubstOptions options) {
64-
return substGenericArgs(
65-
[=](Type t) { return t.subst(subs, options); });
66-
}
67-
68-
FunctionType *GenericFunctionType::substGenericArgs(
69-
llvm::function_ref<Type(Type)> substFn) const {
70-
llvm::SmallVector<AnyFunctionType::Param, 4> params;
71-
params.reserve(getNumParams());
72-
73-
llvm::transform(getParams(), std::back_inserter(params),
74-
[&](const AnyFunctionType::Param &param) {
75-
return param.withType(substFn(param.getPlainType()));
76-
});
77-
78-
auto resultTy = substFn(getResult());
79-
80-
Type thrownError = getThrownError();
81-
if (thrownError)
82-
thrownError = substFn(thrownError);
83-
84-
// Build the resulting (non-generic) function type.
85-
return FunctionType::get(params, resultTy,
86-
getExtInfo().withThrows(isThrowing(), thrownError));
64+
// FIXME: Before dropping the signature, we should assert that
65+
// subs.getGenericSignature() is equal to this function type's
66+
// generic signature.
67+
Type fnType = FunctionType::get(getParams(), getResult(), getExtInfo());
68+
return fnType.subst(subs, options)->castTo<FunctionType>();
8769
}
8870

8971
CanFunctionType
@@ -170,73 +152,6 @@ operator()(CanType dependentType, Type conformingReplacementType,
170152
conformingReplacementType, conformedProtocol);
171153
}
172154

173-
static Type substGenericFunctionType(GenericFunctionType *genericFnType,
174-
InFlightSubstitution &IFS) {
175-
// Substitute into the function type (without generic signature).
176-
auto *bareFnType = FunctionType::get(genericFnType->getParams(),
177-
genericFnType->getResult(),
178-
genericFnType->getExtInfo());
179-
Type result = Type(bareFnType).subst(IFS);
180-
if (!result || result->is<ErrorType>()) return result;
181-
182-
auto *fnType = result->castTo<FunctionType>();
183-
// Substitute generic parameters.
184-
bool anySemanticChanges = false;
185-
SmallVector<GenericTypeParamType *, 2> genericParams;
186-
for (auto param : genericFnType->getGenericParams()) {
187-
Type paramTy = Type(param).subst(IFS);
188-
if (!paramTy)
189-
return Type();
190-
191-
if (auto newParam = paramTy->getAs<GenericTypeParamType>()) {
192-
if (!newParam->isEqual(param))
193-
anySemanticChanges = true;
194-
195-
genericParams.push_back(newParam);
196-
} else {
197-
anySemanticChanges = true;
198-
}
199-
}
200-
201-
// If no generic parameters remain, this is a non-generic function type.
202-
if (genericParams.empty())
203-
return result;
204-
205-
// Transform requirements.
206-
SmallVector<Requirement, 2> requirements;
207-
for (const auto &req : genericFnType->getRequirements()) {
208-
// Substitute into the requirement.
209-
auto substReqt = req.subst(IFS);
210-
211-
// Did anything change?
212-
if (!anySemanticChanges &&
213-
(!req.getFirstType()->isEqual(substReqt.getFirstType()) ||
214-
(req.getKind() != RequirementKind::Layout &&
215-
!req.getSecondType()->isEqual(substReqt.getSecondType())))) {
216-
anySemanticChanges = true;
217-
}
218-
219-
requirements.push_back(substReqt);
220-
}
221-
222-
GenericSignature genericSig;
223-
if (anySemanticChanges) {
224-
// If there were semantic changes, we need to build a new generic
225-
// signature.
226-
ASTContext &ctx = genericFnType->getASTContext();
227-
genericSig = buildGenericSignature(ctx, GenericSignature(),
228-
genericParams, requirements,
229-
/*allowInverses=*/false);
230-
} else {
231-
// Use the mapped generic signature.
232-
genericSig = GenericSignature::get(genericParams, requirements);
233-
}
234-
235-
// Produce the new generic function type.
236-
return GenericFunctionType::get(genericSig, fnType->getParams(),
237-
fnType->getResult(), fnType->getExtInfo());
238-
}
239-
240155
InFlightSubstitution::InFlightSubstitution(TypeSubstitutionFn substType,
241156
LookupConformanceFn lookupConformance,
242157
SubstOptions options)
@@ -558,11 +473,8 @@ Type Type::subst(TypeSubstitutionFn substitutions,
558473
}
559474

560475
Type Type::subst(InFlightSubstitution &IFS) const {
561-
// Handle substitutions into generic function types.
562-
// FIXME: This should be banned.
563-
if (auto genericFnType = getPointer()->getAs<GenericFunctionType>()) {
564-
return substGenericFunctionType(genericFnType, IFS);
565-
}
476+
ASSERT(!getPointer()->getAs<GenericFunctionType>() &&
477+
"Perhaps you want GenericFunctionType::substGenericArgs() instead");
566478

567479
if (IFS.isInvariant(*this))
568480
return *this;

lib/SILGen/SILGenExpr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4265,16 +4265,16 @@ static void lowerKeyPathMemberIndexTypes(
42654265
if (auto subscript = dyn_cast<SubscriptDecl>(decl)) {
42664266
auto subscriptSubstTy = subscript->getInterfaceType();
42674267
auto sig = subscript->getGenericSignature();
4268-
if (sig) {
4269-
subscriptSubstTy = subscriptSubstTy.subst(memberSubs);
4268+
if (auto *subscriptGenericTy = subscriptSubstTy->getAs<GenericFunctionType>()) {
4269+
subscriptSubstTy = subscriptGenericTy->substGenericArgs(memberSubs);
42704270
}
42714271
needsGenericContext |= subscriptSubstTy->hasArchetype();
42724272
processIndicesOrParameters(subscript->getIndices(), &sig);
42734273
} else if (auto method = dyn_cast<AbstractFunctionDecl>(decl)) {
42744274
auto methodSubstTy = method->getInterfaceType();
42754275
auto sig = method->getGenericSignature();
4276-
if (sig) {
4277-
methodSubstTy = methodSubstTy.subst(memberSubs);
4276+
if (auto *methodGenericTy = methodSubstTy->getAs<GenericFunctionType>()) {
4277+
methodSubstTy = methodGenericTy->substGenericArgs(memberSubs);
42784278
}
42794279
needsGenericContext |= methodSubstTy->hasArchetype();
42804280
processIndicesOrParameters(method->getParameters(), &sig);

lib/SILGen/SILGenType.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -946,8 +946,7 @@ static SILFunction *emitSelfConformanceWitness(SILGenModule &SGM,
946946
openedConf);
947947

948948
// Substitute to get the formal substituted type of the thunk.
949-
auto reqtSubstTy =
950-
cast<AnyFunctionType>(reqtOrigTy.subst(reqtSubs)->getCanonicalType());
949+
auto reqtSubstTy = reqtOrigTy.substGenericArgs(reqtSubs);
951950

952951
// Substitute into the requirement type to get the type of the thunk.
953952
auto witnessSILFnType = requirementInfo.SILFnType->substGenericArgs(

lib/Sema/CodeSynthesis.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,9 @@ synthesizeDesignatedInitOverride(AbstractFunctionDecl *fn, void *context) {
666666
.subst(subs);
667667
ConcreteDeclRef ctorRef(superclassCtor, subs);
668668

669-
auto type = superclassCtor->getInitializerInterfaceType().subst(subs);
669+
auto type = superclassCtor->getInitializerInterfaceType();
670+
if (auto *genericFnType = type->getAs<GenericFunctionType>())
671+
type = genericFnType->substGenericArgs(subs);
670672
auto *ctorRefExpr =
671673
new (ctx) OtherConstructorDeclRefExpr(ctorRef, DeclNameLoc(),
672674
IsImplicit, type);

lib/Sema/DerivedConformanceDistributedActor.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -714,13 +714,9 @@ deriveBodyDistributedActor_unownedExecutor(AbstractFunctionDecl *getter, void *)
714714
auto substitutions = SubstitutionMap::get(
715715
buildRemoteExecutorDecl->getGenericSignature(),
716716
[&](SubstitutableType *dependentType) {
717-
if (auto gp = dyn_cast<GenericTypeParamType>(dependentType)) {
718-
if (gp->getDepth() == 0 && gp->getIndex() == 0) {
719-
return getter->getImplicitSelfDecl()->getTypeInContext();
720-
}
721-
}
722-
723-
return Type();
717+
auto gp = cast<GenericTypeParamType>(dependentType);
718+
ASSERT(gp->getDepth() == 0 && gp->getIndex() == 0);
719+
return getter->getImplicitSelfDecl()->getTypeInContext();
724720
},
725721
LookUpConformanceInModule()
726722
);
@@ -730,8 +726,8 @@ deriveBodyDistributedActor_unownedExecutor(AbstractFunctionDecl *getter, void *)
730726
DeclNameLoc(),/*implicit=*/true,
731727
AccessSemantics::Ordinary);
732728
buildRemoteExecutorExpr->setType(
733-
buildRemoteExecutorDecl->getInterfaceType()
734-
.subst(substitutions)
729+
buildRemoteExecutorDecl->getInterfaceType()->castTo<GenericFunctionType>()
730+
->substGenericArgs(substitutions)
735731
);
736732

737733
Expr *selfForBuildRemoteExecutor = DerivedConformance::createSelfDeclRef(getter);

lib/Sema/DerivedConformanceEquatableHashable.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -829,23 +829,20 @@ deriveBodyHashable_hashValue(AbstractFunctionDecl *hashValueDecl, void *) {
829829
auto substitutions = SubstitutionMap::get(
830830
hashFunc->getGenericSignature(),
831831
[&](SubstitutableType *dependentType) {
832-
if (auto gp = dyn_cast<GenericTypeParamType>(dependentType)) {
833-
if (gp->getDepth() == 0 && gp->getIndex() == 0)
834-
return selfType;
835-
}
836-
837-
return Type(dependentType);
832+
auto gp = cast<GenericTypeParamType>(dependentType);
833+
ASSERT(gp->getDepth() == 0 && gp->getIndex() == 0);
834+
return selfType;
838835
},
839836
LookUpConformanceInModule());
840837
ConcreteDeclRef hashFuncRef(hashFunc, substitutions);
841838

842-
Type hashFuncType = hashFunc->getInterfaceType().subst(substitutions);
839+
auto *hashFuncType = hashFunc->getInterfaceType()->castTo<GenericFunctionType>()
840+
->substGenericArgs(substitutions);
843841
auto hashExpr = new (C) DeclRefExpr(hashFuncRef, DeclNameLoc(),
844842
/*implicit*/ true,
845843
AccessSemantics::Ordinary,
846844
hashFuncType);
847-
Type hashFuncResultType =
848-
hashFuncType->castTo<AnyFunctionType>()->getResult();
845+
Type hashFuncResultType = hashFuncType->getResult();
849846
auto *argList = ArgumentList::forImplicitSingle(C, C.Id_for, selfRef);
850847
auto *callExpr = CallExpr::createImplicit(C, hashExpr, argList);
851848
callExpr->setType(hashFuncResultType);

lib/Sema/TypeCheckEffects.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,8 +1731,12 @@ class ApplyClassifier {
17311731
}
17321732

17331733
// Use the most significant result from the arguments.
1734-
auto *fnSubstType = fnInterfaceType.subst(fnRef.getSubstitutions())
1735-
->getAs<AnyFunctionType>();
1734+
FunctionType *fnSubstType = nullptr;
1735+
if (auto *fnGenericType = fnInterfaceType->getAs<GenericFunctionType>())
1736+
fnSubstType = fnGenericType->substGenericArgs(fnRef.getSubstitutions());
1737+
else
1738+
fnSubstType = fnInterfaceType->getAs<FunctionType>();
1739+
17361740
if (!fnSubstType) {
17371741
result.merge(Classification::forInvalidCode());
17381742
return;

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3856,15 +3856,22 @@ filterProtocolRequirements(
38563856

38573857
auto OverloadTy = Req->getOverloadSignatureType();
38583858
if (OverloadTy) {
3859-
OverloadTy =
3860-
OverloadTy.subst(getProtocolSubstitutionMap(Req))->getCanonicalType();
3859+
auto Subs = getProtocolSubstitutionMap(Req);
3860+
// FIXME: This is wrong if the overload has its own generic parameters
3861+
if (auto GenericFnTy = dyn_cast<GenericFunctionType>(OverloadTy))
3862+
OverloadTy = GenericFnTy.substGenericArgs(Subs);
3863+
else
3864+
OverloadTy = OverloadTy.subst(Subs)->getCanonicalType();
38613865
}
38623866
if (llvm::any_of(DeclsByName[Req->getName()], [&](ValueDecl *OtherReq) {
38633867
auto OtherOverloadTy = OtherReq->getOverloadSignatureType();
38643868
if (OtherOverloadTy) {
3865-
OtherOverloadTy =
3866-
OtherOverloadTy.subst(getProtocolSubstitutionMap(OtherReq))
3867-
->getCanonicalType();
3869+
auto Subs = getProtocolSubstitutionMap(OtherReq);
3870+
// FIXME: This is wrong if the overload has its own generic parameters
3871+
if (auto GenericFnTy = dyn_cast<GenericFunctionType>(OtherOverloadTy))
3872+
OtherOverloadTy = GenericFnTy.substGenericArgs(Subs);
3873+
else
3874+
OtherOverloadTy = OtherOverloadTy.subst(Subs)->getCanonicalType();
38683875
}
38693876
return conflicting(Req->getASTContext(), Req->getOverloadSignature(),
38703877
OverloadTy, OtherReq->getOverloadSignature(),
@@ -7397,7 +7404,10 @@ bool swift::forEachConformance(
73977404
if (forEachConformance(subs, body, visitedConformances))
73987405
return true;
73997406

7400-
type = type.subst(subs);
7407+
if (auto *genericFnType = type->getAs<GenericFunctionType>())
7408+
type = genericFnType->substGenericArgs(subs);
7409+
else
7410+
type = type.subst(subs);
74017411
}
74027412

74037413
if (forEachConformance(type, body, visitedConformances))

lib/Sema/TypeCheckStorage.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,7 +1910,9 @@ synthesizeObservedSetterBody(AccessorDecl *Set, TargetImpl target,
19101910

19111911
auto callObserver = [&](AccessorDecl *observer, VarDecl *arg) {
19121912
ConcreteDeclRef ref(observer, subs);
1913-
auto type = observer->getInterfaceType().subst(subs);
1913+
auto type = observer->getInterfaceType();
1914+
if (auto *genericFnType = type->getAs<GenericFunctionType>())
1915+
type = genericFnType->substGenericArgs(subs);
19141916
Expr *Callee = new (Ctx) DeclRefExpr(ref, DeclNameLoc(), /*imp*/true);
19151917
Callee->setType(type);
19161918

@@ -2101,7 +2103,9 @@ synthesizeModifyCoroutineBodyWithSimpleDidSet(AccessorDecl *accessor,
21012103

21022104
auto callDidSet = [&]() {
21032105
ConcreteDeclRef ref(DidSet, subs);
2104-
auto type = DidSet->getInterfaceType().subst(subs);
2106+
auto type = DidSet->getInterfaceType();
2107+
if (auto *genericFnType = type->getAs<GenericFunctionType>())
2108+
type = genericFnType->substGenericArgs(subs);
21052109
Expr *Callee = new (ctx) DeclRefExpr(ref, DeclNameLoc(), /*imp*/ true);
21062110
Callee->setType(type);
21072111

lib/Sema/TypeCheckUnsafe.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,12 @@ bool swift::enumerateUnsafeUses(ConcreteDeclRef declRef,
208208
auto subs = declRef.getSubstitutions();
209209
{
210210
auto type = decl->getInterfaceType();
211-
if (subs)
212-
type = type.subst(subs);
211+
if (subs) {
212+
if (auto *genericFnType = type->getAs<GenericFunctionType>())
213+
type = genericFnType->substGenericArgs(subs);
214+
else
215+
type = type.subst(subs);
216+
}
213217

214218
bool shouldReturnTrue = false;
215219
diagnoseUnsafeType(ctx, loc, type, [&](Type unsafeType) {

0 commit comments

Comments
 (0)