@@ -2006,6 +2006,9 @@ void ContextualFailure::tryFixIts(InFlightDiagnostic &diagnostic) const {
2006
2006
if (tryIntegerCastFixIts (diagnostic))
2007
2007
return ;
2008
2008
2009
+ if (tryProtocolConformanceFixIt (diagnostic))
2010
+ return ;
2011
+
2009
2012
if (tryTypeCoercionFixIt (diagnostic))
2010
2013
return ;
2011
2014
}
@@ -2430,6 +2433,101 @@ bool ContextualFailure::tryTypeCoercionFixIt(
2430
2433
return false ;
2431
2434
}
2432
2435
2436
+ bool ContextualFailure::tryProtocolConformanceFixIt (
2437
+ InFlightDiagnostic &diagnostic) const {
2438
+ auto innermostTyCtx = getDC ()->getInnermostTypeContext ();
2439
+ if (!innermostTyCtx)
2440
+ return false ;
2441
+
2442
+ auto nominal = innermostTyCtx->getSelfNominalTypeDecl ();
2443
+ if (!nominal)
2444
+ return false ;
2445
+
2446
+ // We need to get rid of optionals and parens as it's not relevant when
2447
+ // printing the diagnostic and the fix-it.
2448
+ auto unwrappedToType =
2449
+ ToType->lookThroughAllOptionalTypes ()->getWithoutParens ();
2450
+
2451
+ // If the protocol requires a class & we don't have one (maybe the context
2452
+ // is a struct), then bail out instead of offering a broken fix-it later on.
2453
+ auto requiresClass = false ;
2454
+ if (unwrappedToType->isExistentialType ()) {
2455
+ if (auto protocolTy = unwrappedToType->getAs <ProtocolType>()) {
2456
+ requiresClass = protocolTy->requiresClass ();
2457
+ } else if (auto compositionTy =
2458
+ unwrappedToType->getAs <ProtocolCompositionType>()) {
2459
+ requiresClass = compositionTy->requiresClass ();
2460
+ }
2461
+ }
2462
+ if (requiresClass && !FromType->is <ClassType>()) {
2463
+ return false ;
2464
+ }
2465
+
2466
+ // We can only offer a fix-it if we're assigning to a protocol type and
2467
+ // the type we're assigning is the same as the innermost type context.
2468
+ bool shouldOfferFixIt = nominal->getSelfTypeInContext ()->isEqual (FromType) &&
2469
+ unwrappedToType->isExistentialType ();
2470
+ if (!shouldOfferFixIt)
2471
+ return false ;
2472
+
2473
+ diagnostic.flush ();
2474
+
2475
+ // Let's build a list of protocols that the contextual type does not
2476
+ // conform to. We will start by first checking if we have a protocol
2477
+ // composition type and add all the individual types that the context
2478
+ // does not conform to.
2479
+ SmallVector<std::string, 8 > missingProtoTypeStrings;
2480
+ if (auto compositionTy = unwrappedToType->getAs <ProtocolCompositionType>()) {
2481
+ for (auto memberTy : compositionTy->getMembers ()) {
2482
+ auto protocol = memberTy->getAnyNominal ()->getSelfProtocolDecl ();
2483
+ if (!getTypeChecker ().conformsToProtocol (
2484
+ FromType, protocol, getDC (),
2485
+ ConformanceCheckFlags::InExpression)) {
2486
+ missingProtoTypeStrings.push_back (memberTy->getString ());
2487
+ }
2488
+ }
2489
+
2490
+ // If we don't conform to all of the protocols in the composition, then
2491
+ // store the composition type only. This is because we need to append
2492
+ // 'Foo & Bar' instead of 'Foo, Bar' in order to match the written type.
2493
+ if (missingProtoTypeStrings.size () == compositionTy->getMembers ().size ()) {
2494
+ missingProtoTypeStrings = {compositionTy->getString ()};
2495
+ }
2496
+ }
2497
+
2498
+ // If we didn't have a protocol composition type, it means we only have a
2499
+ // single protocol, so just use it directly. Otherwise, construct a comma
2500
+ // separated list of missing types.
2501
+ std::string protoString;
2502
+ if (missingProtoTypeStrings.empty ()) {
2503
+ protoString = unwrappedToType->getString ();
2504
+ } else if (missingProtoTypeStrings.size () == 1 ) {
2505
+ protoString = missingProtoTypeStrings.front ();
2506
+ } else {
2507
+ protoString = llvm::join (missingProtoTypeStrings, " , " );
2508
+ }
2509
+
2510
+ // Emit a diagnostic to inform the user that they need to conform to the
2511
+ // missing protocols.
2512
+ //
2513
+ // TODO: Maybe also insert the requirement stubs?
2514
+ auto conformanceDiag = emitDiagnostic (
2515
+ getAnchor ()->getLoc (), diag::assign_protocol_conformance_fix_it,
2516
+ unwrappedToType, nominal->getDescriptiveKind (), FromType);
2517
+ if (nominal->getInherited ().size () > 0 ) {
2518
+ auto lastInherited = nominal->getInherited ().back ().getLoc ();
2519
+ auto lastInheritedEndLoc =
2520
+ Lexer::getLocForEndOfToken (getASTContext ().SourceMgr , lastInherited);
2521
+ conformanceDiag.fixItInsert (lastInheritedEndLoc, " , " + protoString);
2522
+ } else {
2523
+ auto nameEndLoc = Lexer::getLocForEndOfToken (getASTContext ().SourceMgr ,
2524
+ nominal->getNameLoc ());
2525
+ conformanceDiag.fixItInsert (nameEndLoc, " : " + protoString);
2526
+ }
2527
+
2528
+ return true ;
2529
+ }
2530
+
2433
2531
void ContextualFailure::tryComputedPropertyFixIts (Expr *expr) const {
2434
2532
if (!isa<ClosureExpr>(expr))
2435
2533
return ;
0 commit comments