Skip to content

Commit aea76ca

Browse files
committed
[TypeChecker] Extend type inference from default expressions to handle nested generic parameters
Adds support for parameter types like `[T?]` or `[(T, U?)]`, and relaxes restriction on same-type generic parameters. A same-type requirement is acceptable if it only includes in-scope generic parameters and concrete types i.e. `T.X == Int` if accepted if `T` is referenced only by a parameter default expression is being applied to.
1 parent a6f86c4 commit aea76ca

File tree

5 files changed

+341
-79
lines changed

5 files changed

+341
-79
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6254,7 +6254,12 @@ ERROR(cannot_default_generic_parameter_inferrable_from_another_parameter, none,
62546254

62556255
ERROR(cannot_default_generic_parameter_inferrable_through_same_type, none,
62566256
"cannot use default expression for inference of %0 because it "
6257-
"is inferrable through same-type requirement: %1",
6257+
"is inferrable through same-type requirement: '%1'",
6258+
(Type, StringRef))
6259+
6260+
ERROR(cannot_default_generic_parameter_invalid_requirement, none,
6261+
"cannot use default expression for inference of %0 because "
6262+
"requirement '%1' refers to other generic parameters",
62586263
(Type, StringRef))
62596264

62606265
#define UNDEFINE_DIAGNOSTIC_MACROS

lib/Sema/CSSimplify.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1632,13 +1632,13 @@ static ConstraintSystem::TypeMatchResult matchCallArguments(
16321632
if (parameterBindings[paramIdx].empty()) {
16331633
auto &ctx = cs.getASTContext();
16341634

1635-
if (paramTy->isTypeVariableOrMember() &&
1635+
if (paramTy->hasTypeVariable() &&
16361636
ctx.TypeCheckerOpts.EnableTypeInferenceFromDefaultArguments) {
16371637
auto *paramList = getParameterList(callee);
16381638
auto defaultExprType = paramList->get(paramIdx)->getTypeOfDefaultExpr();
16391639

16401640
// A caller side default.
1641-
if (!defaultExprType)
1641+
if (!defaultExprType || defaultExprType->hasError())
16421642
continue;
16431643

16441644
// If this is just a regular default type that works

lib/Sema/TypeCheckConstraints.cpp

Lines changed: 113 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
441441
isAutoClosure ? CTP_AutoclosureDefaultParameter : CTP_DefaultParameter,
442442
paramType, /*isDiscarded=*/false);
443443

444+
auto paramInterfaceTy = paramType->mapTypeOutOfContext();
445+
444446
{
445447
// Buffer all of the diagnostics produced by \c typeCheckExpression
446448
// since in some cases we need to try type-checking again with a
@@ -459,6 +461,11 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
459461
if (!ctx.TypeCheckerOpts.EnableTypeInferenceFromDefaultArguments)
460462
return Type();
461463

464+
// Parameter type doesn't have any generic parameters mentioned
465+
// in it, so there is nothing to infer.
466+
if (!paramInterfaceTy->hasTypeParameter())
467+
return Type();
468+
462469
// Ignore any diagnostics emitted by the original type-check.
463470
diagnostics.abort();
464471
}
@@ -475,40 +482,76 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
475482
// If both of aforementioned conditions are true, let's attempt
476483
// to open generic parameter and infer the type of this default
477484
// expression.
478-
auto interfaceType = paramType->mapTypeOutOfContext();
479-
if (!interfaceType->isTypeParameter())
480-
return Type();
485+
OpenedTypeMap genericParameters;
486+
487+
ConstraintSystemOptions options;
488+
options |= ConstraintSystemFlags::AllowFixes;
489+
490+
ConstraintSystem cs(DC, options);
491+
492+
auto *locator = cs.getConstraintLocator(
493+
defaultValue, LocatorPathElt::ContextualType(
494+
defaultExprTarget.getExprContextualTypePurpose()));
481495

482-
auto containsType = [&](Type type, Type contained) {
483-
return type.findIf(
484-
[&contained](Type nested) { return nested->isEqual(contained); });
496+
auto getCanonicalGenericParamTy = [](GenericTypeParamType *GP) {
497+
return cast<GenericTypeParamType>(GP->getCanonicalType());
485498
};
486499

487-
// Anchor of this default expression.
500+
// Find and open all of the generic parameters used by the parameter
501+
// and replace them with type variables.
502+
auto contextualTy = paramInterfaceTy.transform([&](Type type) -> Type {
503+
assert(!type->is<UnboundGenericType>());
504+
505+
if (auto *GP = type->getAs<GenericTypeParamType>()) {
506+
return cs.openGenericParameter(DC->getParent(), GP, genericParameters,
507+
locator);
508+
}
509+
return type;
510+
});
511+
512+
auto containsTypes = [&](Type type, OpenedTypeMap &toFind) {
513+
return type.findIf([&](Type nested) {
514+
if (auto *GP = nested->getAs<GenericTypeParamType>())
515+
return toFind.count(getCanonicalGenericParamTy(GP)) > 0;
516+
return false;
517+
});
518+
};
519+
520+
auto containsGenericParamsExcluding = [&](Type type,
521+
OpenedTypeMap &exclusions) -> bool {
522+
return type.findIf([&](Type type) {
523+
if (auto *GP = type->getAs<GenericTypeParamType>())
524+
return !exclusions.count(getCanonicalGenericParamTy(GP));
525+
return false;
526+
});
527+
};
528+
529+
// Anchor of this default expression i.e. function, subscript
530+
// or enum case.
488531
auto *anchor = cast<ValueDecl>(DC->getParent()->getAsDecl());
489532

490-
// Check whether generic parameter is only mentioned once in
533+
// Check whether generic parameters are only mentioned once in
491534
// the anchor's signature.
492535
{
493536
auto anchorTy = anchor->getInterfaceType()->castTo<GenericFunctionType>();
494537

495-
// Reject if generic parameter could be inferred from result type.
496-
if (containsType(anchorTy->getResult(), interfaceType)) {
538+
// Reject if generic parameters could be inferred from result type.
539+
if (containsTypes(anchorTy->getResult(), genericParameters)) {
497540
ctx.Diags.diagnose(
498541
defaultValue->getLoc(),
499542
diag::cannot_default_generic_parameter_inferrable_from_result,
500-
interfaceType);
543+
paramInterfaceTy);
501544
return Type();
502545
}
503546

504-
// Reject if generic parameter is used in multiple different positions
547+
// Reject if generic parameters are used in multiple different positions
505548
// in the parameter list.
506549

507550
llvm::SmallVector<unsigned, 2> affectedParams;
508551
for (unsigned i : indices(anchorTy->getParams())) {
509552
const auto &param = anchorTy->getParams()[i];
510553

511-
if (containsType(param.getPlainType(), interfaceType))
554+
if (containsTypes(param.getPlainType(), genericParameters))
512555
affectedParams.push_back(i);
513556
}
514557

@@ -524,27 +567,14 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
524567
defaultValue->getLoc(),
525568
diag::
526569
cannot_default_generic_parameter_inferrable_from_another_parameter,
527-
interfaceType, params.str());
570+
paramInterfaceTy, params.str());
528571
return Type();
529572
}
530573
}
531574

532575
auto signature = DC->getGenericSignatureOfContext();
533576
assert(signature && "generic parameter without signature?");
534577

535-
ConstraintSystemOptions options;
536-
options |= ConstraintSystemFlags::AllowFixes;
537-
538-
ConstraintSystem cs(DC, options);
539-
540-
auto *locator = cs.getConstraintLocator(
541-
defaultValue, LocatorPathElt::ContextualType(
542-
defaultExprTarget.getExprContextualTypePurpose()));
543-
544-
// A replacement for generic parameter type to associate any generic
545-
// requirements with.
546-
auto *contextualTy = cs.createTypeVariable(locator, /*flags=*/0);
547-
548578
auto *requirementBaseLocator = cs.getConstraintLocator(
549579
locator, LocatorPathElt::OpenedGeneric(signature));
550580

@@ -553,76 +583,84 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
553583
// a dependent member type), that means it could be inferred through
554584
// them e.g. `T: X.Y` or `T == U`.
555585
{
556-
auto isViable = [](Type type) {
557-
return !(type->hasTypeParameter() && type->hasDependentMember());
558-
};
559-
560586
auto recordRequirement = [&](unsigned index, Requirement requirement,
561587
ConstraintLocator *locator) {
562588
cs.openGenericRequirement(DC->getParent(), index, requirement,
563589
/*skipSelfProtocolConstraint=*/false, locator,
564-
[](Type type) -> Type { return type; });
590+
[&](Type type) -> Type {
591+
return cs.openType(type, genericParameters);
592+
});
593+
};
594+
595+
auto diagnoseInvalidRequirement = [&](Requirement requirement) {
596+
SmallString<32> reqBuf;
597+
llvm::raw_svector_ostream req(reqBuf);
598+
599+
requirement.print(req, PrintOptions());
600+
601+
ctx.Diags.diagnose(
602+
defaultValue->getLoc(),
603+
diag::cannot_default_generic_parameter_invalid_requirement,
604+
paramInterfaceTy, req.str());
565605
};
566606

567607
auto requirements = signature.getRequirements();
568608
for (unsigned reqIdx = 0; reqIdx != requirements.size(); ++reqIdx) {
569609
auto &requirement = requirements[reqIdx];
570610

571611
switch (requirement.getKind()) {
572-
case RequirementKind::Conformance: {
573-
if (!requirement.getFirstType()->isEqual(interfaceType))
574-
continue;
575-
576-
recordRequirement(reqIdx,
577-
{RequirementKind::Conformance, contextualTy,
578-
requirement.getSecondType()},
579-
requirementBaseLocator);
580-
break;
581-
}
612+
case RequirementKind::SameType: {
613+
auto lhsTy = requirement.getFirstType();
614+
auto rhsTy = requirement.getSecondType();
582615

583-
case RequirementKind::Superclass: {
584-
auto subclassTy = requirement.getFirstType();
585-
auto superclassTy = requirement.getSecondType();
616+
// Unrelated requirement.
617+
if (!containsTypes(lhsTy, genericParameters) &&
618+
!containsTypes(rhsTy, genericParameters))
619+
continue;
586620

587-
if (subclassTy->isEqual(interfaceType) && isViable(superclassTy)) {
588-
recordRequirement(
589-
reqIdx, {RequirementKind::Superclass, contextualTy, superclassTy},
590-
requirementBaseLocator);
621+
// Allow a subset of generic same-type requirements that only mention
622+
// "in scope" generic parameters e.g. `T.X == Int` or `T == U.Z`
623+
if (!containsGenericParamsExcluding(lhsTy, genericParameters) &&
624+
!containsGenericParamsExcluding(rhsTy, genericParameters)) {
625+
recordRequirement(reqIdx, requirement, requirementBaseLocator);
626+
continue;
591627
}
592628

593-
break;
594-
}
595-
596-
case RequirementKind::SameType: {
597-
// If there is a same-type constraint that involves our parameter
598-
// type, fail the type-check since the type could be inferred
599-
// through other positions.
600-
if (containsType(requirement.getFirstType(), interfaceType) ||
601-
containsType(requirement.getSecondType(), interfaceType)) {
602-
SmallString<32> reqBuf;
603-
llvm::raw_svector_ostream req(reqBuf);
604-
605-
requirement.print(req, PrintOptions());
606-
607-
ctx.Diags.diagnose(
608-
defaultValue->getLoc(),
609-
diag::
610-
cannot_default_generic_parameter_inferrable_through_same_type,
611-
interfaceType, req.str());
629+
// If there is a same-type constraint that involves out of scope
630+
// generic parameters mixed with in-scope ones, fail the type-check
631+
// since the type could be inferred through other positions.
632+
{
633+
diagnoseInvalidRequirement(requirement);
612634
return Type();
613635
}
614-
615-
continue;
616636
}
617637

638+
case RequirementKind::Conformance:
639+
case RequirementKind::Superclass:
618640
case RequirementKind::Layout:
619-
if (!requirement.getFirstType()->isEqual(interfaceType))
641+
auto adheringTy = requirement.getFirstType();
642+
643+
// Unrelated requirement.
644+
if (!containsTypes(adheringTy, genericParameters))
620645
continue;
621646

622-
recordRequirement(reqIdx,
623-
{RequirementKind::Layout, contextualTy,
624-
requirement.getLayoutConstraint()},
625-
requirementBaseLocator);
647+
// If adhering type has a mix or in- and out-of-scope parameters
648+
// mentioned we need to diagnose.
649+
if (containsGenericParamsExcluding(adheringTy, genericParameters)) {
650+
diagnoseInvalidRequirement(requirement);
651+
return Type();
652+
}
653+
654+
if (requirement.getKind() == RequirementKind::Superclass) {
655+
auto superclassTy = requirement.getSecondType();
656+
657+
if (containsGenericParamsExcluding(superclassTy, genericParameters)) {
658+
diagnoseInvalidRequirement(requirement);
659+
return Type();
660+
}
661+
}
662+
663+
recordRequirement(reqIdx, requirement, requirementBaseLocator);
626664
break;
627665
}
628666
}

test/Constraints/type_inference_from_default_exprs.swift

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ protocol P {
1616
}
1717

1818
func testInferFromSameType<T, U: P>(_: T = 42, _: [U]) where T == U.X {}
19-
// expected-error@-1 {{cannot use default expression for inference of 'T' because it is inferrable through same-type requirement: T == U.X}}
19+
// expected-error@-1 {{cannot use default expression for inference of 'T' because requirement 'T == U.X' refers to other generic parameters}}
2020

2121
func test1<T>(_: T = 42) {} // Ok
2222

@@ -60,6 +60,46 @@ extension S {
6060
}
6161
}
6262

63+
// In nested positions
64+
func testNested1<T>(_: [T] = [0, 1.0]) {} // Ok (T == Double)
65+
func testNested2<T>(_: T? = 42.0) {} // Ok
66+
func testNested2NoInference<T>(_: T? = nil) {} // Ok (old semantics)
67+
// expected-note@-1 {{in call to function 'testNested2NoInference'}}
68+
69+
struct D : P {
70+
typealias X = B
71+
}
72+
73+
func testNested3<T: P>(_: T = B()) where T.X == String {}
74+
func testNested4<T: P>(_: T = B()) where T.X == Int {}
75+
// expected-error@-1 {{global function 'testNested4' requires the types 'B.X' (aka 'String') and 'Int' be equivalent}}
76+
// expected-note@-2 {{where 'T.X' = 'B.X' (aka 'String')}}
77+
78+
func testNested5<T: P>(_: [T]? = [D()]) where T.X: P, T.X: AnyObject {}
79+
80+
func testNested5Invalid<T: P>(_: [T]? = [B()]) where T.X: P, T.X: AnyObject {}
81+
// expected-error@-1 {{global function 'testNested5Invalid' requires that 'B.X' (aka 'String') conform to 'P'}}
82+
// expected-error@-2 {{global function 'testNested5Invalid' requires that 'B.X' (aka 'String') be a class type}}
83+
// expected-note@-3 2 {{where 'T.X' = 'B.X' (aka 'String')}}
84+
// expected-note@-4 {{in call to function 'testNested5Invalid'}}
85+
86+
func testNested6<T: P, U>(_: (a: [T?], b: U) = (a: [D()], b: B())) where T.X == U, T.X: P, U: AnyObject { // Ok
87+
}
88+
89+
// Generic requirements
90+
91+
class GenClass<T> {}
92+
93+
func testReq1<T, U>(_: T = B(), _: U) where T: GenClass<U> {}
94+
// expected-error@-1 {{cannot use default expression for inference of 'T' because requirement 'T : GenClass<U>' refers to other generic parameters}}
95+
96+
class E : GenClass<B> {
97+
}
98+
99+
func testReq2<T, U>(_: (T, U) = (E(), B())) where T: GenClass<U>, U: AnyObject {} // Ok
100+
101+
func testReq3<T: P, U>(_: [T?] = [B()], _: U) where T.X == U {}
102+
// expected-error@-1 {{cannot use default expression for inference of '[T?]' because requirement 'U == T.X' refers to other generic parameters}}
63103

64104
func main() {
65105
test1() // Ok
@@ -78,4 +118,15 @@ func main() {
78118

79119
_ = S()[] // Ok
80120
_ = S()[B()] // Ok
121+
122+
testNested1() // Ok
123+
testNested2() // Ok
124+
testNested2NoInference() // expected-error {{generic parameter 'T' could not be inferred}}
125+
126+
testNested3() // Ok
127+
testNested5() // Ok
128+
testNested5Invalid() // expected-error {{generic parameter 'T' could not be inferred}}
129+
testNested6() // Ok
130+
131+
testReq2() // Ok
81132
}

0 commit comments

Comments
 (0)