@@ -441,6 +441,8 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
441
441
isAutoClosure ? CTP_AutoclosureDefaultParameter : CTP_DefaultParameter,
442
442
paramType, /* isDiscarded=*/ false );
443
443
444
+ auto paramInterfaceTy = paramType->mapTypeOutOfContext ();
445
+
444
446
{
445
447
// Buffer all of the diagnostics produced by \c typeCheckExpression
446
448
// since in some cases we need to try type-checking again with a
@@ -459,6 +461,11 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
459
461
if (!ctx.TypeCheckerOpts .EnableTypeInferenceFromDefaultArguments )
460
462
return Type ();
461
463
464
+ // Parameter type doesn't have any generic parameters mentioned
465
+ // in it, so there is nothing to infer.
466
+ if (!paramInterfaceTy->hasTypeParameter ())
467
+ return Type ();
468
+
462
469
// Ignore any diagnostics emitted by the original type-check.
463
470
diagnostics.abort ();
464
471
}
@@ -475,40 +482,76 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
475
482
// If both of aforementioned conditions are true, let's attempt
476
483
// to open generic parameter and infer the type of this default
477
484
// expression.
478
- auto interfaceType = paramType->mapTypeOutOfContext ();
479
- if (!interfaceType->isTypeParameter ())
480
- return Type ();
485
+ OpenedTypeMap genericParameters;
486
+
487
+ ConstraintSystemOptions options;
488
+ options |= ConstraintSystemFlags::AllowFixes;
489
+
490
+ ConstraintSystem cs (DC, options);
491
+
492
+ auto *locator = cs.getConstraintLocator (
493
+ defaultValue, LocatorPathElt::ContextualType (
494
+ defaultExprTarget.getExprContextualTypePurpose ()));
481
495
482
- auto containsType = [&](Type type, Type contained) {
483
- return type.findIf (
484
- [&contained](Type nested) { return nested->isEqual (contained); });
496
+ auto getCanonicalGenericParamTy = [](GenericTypeParamType *GP) {
497
+ return cast<GenericTypeParamType>(GP->getCanonicalType ());
485
498
};
486
499
487
- // Anchor of this default expression.
500
+ // Find and open all of the generic parameters used by the parameter
501
+ // and replace them with type variables.
502
+ auto contextualTy = paramInterfaceTy.transform ([&](Type type) -> Type {
503
+ assert (!type->is <UnboundGenericType>());
504
+
505
+ if (auto *GP = type->getAs <GenericTypeParamType>()) {
506
+ return cs.openGenericParameter (DC->getParent (), GP, genericParameters,
507
+ locator);
508
+ }
509
+ return type;
510
+ });
511
+
512
+ auto containsTypes = [&](Type type, OpenedTypeMap &toFind) {
513
+ return type.findIf ([&](Type nested) {
514
+ if (auto *GP = nested->getAs <GenericTypeParamType>())
515
+ return toFind.count (getCanonicalGenericParamTy (GP)) > 0 ;
516
+ return false ;
517
+ });
518
+ };
519
+
520
+ auto containsGenericParamsExcluding = [&](Type type,
521
+ OpenedTypeMap &exclusions) -> bool {
522
+ return type.findIf ([&](Type type) {
523
+ if (auto *GP = type->getAs <GenericTypeParamType>())
524
+ return !exclusions.count (getCanonicalGenericParamTy (GP));
525
+ return false ;
526
+ });
527
+ };
528
+
529
+ // Anchor of this default expression i.e. function, subscript
530
+ // or enum case.
488
531
auto *anchor = cast<ValueDecl>(DC->getParent ()->getAsDecl ());
489
532
490
- // Check whether generic parameter is only mentioned once in
533
+ // Check whether generic parameters are only mentioned once in
491
534
// the anchor's signature.
492
535
{
493
536
auto anchorTy = anchor->getInterfaceType ()->castTo <GenericFunctionType>();
494
537
495
- // Reject if generic parameter could be inferred from result type.
496
- if (containsType (anchorTy->getResult (), interfaceType )) {
538
+ // Reject if generic parameters could be inferred from result type.
539
+ if (containsTypes (anchorTy->getResult (), genericParameters )) {
497
540
ctx.Diags .diagnose (
498
541
defaultValue->getLoc (),
499
542
diag::cannot_default_generic_parameter_inferrable_from_result,
500
- interfaceType );
543
+ paramInterfaceTy );
501
544
return Type ();
502
545
}
503
546
504
- // Reject if generic parameter is used in multiple different positions
547
+ // Reject if generic parameters are used in multiple different positions
505
548
// in the parameter list.
506
549
507
550
llvm::SmallVector<unsigned , 2 > affectedParams;
508
551
for (unsigned i : indices (anchorTy->getParams ())) {
509
552
const auto ¶m = anchorTy->getParams ()[i];
510
553
511
- if (containsType (param.getPlainType (), interfaceType ))
554
+ if (containsTypes (param.getPlainType (), genericParameters ))
512
555
affectedParams.push_back (i);
513
556
}
514
557
@@ -524,27 +567,14 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
524
567
defaultValue->getLoc (),
525
568
diag::
526
569
cannot_default_generic_parameter_inferrable_from_another_parameter,
527
- interfaceType , params.str ());
570
+ paramInterfaceTy , params.str ());
528
571
return Type ();
529
572
}
530
573
}
531
574
532
575
auto signature = DC->getGenericSignatureOfContext ();
533
576
assert (signature && " generic parameter without signature?" );
534
577
535
- ConstraintSystemOptions options;
536
- options |= ConstraintSystemFlags::AllowFixes;
537
-
538
- ConstraintSystem cs (DC, options);
539
-
540
- auto *locator = cs.getConstraintLocator (
541
- defaultValue, LocatorPathElt::ContextualType (
542
- defaultExprTarget.getExprContextualTypePurpose ()));
543
-
544
- // A replacement for generic parameter type to associate any generic
545
- // requirements with.
546
- auto *contextualTy = cs.createTypeVariable (locator, /* flags=*/ 0 );
547
-
548
578
auto *requirementBaseLocator = cs.getConstraintLocator (
549
579
locator, LocatorPathElt::OpenedGeneric (signature));
550
580
@@ -553,76 +583,84 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue,
553
583
// a dependent member type), that means it could be inferred through
554
584
// them e.g. `T: X.Y` or `T == U`.
555
585
{
556
- auto isViable = [](Type type) {
557
- return !(type->hasTypeParameter () && type->hasDependentMember ());
558
- };
559
-
560
586
auto recordRequirement = [&](unsigned index, Requirement requirement,
561
587
ConstraintLocator *locator) {
562
588
cs.openGenericRequirement (DC->getParent (), index, requirement,
563
589
/* skipSelfProtocolConstraint=*/ false , locator,
564
- [](Type type) -> Type { return type; });
590
+ [&](Type type) -> Type {
591
+ return cs.openType (type, genericParameters);
592
+ });
593
+ };
594
+
595
+ auto diagnoseInvalidRequirement = [&](Requirement requirement) {
596
+ SmallString<32 > reqBuf;
597
+ llvm::raw_svector_ostream req (reqBuf);
598
+
599
+ requirement.print (req, PrintOptions ());
600
+
601
+ ctx.Diags .diagnose (
602
+ defaultValue->getLoc (),
603
+ diag::cannot_default_generic_parameter_invalid_requirement,
604
+ paramInterfaceTy, req.str ());
565
605
};
566
606
567
607
auto requirements = signature.getRequirements ();
568
608
for (unsigned reqIdx = 0 ; reqIdx != requirements.size (); ++reqIdx) {
569
609
auto &requirement = requirements[reqIdx];
570
610
571
611
switch (requirement.getKind ()) {
572
- case RequirementKind::Conformance: {
573
- if (!requirement.getFirstType ()->isEqual (interfaceType))
574
- continue ;
575
-
576
- recordRequirement (reqIdx,
577
- {RequirementKind::Conformance, contextualTy,
578
- requirement.getSecondType ()},
579
- requirementBaseLocator);
580
- break ;
581
- }
612
+ case RequirementKind::SameType: {
613
+ auto lhsTy = requirement.getFirstType ();
614
+ auto rhsTy = requirement.getSecondType ();
582
615
583
- case RequirementKind::Superclass: {
584
- auto subclassTy = requirement.getFirstType ();
585
- auto superclassTy = requirement.getSecondType ();
616
+ // Unrelated requirement.
617
+ if (!containsTypes (lhsTy, genericParameters) &&
618
+ !containsTypes (rhsTy, genericParameters))
619
+ continue ;
586
620
587
- if (subclassTy->isEqual (interfaceType) && isViable (superclassTy)) {
588
- recordRequirement (
589
- reqIdx, {RequirementKind::Superclass, contextualTy, superclassTy},
590
- requirementBaseLocator);
621
+ // Allow a subset of generic same-type requirements that only mention
622
+ // "in scope" generic parameters e.g. `T.X == Int` or `T == U.Z`
623
+ if (!containsGenericParamsExcluding (lhsTy, genericParameters) &&
624
+ !containsGenericParamsExcluding (rhsTy, genericParameters)) {
625
+ recordRequirement (reqIdx, requirement, requirementBaseLocator);
626
+ continue ;
591
627
}
592
628
593
- break ;
594
- }
595
-
596
- case RequirementKind::SameType: {
597
- // If there is a same-type constraint that involves our parameter
598
- // type, fail the type-check since the type could be inferred
599
- // through other positions.
600
- if (containsType (requirement.getFirstType (), interfaceType) ||
601
- containsType (requirement.getSecondType (), interfaceType)) {
602
- SmallString<32 > reqBuf;
603
- llvm::raw_svector_ostream req (reqBuf);
604
-
605
- requirement.print (req, PrintOptions ());
606
-
607
- ctx.Diags .diagnose (
608
- defaultValue->getLoc (),
609
- diag::
610
- cannot_default_generic_parameter_inferrable_through_same_type,
611
- interfaceType, req.str ());
629
+ // If there is a same-type constraint that involves out of scope
630
+ // generic parameters mixed with in-scope ones, fail the type-check
631
+ // since the type could be inferred through other positions.
632
+ {
633
+ diagnoseInvalidRequirement (requirement);
612
634
return Type ();
613
635
}
614
-
615
- continue ;
616
636
}
617
637
638
+ case RequirementKind::Conformance:
639
+ case RequirementKind::Superclass:
618
640
case RequirementKind::Layout:
619
- if (!requirement.getFirstType ()->isEqual (interfaceType))
641
+ auto adheringTy = requirement.getFirstType ();
642
+
643
+ // Unrelated requirement.
644
+ if (!containsTypes (adheringTy, genericParameters))
620
645
continue ;
621
646
622
- recordRequirement (reqIdx,
623
- {RequirementKind::Layout, contextualTy,
624
- requirement.getLayoutConstraint ()},
625
- requirementBaseLocator);
647
+ // If adhering type has a mix or in- and out-of-scope parameters
648
+ // mentioned we need to diagnose.
649
+ if (containsGenericParamsExcluding (adheringTy, genericParameters)) {
650
+ diagnoseInvalidRequirement (requirement);
651
+ return Type ();
652
+ }
653
+
654
+ if (requirement.getKind () == RequirementKind::Superclass) {
655
+ auto superclassTy = requirement.getSecondType ();
656
+
657
+ if (containsGenericParamsExcluding (superclassTy, genericParameters)) {
658
+ diagnoseInvalidRequirement (requirement);
659
+ return Type ();
660
+ }
661
+ }
662
+
663
+ recordRequirement (reqIdx, requirement, requirementBaseLocator);
626
664
break ;
627
665
}
628
666
}
0 commit comments