Skip to content

Commit 44a4727

Browse files
committed
[ConstraintSystem] Add KeyPathComponentTypes map
Avoids digging through the long list of type vars to find the ones related to key path components when dumping an expression.
1 parent 713a2ab commit 44a4727

File tree

4 files changed

+66
-22
lines changed

4 files changed

+66
-22
lines changed

lib/Sema/CSApply.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1698,9 +1698,12 @@ namespace {
16981698
}
16991699

17001700
assert(componentExpr);
1701-
componentExpr->setType(simplifyType(cs.getType(anchor)));
1701+
Type ty = simplifyType(cs.getType(anchor));
1702+
componentExpr->setType(ty);
17021703
cs.cacheType(componentExpr);
17031704

1705+
cs.setType(keyPath, 0, ty);
1706+
17041707
keyPath->setParsedPath(componentExpr);
17051708
keyPath->resolveComponents(ctx, {component});
17061709
return keyPath;
@@ -4358,6 +4361,13 @@ namespace {
43584361
Type baseTy = keyPathTy->getGenericArgs()[0];
43594362
Type leafTy = keyPathTy->getGenericArgs()[1];
43604363

4364+
// Updates the constraint system with the type of the last resolved
4365+
// component. We do it this way because we sometimes insert new
4366+
// components.
4367+
auto updateCSWithResolvedComponent = [&]() {
4368+
cs.setType(E, resolvedComponents.size() - 1, baseTy);
4369+
};
4370+
43614371
for (unsigned i : indices(E->getComponents())) {
43624372
auto &origComponent = E->getMutableComponents()[i];
43634373

@@ -4434,6 +4444,7 @@ namespace {
44344444
resolvedComponents.push_back(component);
44354445

44364446
if (shouldForceUnwrapResult(foundDecl->choice, locator)) {
4447+
updateCSWithResolvedComponent();
44374448
auto objectTy = getObjectType(baseTy);
44384449
auto loc = origComponent.getLoc();
44394450
component = KeyPathExpr::Component::forOptionalForce(objectTy, loc);
@@ -4461,6 +4472,7 @@ namespace {
44614472
resolvedComponents.push_back(component);
44624473

44634474
if (shouldForceUnwrapResult(foundDecl->choice, locator)) {
4475+
updateCSWithResolvedComponent();
44644476
auto objectTy = getObjectType(baseTy);
44654477
auto loc = origComponent.getLoc();
44664478
component = KeyPathExpr::Component::forOptionalForce(objectTy, loc);
@@ -4512,6 +4524,9 @@ namespace {
45124524
case KeyPathExpr::Component::Kind::TupleElement:
45134525
llvm_unreachable("already resolved");
45144526
}
4527+
4528+
// By now, "baseTy" is the result type of this component.
4529+
updateCSWithResolvedComponent();
45154530
}
45164531

45174532
// Wrap a non-optional result if there was chaining involved.
@@ -4524,6 +4539,7 @@ namespace {
45244539
auto component = KeyPathExpr::Component::forOptionalWrap(leafTy);
45254540
resolvedComponents.push_back(component);
45264541
baseTy = leafTy;
4542+
updateCSWithResolvedComponent();
45274543
}
45284544
E->resolveComponents(cs.getASTContext(), resolvedComponents);
45294545

lib/Sema/CSGen.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3054,6 +3054,10 @@ namespace {
30543054
case KeyPathExpr::Component::Kind::Identity:
30553055
continue;
30563056
}
3057+
3058+
// By now, `base` is the result type of this component. Set it in the
3059+
// constraint system so we can find it later.
3060+
CS.setType(E, i, base);
30573061
}
30583062

30593063
// If there was an optional chaining component, the end result must be

lib/Sema/ConstraintSystem.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,8 @@ class ConstraintSystem {
10651065
llvm::DenseMap<const Expr *, TypeBase *> ExprTypes;
10661066
llvm::DenseMap<const TypeLoc *, TypeBase *> TypeLocTypes;
10671067
llvm::DenseMap<const ParamDecl *, TypeBase *> ParamTypes;
1068+
llvm::DenseMap<std::pair<const KeyPathExpr *, unsigned>, TypeBase *>
1069+
KeyPathComponentTypes;
10681070

10691071
/// Maps closure parameters to type variables.
10701072
llvm::DenseMap<const ParamDecl *, TypeVariableType *>
@@ -1446,6 +1448,11 @@ class ConstraintSystem {
14461448
if (expr->getType())
14471449
CS.cacheType(expr);
14481450

1451+
if (auto kp = dyn_cast<KeyPathExpr>(expr))
1452+
for (auto i : indices(kp->getComponents()))
1453+
if (kp->getComponents()[i].getComponentType())
1454+
CS.cacheType(kp, i);
1455+
14491456
return expr;
14501457
}
14511458

@@ -1477,6 +1484,15 @@ class ConstraintSystem {
14771484
"Should not write type variable into expression!");
14781485
expr->setType(CS.getType(expr));
14791486

1487+
if (auto kp = dyn_cast<KeyPathExpr>(expr)) {
1488+
for (auto i : indices(kp->getComponents())) {
1489+
Type componentType;
1490+
if (CS.hasType(kp, i))
1491+
componentType = CS.getType(kp, i);
1492+
kp->getMutableComponents()[i].setComponentType(componentType);
1493+
}
1494+
}
1495+
14801496
return expr;
14811497
}
14821498

@@ -1757,6 +1773,12 @@ class ConstraintSystem {
17571773
ParamTypes[P] = T.getPointer();
17581774
}
17591775

1776+
void setType(KeyPathExpr *KP, unsigned I, Type T) {
1777+
assert(KP && "Expected non-null key path parameter!");
1778+
assert(T && "Expected non-null type!");
1779+
KeyPathComponentTypes[std::make_pair(KP, I)] = T.getPointer();
1780+
}
1781+
17601782
/// Check to see if we have a type for an expression.
17611783
bool hasType(const Expr *E) const {
17621784
assert(E != nullptr && "Expected non-null expression!");
@@ -1772,6 +1794,12 @@ class ConstraintSystem {
17721794
return ParamTypes.find(P) != ParamTypes.end();
17731795
}
17741796

1797+
bool hasType(const KeyPathExpr *KP, unsigned I) const {
1798+
assert(KP && "Expected non-null key path parameter!");
1799+
return KeyPathComponentTypes.find(std::make_pair(KP, I))
1800+
!= KeyPathComponentTypes.end();
1801+
}
1802+
17751803
/// Get the type for an expression.
17761804
Type getType(const Expr *E) const {
17771805
assert(hasType(E) && "Expected type to have been set!");
@@ -1800,6 +1828,11 @@ class ConstraintSystem {
18001828
return wantInterfaceType ? D->getInterfaceType() : D->getType();
18011829
}
18021830

1831+
Type getType(const KeyPathExpr *KP, unsigned I) const {
1832+
assert(hasType(KP, I) && "Expected type to have been set!");
1833+
return KeyPathComponentTypes.find(std::make_pair(KP, I))->second;
1834+
}
1835+
18031836
/// Cache the type of the expression argument and return that same
18041837
/// argument.
18051838
template <typename T>
@@ -1809,6 +1842,15 @@ class ConstraintSystem {
18091842
return E;
18101843
}
18111844

1845+
/// Cache the type of the expression argument and return that same
1846+
/// argument.
1847+
KeyPathExpr *cacheType(KeyPathExpr *E, unsigned I) {
1848+
auto componentTy = E->getComponents()[I].getComponentType();
1849+
assert(componentTy && "Expected a type!");
1850+
setType(E, I, componentTy);
1851+
return E;
1852+
}
1853+
18121854
void setContextualType(Expr *E, TypeLoc T, ContextualTypePurpose purpose) {
18131855
assert(E != nullptr && "Expected non-null expression!");
18141856
contextualTypeNode = E;

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3408,27 +3408,9 @@ void ConstraintSystem::print(raw_ostream &out, Expr *E) {
34083408
return Type();
34093409
};
34103410
auto getTypeOfKeyPathComponent =
3411-
[&](const KeyPathExpr *E, unsigned index) -> Type {
3412-
// CSGen attaches the type var for each key path component to the
3413-
// "key path component #n -> function result" locator. This is the natural
3414-
// place subscript result types will end up, and it's otherwise unused by
3415-
// other components.
3416-
auto componentLocator = getConstraintLocator(E,
3417-
ConstraintLocator::PathElement::getKeyPathComponent(index));
3418-
auto resultLocator = getConstraintLocator(componentLocator,
3419-
ConstraintLocator::FunctionResult);
3420-
3421-
// Find the first type variable with this locator and return it.
3422-
for (auto &typeVar : TypeVariables) {
3423-
if ((*typeVar)->getLocator() != resultLocator)
3424-
continue;
3425-
3426-
if (auto fixedTy = (*typeVar)->getFixedType(getSavedBindings()))
3427-
return fixedTy;
3428-
3429-
return typeVar;
3430-
}
3431-
3411+
[&](const KeyPathExpr *KP, unsigned I) -> Type {
3412+
if (hasType(KP, I))
3413+
return getType(KP, I);
34323414
return Type();
34333415
};
34343416

0 commit comments

Comments
 (0)