Skip to content

Clean up GenericFunctionType substitution #80301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 27, 2025
1 change: 0 additions & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4146,7 +4146,6 @@ class GenericFunctionType final
/// function type and return the resulting non-generic type.
FunctionType *substGenericArgs(SubstitutionMap subs,
SubstOptions options = std::nullopt);
FunctionType *substGenericArgs(llvm::function_ref<Type(Type)> substFn) const;

void Profile(llvm::FoldingSetNodeID &ID) {
std::optional<ExtInfo> info = std::nullopt;
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5754,7 +5754,7 @@ ProtocolConformanceRef ProtocolConformanceRef::forAbstract(
properties |= conformingType->getRecursiveProperties();
auto arena = getArena(properties);

// Profile the substitution map.
// Form the folding set key.
llvm::FoldingSetNodeID id;
AbstractConformance::Profile(id, conformingType, proto);

Expand Down
14 changes: 6 additions & 8 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -828,14 +828,12 @@ std::string ASTMangler::mangleAutoDiffGeneratedDeclaration(
// Since we don't have a distinct mangling for sugared generic
// parameter types, we must desugar them here.
static Type getTypeForDWARFMangling(Type t) {
return t.subst(
[](SubstitutableType *t) -> Type {
if (t->isRootParameterPack()) {
return PackType::getSingletonPackExpansion(t->getCanonicalType());
}
return t->getCanonicalType();
},
MakeAbstractConformanceForGenericType());
return t.transformRec(
[](TypeBase *t) -> std::optional<Type> {
if (isa<GenericTypeParamType>(t))
return t->getCanonicalType();
return std::nullopt;
});
}

std::string ASTMangler::mangleTypeForDebugger(Type Ty, GenericSignature sig) {
Expand Down
7 changes: 4 additions & 3 deletions lib/AST/AbstractConformance.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
//
//===----------------------------------------------------------------------===//
//
// This file defines the AbstractConformance class, which stores an
// abstract conformance that is stashed in a ProtocolConformanceRef.
// This file defines the AbstractConformance class, which represents
// the conformance of a type parameter or archetype to a protocol.
// These are usually stashed inside a ProtocolConformanceRef.
//
//===----------------------------------------------------------------------===//
#ifndef SWIFT_AST_ABSTRACT_CONFORMANCE_H
Expand All @@ -36,7 +37,7 @@ class AbstractConformance final : public llvm::FoldingSetNode {
Profile(id, getType(), getProtocol());
}

/// Profile the substitution map storage, for use with LLVM's FoldingSet.
/// Profile the storage for this conformance, for use with LLVM's FoldingSet.
static void Profile(llvm::FoldingSetNodeID &id,
Type conformingType,
ProtocolDecl *requirement) {
Expand Down
102 changes: 7 additions & 95 deletions lib/AST/TypeSubstitution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,29 +61,11 @@ Type QuerySubstitutionMap::operator()(SubstitutableType *type) const {
FunctionType *
GenericFunctionType::substGenericArgs(SubstitutionMap subs,
SubstOptions options) {
return substGenericArgs(
[=](Type t) { return t.subst(subs, options); });
}

FunctionType *GenericFunctionType::substGenericArgs(
llvm::function_ref<Type(Type)> substFn) const {
llvm::SmallVector<AnyFunctionType::Param, 4> params;
params.reserve(getNumParams());

llvm::transform(getParams(), std::back_inserter(params),
[&](const AnyFunctionType::Param &param) {
return param.withType(substFn(param.getPlainType()));
});

auto resultTy = substFn(getResult());

Type thrownError = getThrownError();
if (thrownError)
thrownError = substFn(thrownError);

// Build the resulting (non-generic) function type.
return FunctionType::get(params, resultTy,
getExtInfo().withThrows(isThrowing(), thrownError));
// FIXME: Before dropping the signature, we should assert that
// subs.getGenericSignature() is equal to this function type's
// generic signature.
Type fnType = FunctionType::get(getParams(), getResult(), getExtInfo());
return fnType.subst(subs, options)->castTo<FunctionType>();
}

CanFunctionType
Expand Down Expand Up @@ -170,73 +152,6 @@ operator()(CanType dependentType, Type conformingReplacementType,
conformingReplacementType, conformedProtocol);
}

static Type substGenericFunctionType(GenericFunctionType *genericFnType,
InFlightSubstitution &IFS) {
// Substitute into the function type (without generic signature).
auto *bareFnType = FunctionType::get(genericFnType->getParams(),
genericFnType->getResult(),
genericFnType->getExtInfo());
Type result = Type(bareFnType).subst(IFS);
if (!result || result->is<ErrorType>()) return result;

auto *fnType = result->castTo<FunctionType>();
// Substitute generic parameters.
bool anySemanticChanges = false;
SmallVector<GenericTypeParamType *, 2> genericParams;
for (auto param : genericFnType->getGenericParams()) {
Type paramTy = Type(param).subst(IFS);
if (!paramTy)
return Type();

if (auto newParam = paramTy->getAs<GenericTypeParamType>()) {
if (!newParam->isEqual(param))
anySemanticChanges = true;

genericParams.push_back(newParam);
} else {
anySemanticChanges = true;
}
}

// If no generic parameters remain, this is a non-generic function type.
if (genericParams.empty())
return result;

// Transform requirements.
SmallVector<Requirement, 2> requirements;
for (const auto &req : genericFnType->getRequirements()) {
// Substitute into the requirement.
auto substReqt = req.subst(IFS);

// Did anything change?
if (!anySemanticChanges &&
(!req.getFirstType()->isEqual(substReqt.getFirstType()) ||
(req.getKind() != RequirementKind::Layout &&
!req.getSecondType()->isEqual(substReqt.getSecondType())))) {
anySemanticChanges = true;
}

requirements.push_back(substReqt);
}

GenericSignature genericSig;
if (anySemanticChanges) {
// If there were semantic changes, we need to build a new generic
// signature.
ASTContext &ctx = genericFnType->getASTContext();
genericSig = buildGenericSignature(ctx, GenericSignature(),
genericParams, requirements,
/*allowInverses=*/false);
} else {
// Use the mapped generic signature.
genericSig = GenericSignature::get(genericParams, requirements);
}

// Produce the new generic function type.
return GenericFunctionType::get(genericSig, fnType->getParams(),
fnType->getResult(), fnType->getExtInfo());
}

InFlightSubstitution::InFlightSubstitution(TypeSubstitutionFn substType,
LookupConformanceFn lookupConformance,
SubstOptions options)
Expand Down Expand Up @@ -558,11 +473,8 @@ Type Type::subst(TypeSubstitutionFn substitutions,
}

Type Type::subst(InFlightSubstitution &IFS) const {
// Handle substitutions into generic function types.
// FIXME: This should be banned.
if (auto genericFnType = getPointer()->getAs<GenericFunctionType>()) {
return substGenericFunctionType(genericFnType, IFS);
}
ASSERT(!getPointer()->getAs<GenericFunctionType>() &&
"Perhaps you want GenericFunctionType::substGenericArgs() instead");

if (IFS.isInvariant(*this))
return *this;
Expand Down
8 changes: 4 additions & 4 deletions lib/SILGen/SILGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4265,16 +4265,16 @@ static void lowerKeyPathMemberIndexTypes(
if (auto subscript = dyn_cast<SubscriptDecl>(decl)) {
auto subscriptSubstTy = subscript->getInterfaceType();
auto sig = subscript->getGenericSignature();
if (sig) {
subscriptSubstTy = subscriptSubstTy.subst(memberSubs);
if (auto *subscriptGenericTy = subscriptSubstTy->getAs<GenericFunctionType>()) {
subscriptSubstTy = subscriptGenericTy->substGenericArgs(memberSubs);
}
needsGenericContext |= subscriptSubstTy->hasArchetype();
processIndicesOrParameters(subscript->getIndices(), &sig);
} else if (auto method = dyn_cast<AbstractFunctionDecl>(decl)) {
auto methodSubstTy = method->getInterfaceType();
auto sig = method->getGenericSignature();
if (sig) {
methodSubstTy = methodSubstTy.subst(memberSubs);
if (auto *methodGenericTy = methodSubstTy->getAs<GenericFunctionType>()) {
methodSubstTy = methodGenericTy->substGenericArgs(memberSubs);
}
needsGenericContext |= methodSubstTy->hasArchetype();
processIndicesOrParameters(method->getParameters(), &sig);
Expand Down
3 changes: 1 addition & 2 deletions lib/SILGen/SILGenType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,7 @@ static SILFunction *emitSelfConformanceWitness(SILGenModule &SGM,
openedConf);

// Substitute to get the formal substituted type of the thunk.
auto reqtSubstTy =
cast<AnyFunctionType>(reqtOrigTy.subst(reqtSubs)->getCanonicalType());
auto reqtSubstTy = reqtOrigTy.substGenericArgs(reqtSubs);

// Substitute into the requirement type to get the type of the thunk.
auto witnessSILFnType = requirementInfo.SILFnType->substGenericArgs(
Expand Down
4 changes: 3 additions & 1 deletion lib/Sema/CodeSynthesis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,9 @@ synthesizeDesignatedInitOverride(AbstractFunctionDecl *fn, void *context) {
.subst(subs);
ConcreteDeclRef ctorRef(superclassCtor, subs);

auto type = superclassCtor->getInitializerInterfaceType().subst(subs);
auto type = superclassCtor->getInitializerInterfaceType();
if (auto *genericFnType = type->getAs<GenericFunctionType>())
type = genericFnType->substGenericArgs(subs);
auto *ctorRefExpr =
new (ctx) OtherConstructorDeclRefExpr(ctorRef, DeclNameLoc(),
IsImplicit, type);
Expand Down
14 changes: 5 additions & 9 deletions lib/Sema/DerivedConformanceDistributedActor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,9 @@ deriveBodyDistributedActor_unownedExecutor(AbstractFunctionDecl *getter, void *)
auto substitutions = SubstitutionMap::get(
buildRemoteExecutorDecl->getGenericSignature(),
[&](SubstitutableType *dependentType) {
if (auto gp = dyn_cast<GenericTypeParamType>(dependentType)) {
if (gp->getDepth() == 0 && gp->getIndex() == 0) {
return getter->getImplicitSelfDecl()->getTypeInContext();
}
}

return Type();
auto gp = cast<GenericTypeParamType>(dependentType);
ASSERT(gp->getDepth() == 0 && gp->getIndex() == 0);
return getter->getImplicitSelfDecl()->getTypeInContext();
},
LookUpConformanceInModule()
);
Expand All @@ -730,8 +726,8 @@ deriveBodyDistributedActor_unownedExecutor(AbstractFunctionDecl *getter, void *)
DeclNameLoc(),/*implicit=*/true,
AccessSemantics::Ordinary);
buildRemoteExecutorExpr->setType(
buildRemoteExecutorDecl->getInterfaceType()
.subst(substitutions)
buildRemoteExecutorDecl->getInterfaceType()->castTo<GenericFunctionType>()
->substGenericArgs(substitutions)
);

Expr *selfForBuildRemoteExecutor = DerivedConformance::createSelfDeclRef(getter);
Expand Down
15 changes: 6 additions & 9 deletions lib/Sema/DerivedConformanceEquatableHashable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -829,23 +829,20 @@ deriveBodyHashable_hashValue(AbstractFunctionDecl *hashValueDecl, void *) {
auto substitutions = SubstitutionMap::get(
hashFunc->getGenericSignature(),
[&](SubstitutableType *dependentType) {
if (auto gp = dyn_cast<GenericTypeParamType>(dependentType)) {
if (gp->getDepth() == 0 && gp->getIndex() == 0)
return selfType;
}

return Type(dependentType);
auto gp = cast<GenericTypeParamType>(dependentType);
ASSERT(gp->getDepth() == 0 && gp->getIndex() == 0);
return selfType;
},
LookUpConformanceInModule());
ConcreteDeclRef hashFuncRef(hashFunc, substitutions);

Type hashFuncType = hashFunc->getInterfaceType().subst(substitutions);
auto *hashFuncType = hashFunc->getInterfaceType()->castTo<GenericFunctionType>()
->substGenericArgs(substitutions);
auto hashExpr = new (C) DeclRefExpr(hashFuncRef, DeclNameLoc(),
/*implicit*/ true,
AccessSemantics::Ordinary,
hashFuncType);
Type hashFuncResultType =
hashFuncType->castTo<AnyFunctionType>()->getResult();
Type hashFuncResultType = hashFuncType->getResult();
auto *argList = ArgumentList::forImplicitSingle(C, C.Id_for, selfRef);
auto *callExpr = CallExpr::createImplicit(C, hashExpr, argList);
callExpr->setType(hashFuncResultType);
Expand Down
8 changes: 6 additions & 2 deletions lib/Sema/TypeCheckEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1731,8 +1731,12 @@ class ApplyClassifier {
}

// Use the most significant result from the arguments.
auto *fnSubstType = fnInterfaceType.subst(fnRef.getSubstitutions())
->getAs<AnyFunctionType>();
FunctionType *fnSubstType = nullptr;
if (auto *fnGenericType = fnInterfaceType->getAs<GenericFunctionType>())
fnSubstType = fnGenericType->substGenericArgs(fnRef.getSubstitutions());
else
fnSubstType = fnInterfaceType->getAs<FunctionType>();

if (!fnSubstType) {
result.merge(Classification::forInvalidCode());
return;
Expand Down
22 changes: 16 additions & 6 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3856,15 +3856,22 @@ filterProtocolRequirements(

auto OverloadTy = Req->getOverloadSignatureType();
if (OverloadTy) {
OverloadTy =
OverloadTy.subst(getProtocolSubstitutionMap(Req))->getCanonicalType();
auto Subs = getProtocolSubstitutionMap(Req);
// FIXME: This is wrong if the overload has its own generic parameters
if (auto GenericFnTy = dyn_cast<GenericFunctionType>(OverloadTy))
OverloadTy = GenericFnTy.substGenericArgs(Subs);
else
OverloadTy = OverloadTy.subst(Subs)->getCanonicalType();
}
if (llvm::any_of(DeclsByName[Req->getName()], [&](ValueDecl *OtherReq) {
auto OtherOverloadTy = OtherReq->getOverloadSignatureType();
if (OtherOverloadTy) {
OtherOverloadTy =
OtherOverloadTy.subst(getProtocolSubstitutionMap(OtherReq))
->getCanonicalType();
auto Subs = getProtocolSubstitutionMap(OtherReq);
// FIXME: This is wrong if the overload has its own generic parameters
if (auto GenericFnTy = dyn_cast<GenericFunctionType>(OtherOverloadTy))
OtherOverloadTy = GenericFnTy.substGenericArgs(Subs);
else
OtherOverloadTy = OtherOverloadTy.subst(Subs)->getCanonicalType();
}
return conflicting(Req->getASTContext(), Req->getOverloadSignature(),
OverloadTy, OtherReq->getOverloadSignature(),
Expand Down Expand Up @@ -7401,7 +7408,10 @@ bool swift::forEachConformance(
if (forEachConformance(subs, body, visitedConformances))
return true;

type = type.subst(subs);
if (auto *genericFnType = type->getAs<GenericFunctionType>())
type = genericFnType->substGenericArgs(subs);
else
type = type.subst(subs);
}

if (forEachConformance(type, body, visitedConformances))
Expand Down
8 changes: 6 additions & 2 deletions lib/Sema/TypeCheckStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1910,7 +1910,9 @@ synthesizeObservedSetterBody(AccessorDecl *Set, TargetImpl target,

auto callObserver = [&](AccessorDecl *observer, VarDecl *arg) {
ConcreteDeclRef ref(observer, subs);
auto type = observer->getInterfaceType().subst(subs);
auto type = observer->getInterfaceType();
if (auto *genericFnType = type->getAs<GenericFunctionType>())
type = genericFnType->substGenericArgs(subs);
Expr *Callee = new (Ctx) DeclRefExpr(ref, DeclNameLoc(), /*imp*/true);
Callee->setType(type);

Expand Down Expand Up @@ -2101,7 +2103,9 @@ synthesizeModifyCoroutineBodyWithSimpleDidSet(AccessorDecl *accessor,

auto callDidSet = [&]() {
ConcreteDeclRef ref(DidSet, subs);
auto type = DidSet->getInterfaceType().subst(subs);
auto type = DidSet->getInterfaceType();
if (auto *genericFnType = type->getAs<GenericFunctionType>())
type = genericFnType->substGenericArgs(subs);
Expr *Callee = new (ctx) DeclRefExpr(ref, DeclNameLoc(), /*imp*/ true);
Callee->setType(type);

Expand Down
8 changes: 6 additions & 2 deletions lib/Sema/TypeCheckUnsafe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,12 @@ bool swift::enumerateUnsafeUses(ConcreteDeclRef declRef,
auto subs = declRef.getSubstitutions();
{
auto type = decl->getInterfaceType();
if (subs)
type = type.subst(subs);
if (subs) {
if (auto *genericFnType = type->getAs<GenericFunctionType>())
type = genericFnType->substGenericArgs(subs);
else
type = type.subst(subs);
}

bool shouldReturnTrue = false;
diagnoseUnsafeType(ctx, loc, type, [&](Type unsafeType) {
Expand Down
Loading