Skip to content

Sema: Clean up and optimize interface type opening #77683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -1198,9 +1198,6 @@ struct Score {
/// variable.
using OpenedType = std::pair<GenericTypeParamType *, TypeVariableType *>;

using OpenedTypeMap =
llvm::DenseMap<GenericTypeParamType *, TypeVariableType *>;

/// Describes the information about a case label item that needs to be tracked
/// within the constraint system.
struct CaseLabelItemInfo {
Expand Down Expand Up @@ -4236,7 +4233,7 @@ class ConstraintSystem {
/// corresponding opened type variables.
///
/// \returns The opened type, or \c type if there are no archetypes in it.
Type openType(Type type, OpenedTypeMap &replacements,
Type openType(Type type, ArrayRef<OpenedType> replacements,
ConstraintLocatorBuilder locator);

/// "Open" an opaque archetype type, similar to \c openType.
Expand All @@ -4247,7 +4244,7 @@ class ConstraintSystem {
/// opening its pattern and shape types and connecting them to the
/// aforementioned variable via special constraints.
Type openPackExpansionType(PackExpansionType *expansion,
OpenedTypeMap &replacements,
ArrayRef<OpenedType> replacements,
ConstraintLocatorBuilder locator);

/// Update OpenedPackExpansionTypes and record a change in the trail.
Expand Down Expand Up @@ -4280,28 +4277,26 @@ class ConstraintSystem {
/// \returns The opened type, or \c type if there are no archetypes in it.
FunctionType *openFunctionType(AnyFunctionType *funcType,
ConstraintLocatorBuilder locator,
OpenedTypeMap &replacements,
SmallVectorImpl<OpenedType> &replacements,
DeclContext *outerDC);

/// Open the generic parameter list and its requirements,
/// creating type variables for each of the type parameters.
void openGeneric(DeclContext *outerDC,
GenericSignature signature,
ConstraintLocatorBuilder locator,
OpenedTypeMap &replacements);
SmallVectorImpl<OpenedType> &replacements);

/// Open the generic parameter list creating type variables for each of the
/// type parameters.
void openGenericParameters(DeclContext *outerDC,
GenericSignature signature,
OpenedTypeMap &replacements,
SmallVectorImpl<OpenedType> &replacements,
ConstraintLocatorBuilder locator);

/// Open a generic parameter into a type variable and record
/// it in \c replacements.
TypeVariableType *openGenericParameter(DeclContext *outerDC,
GenericTypeParamType *parameter,
OpenedTypeMap &replacements,
TypeVariableType *openGenericParameter(GenericTypeParamType *parameter,
ConstraintLocatorBuilder locator);

/// Given generic signature open its generic requirements,
Expand All @@ -4328,7 +4323,7 @@ class ConstraintSystem {
/// Record the set of opened types for the given locator.
void recordOpenedTypes(
ConstraintLocatorBuilder locator,
const OpenedTypeMap &replacements,
SmallVectorImpl<OpenedType> &replacements,
bool fixmeAllowDuplicates=false);

/// Check whether the given type conforms to the given protocol and if
Expand All @@ -4340,7 +4335,7 @@ class ConstraintSystem {
FunctionType *adjustFunctionTypeForConcurrency(
FunctionType *fnType, Type baseType, ValueDecl *decl, DeclContext *dc,
unsigned numApplies, bool isMainDispatchQueue,
OpenedTypeMap &replacements, ConstraintLocatorBuilder locator);
ArrayRef<OpenedType> replacements, ConstraintLocatorBuilder locator);

/// Retrieve the type of a reference to the given value declaration.
///
Expand Down Expand Up @@ -4380,7 +4375,7 @@ class ConstraintSystem {
Type getMemberReferenceTypeFromOpenedType(
Type &openedType, Type baseObjTy, ValueDecl *value, DeclContext *outerDC,
ConstraintLocator *locator, bool hasAppliedSelf, bool isDynamicLookup,
OpenedTypeMap &replacements);
ArrayRef<OpenedType> replacements);

/// Retrieve the type of a reference to the given value declaration,
/// as a member with a base of the given type.
Expand All @@ -4396,7 +4391,7 @@ class ConstraintSystem {
DeclReferenceType getTypeOfMemberReference(
Type baseTy, ValueDecl *decl, DeclContext *useDC, bool isDynamicLookup,
FunctionRefKind functionRefKind, ConstraintLocator *locator,
OpenedTypeMap *replacements = nullptr);
SmallVectorImpl<OpenedType> *replacements = nullptr);

/// Retrieve a list of generic parameter types solver has "opened" (replaced
/// with a type variable) at the given location.
Expand Down
6 changes: 3 additions & 3 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Solution::computeSubstitutions(NullablePtr<ValueDecl> decl,
if (openedTypes == OpenedTypes.end())
return SubstitutionMap();

TypeSubstitutionMap subs;
SmallVector<Type, 4> replacementTypes;
for (const auto &opened : openedTypes->second) {
auto type = getFixedType(opened.second);
if (opened.first->isParameterPack()) {
Expand All @@ -115,7 +115,7 @@ Solution::computeSubstitutions(NullablePtr<ValueDecl> decl,
} else if (!type->is<PackType>())
type = PackType::getSingletonPackExpansion(type);
}
subs[opened.first] = type;
replacementTypes.push_back(type);
}

auto lookupConformanceFn =
Expand Down Expand Up @@ -145,7 +145,7 @@ Solution::computeSubstitutions(NullablePtr<ValueDecl> decl,
};

return SubstitutionMap::get(sig,
QueryTypeSubstitutionMap{subs},
replacementTypes,
lookupConformanceFn);
}

Expand Down
18 changes: 10 additions & 8 deletions lib/Sema/CSRanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,18 @@ static bool isProtocolExtensionAsSpecializedAs(DeclContext *dc1,
// Form a constraint system where we've opened up all of the requirements of
// the second protocol extension.
ConstraintSystem cs(dc1, std::nullopt);
OpenedTypeMap replacements;
SmallVector<OpenedType, 4> replacements;
cs.openGeneric(dc2, sig2, ConstraintLocatorBuilder(nullptr), replacements);

// Bind the 'Self' type from the first extension to the type parameter from
// opening 'Self' of the second extension.
Type selfType1 = sig1.getGenericParams()[0];
Type selfType2 = sig2.getGenericParams()[0];
ASSERT(selfType1->isEqual(selfType2));
ASSERT(replacements[0].first->isEqual(selfType2));

cs.addConstraint(ConstraintKind::Bind,
replacements[cast<GenericTypeParamType>(selfType2->getCanonicalType())],
replacements[0].second,
dc1->mapTypeIntoContext(selfType1),
nullptr);

Expand Down Expand Up @@ -578,7 +581,7 @@ bool CompareDeclSpecializationRequest::evaluate(

auto openType = [&](ConstraintSystem &cs, DeclContext *innerDC,
DeclContext *outerDC, Type type,
OpenedTypeMap &replacements,
SmallVectorImpl<OpenedType> &replacements,
ConstraintLocator *locator) -> Type {
if (auto *funcType = type->getAs<AnyFunctionType>()) {
return cs.openFunctionType(funcType, locator, replacements, outerDC);
Expand All @@ -596,12 +599,11 @@ bool CompareDeclSpecializationRequest::evaluate(
// FIXME: Locator when anchored on a declaration.
// Get the type of a reference to the second declaration.

OpenedTypeMap unused, replacements;
auto openedType2 = openType(cs, innerDC1, outerDC2, type2, unused, locator);
auto openedType1 =
openType(cs, innerDC2, outerDC1, type1, replacements, locator);
SmallVector<OpenedType, 4> unused, replacements;
auto openedType2 = openType(cs, innerDC2, outerDC2, type2, unused, locator);
auto openedType1 = openType(cs, innerDC1, outerDC1, type1, replacements, locator);

for (const auto &replacement : replacements) {
for (auto replacement : replacements) {
if (auto mapped = innerDC1->mapTypeIntoContext(replacement.first)) {
cs.addConstraint(ConstraintKind::Bind, replacement.second, mapped,
locator);
Expand Down
7 changes: 3 additions & 4 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11534,10 +11534,9 @@ static Type getOpenedResultBuilderTypeFor(ConstraintSystem &cs,
// Find the opened type for this callee and substitute in the type
// parameters.
auto substitutions = cs.getOpenedTypes(calleeLocator);
if (!substitutions.empty()) {
OpenedTypeMap replacements(substitutions.begin(), substitutions.end());
builderType = cs.openType(builderType, replacements, locator);
}
if (!substitutions.empty())
builderType = cs.openType(builderType, substitutions, locator);

assert(!builderType->hasTypeParameter());
}
return builderType;
Expand Down
53 changes: 29 additions & 24 deletions lib/Sema/TypeCheckConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
// If both of aforementioned conditions are true, let's attempt
// to open generic parameter and infer the type of this default
// expression.
OpenedTypeMap genericParameters;
SmallVector<OpenedType, 4> genericParameters;

ConstraintSystemOptions options;
options |= ConstraintSystemFlags::AllowFixes;
Expand All @@ -584,8 +584,13 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
defaultValue, LocatorPathElt::ContextualType(
defaultExprTarget.getExprContextualTypePurpose()));

auto getCanonicalGenericParamTy = [](GenericTypeParamType *GP) {
return cast<GenericTypeParamType>(GP->getCanonicalType());
auto findParam = [&](GenericTypeParamType *GP) -> TypeVariableType * {
for (auto pair : genericParameters) {
if (pair.first->isEqual(GP))
return pair.second;
}

return nullptr;
};

// Find and open all of the generic parameters used by the parameter
Expand All @@ -594,29 +599,29 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
assert(!type->is<UnboundGenericType>());

if (auto *GP = type->getAs<GenericTypeParamType>()) {
auto openedVar = genericParameters.find(getCanonicalGenericParamTy(GP));
if (openedVar != genericParameters.end()) {
return openedVar->second;
}
return cs.openGenericParameter(DC->getParent(), GP, genericParameters,
locator);
if (auto *typeVar = findParam(GP))
return typeVar;

auto *typeVar = cs.openGenericParameter(GP, locator);
genericParameters.emplace_back(GP, typeVar);

return typeVar;
}
return std::nullopt;
});

auto containsTypes = [&](Type type, OpenedTypeMap &toFind) {
return type.findIf([&](Type nested) {
auto containsTypes = [&](Type type) {
return type.findIf([&](Type nested) -> bool {
if (auto *GP = nested->getAs<GenericTypeParamType>())
return toFind.count(getCanonicalGenericParamTy(GP)) > 0;
return findParam(GP);
return false;
});
};

auto containsGenericParamsExcluding = [&](Type type,
OpenedTypeMap &exclusions) -> bool {
return type.findIf([&](Type type) {
auto containsGenericParamsExcluding = [&](Type type) -> bool {
return type.findIf([&](Type type) -> bool {
if (auto *GP = type->getAs<GenericTypeParamType>())
return !exclusions.count(getCanonicalGenericParamTy(GP));
return !findParam(GP);
return false;
});
};
Expand All @@ -637,7 +642,7 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
for (unsigned i : indices(anchorTy->getParams())) {
const auto &param = anchorTy->getParams()[i];

if (containsTypes(param.getPlainType(), genericParameters))
if (containsTypes(param.getPlainType()))
affectedParams.push_back(i);
}

Expand Down Expand Up @@ -704,8 +709,8 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
auto rhsTy = requirement.getSecondType();

// Unrelated requirement.
if (!containsTypes(lhsTy, genericParameters) &&
!containsTypes(rhsTy, genericParameters))
if (!containsTypes(lhsTy) &&
!containsTypes(rhsTy))
continue;

// If both sides are dependent members, that's okay because types
Expand All @@ -716,8 +721,8 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,

// Allow a subset of generic same-type requirements that only mention
// "in scope" generic parameters e.g. `T.X == Int` or `T == U.Z`
if (!containsGenericParamsExcluding(lhsTy, genericParameters) &&
!containsGenericParamsExcluding(rhsTy, genericParameters)) {
if (!containsGenericParamsExcluding(lhsTy) &&
!containsGenericParamsExcluding(rhsTy)) {
recordRequirement(reqIdx, requirement, requirementBaseLocator);
continue;
}
Expand All @@ -737,20 +742,20 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
auto adheringTy = requirement.getFirstType();

// Unrelated requirement.
if (!containsTypes(adheringTy, genericParameters))
if (!containsTypes(adheringTy))
continue;

// If adhering type has a mix or in- and out-of-scope parameters
// mentioned we need to diagnose.
if (containsGenericParamsExcluding(adheringTy, genericParameters)) {
if (containsGenericParamsExcluding(adheringTy)) {
diagnoseInvalidRequirement(requirement);
return Type();
}

if (requirement.getKind() == RequirementKind::Superclass) {
auto superclassTy = requirement.getSecondType();

if (containsGenericParamsExcluding(superclassTy, genericParameters)) {
if (containsGenericParamsExcluding(superclassTy)) {
diagnoseInvalidRequirement(requirement);
return Type();
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache,
// Open up the type of the requirement.
reqLocator =
cs->getConstraintLocator(req, ConstraintLocator::ProtocolRequirement);
OpenedTypeMap reqReplacements;
SmallVector<OpenedType, 4> reqReplacements;
reqType = cs->getTypeOfMemberReference(selfTy, req, dc,
/*isDynamicResult=*/false,
FunctionRefKind::DoubleApply,
Expand Down
Loading