@@ -2508,7 +2508,14 @@ void swift::performAbstractFuncDeclDiagnostics(TypeChecker &TC,
2508
2508
2509
2509
// / Diagnose C style for loops.
2510
2510
2511
- static Expr *endConditionValueForConvertingCStyleForLoop (const ForStmt *FS, VarDecl *loopVar) {
2511
+ enum OperatorKind {
2512
+ Greater,
2513
+ Smaller,
2514
+ NEqual,
2515
+ };
2516
+
2517
+ static Expr *endConditionValueForConvertingCStyleForLoop (const ForStmt *FS,
2518
+ VarDecl *loopVar, OperatorKind &OpKind) {
2512
2519
auto *Cond = FS->getCond ().getPtrOrNull ();
2513
2520
if (!Cond)
2514
2521
return nullptr ;
@@ -2527,8 +2534,15 @@ static Expr *endConditionValueForConvertingCStyleForLoop(const ForStmt *FS, VarD
2527
2534
2528
2535
// Verify that the condition is a simple != or < comparison to the loop variable.
2529
2536
auto comparisonOpName = binaryFuncExpr->getDecl ()->getNameStr ();
2530
- if (comparisonOpName != " !=" && comparisonOpName != " <" )
2537
+ if (comparisonOpName == " !=" )
2538
+ OpKind = OperatorKind::NEqual;
2539
+ else if (comparisonOpName == " <" )
2540
+ OpKind = OperatorKind::Smaller;
2541
+ else if (comparisonOpName == " >" )
2542
+ OpKind = OperatorKind::Greater;
2543
+ else
2531
2544
return nullptr ;
2545
+
2532
2546
auto args = binaryExpr->getArg ()->getElements ();
2533
2547
auto loadExpr = dyn_cast<LoadExpr>(args[0 ]);
2534
2548
if (!loadExpr)
@@ -2541,7 +2555,9 @@ static Expr *endConditionValueForConvertingCStyleForLoop(const ForStmt *FS, VarD
2541
2555
return args[1 ];
2542
2556
}
2543
2557
2544
- static bool unaryIncrementForConvertingCStyleForLoop (const ForStmt *FS, VarDecl *loopVar) {
2558
+ static bool unaryOperatorCheckForConvertingCStyleForLoop (const ForStmt *FS,
2559
+ VarDecl *loopVar,
2560
+ StringRef OpContent) {
2545
2561
auto *Increment = FS->getIncrement ().getPtrOrNull ();
2546
2562
if (!Increment)
2547
2563
return false ;
@@ -2552,19 +2568,33 @@ static bool unaryIncrementForConvertingCStyleForLoop(const ForStmt *FS, VarDecl
2552
2568
return false ;
2553
2569
auto inoutExpr = dyn_cast<InOutExpr>(unaryExpr->getArg ());
2554
2570
if (!inoutExpr)
2555
- return false ;
2571
+ return false ;
2556
2572
auto incrementDeclRefExpr = dyn_cast<DeclRefExpr>(inoutExpr->getSubExpr ());
2557
2573
if (!incrementDeclRefExpr)
2558
2574
return false ;
2559
2575
auto unaryFuncExpr = dyn_cast<DeclRefExpr>(unaryExpr->getFn ());
2560
2576
if (!unaryFuncExpr)
2561
2577
return false ;
2562
- if (unaryFuncExpr->getDecl ()->getNameStr () != " ++ " )
2578
+ if (unaryFuncExpr->getDecl ()->getNameStr () != OpContent )
2563
2579
return false ;
2564
- return incrementDeclRefExpr->getDecl () == loopVar;
2580
+ return incrementDeclRefExpr->getDecl () == loopVar;
2581
+ }
2582
+
2583
+
2584
+ static bool unaryIncrementForConvertingCStyleForLoop (const ForStmt *FS,
2585
+ VarDecl *loopVar) {
2586
+ return unaryOperatorCheckForConvertingCStyleForLoop (FS, loopVar, " ++" );
2587
+ }
2588
+
2589
+ static bool unaryDecrementForConvertingCStyleForLoop (const ForStmt *FS,
2590
+ VarDecl *loopVar) {
2591
+ return unaryOperatorCheckForConvertingCStyleForLoop (FS, loopVar, " --" );
2565
2592
}
2566
2593
2567
- static bool plusEqualOneIncrementForConvertingCStyleForLoop (TypeChecker &TC, const ForStmt *FS, VarDecl *loopVar) {
2594
+ static bool binaryOperatorCheckForConvertingCStyleForLoop (TypeChecker &TC,
2595
+ const ForStmt *FS,
2596
+ VarDecl *loopVar,
2597
+ StringRef OpContent) {
2568
2598
auto *Increment = FS->getIncrement ().getPtrOrNull ();
2569
2599
if (!Increment)
2570
2600
return false ;
@@ -2574,7 +2604,7 @@ static bool plusEqualOneIncrementForConvertingCStyleForLoop(TypeChecker &TC, con
2574
2604
auto binaryFuncExpr = dyn_cast<DeclRefExpr>(binaryExpr->getFn ());
2575
2605
if (!binaryFuncExpr)
2576
2606
return false ;
2577
- if (binaryFuncExpr->getDecl ()->getNameStr () != " += " )
2607
+ if (binaryFuncExpr->getDecl ()->getNameStr () != OpContent )
2578
2608
return false ;
2579
2609
auto argTupleExpr = dyn_cast<TupleExpr>(binaryExpr->getArg ());
2580
2610
if (!argTupleExpr)
@@ -2595,6 +2625,19 @@ static bool plusEqualOneIncrementForConvertingCStyleForLoop(TypeChecker &TC, con
2595
2625
if (!declRefExpr)
2596
2626
return false ;
2597
2627
return declRefExpr->getDecl () == loopVar;
2628
+
2629
+ }
2630
+
2631
+ static bool plusEqualOneIncrementForConvertingCStyleForLoop (TypeChecker &TC,
2632
+ const ForStmt *FS,
2633
+ VarDecl *loopVar) {
2634
+ return binaryOperatorCheckForConvertingCStyleForLoop (TC, FS, loopVar, " +=" );
2635
+ }
2636
+
2637
+ static bool minusEqualOneDecrementForConvertingCStyleForLoop (TypeChecker &TC,
2638
+ const ForStmt *FS,
2639
+ VarDecl *loopVar) {
2640
+ return binaryOperatorCheckForConvertingCStyleForLoop (TC, FS, loopVar, " -=" );
2598
2641
}
2599
2642
2600
2643
static void checkCStyleForLoop (TypeChecker &TC, const ForStmt *FS) {
@@ -2616,13 +2659,18 @@ static void checkCStyleForLoop(TypeChecker &TC, const ForStmt *FS) {
2616
2659
2617
2660
VarDecl *loopVar = dyn_cast<VarDecl>(initializers[1 ]);
2618
2661
Expr *startValue = loopVarDecl->getInit (0 );
2619
- Expr *endValue = endConditionValueForConvertingCStyleForLoop (FS, loopVar);
2662
+ OperatorKind OpKind;
2663
+ Expr *endValue = endConditionValueForConvertingCStyleForLoop (FS, loopVar, OpKind);
2620
2664
bool strideByOne = unaryIncrementForConvertingCStyleForLoop (FS, loopVar) ||
2621
2665
plusEqualOneIncrementForConvertingCStyleForLoop (TC, FS, loopVar);
2666
+ bool strideBackByOne = unaryDecrementForConvertingCStyleForLoop (FS, loopVar) ||
2667
+ minusEqualOneDecrementForConvertingCStyleForLoop (TC, FS, loopVar);
2622
2668
2623
- if (!loopVar || !startValue || !endValue || !strideByOne)
2669
+ if (!loopVar || !startValue || !endValue || ( !strideByOne && !strideBackByOne) )
2624
2670
return ;
2625
-
2671
+
2672
+ assert (strideBackByOne != strideByOne && " cannot be both increment and decrement." );
2673
+
2626
2674
// Verify that the loop variable is invariant inside the body.
2627
2675
VarDeclUsageChecker checker (TC, loopVar);
2628
2676
checker.suppressDiagnostics ();
@@ -2639,13 +2687,29 @@ static void checkCStyleForLoop(TypeChecker &TC, const ForStmt *FS) {
2639
2687
SourceLoc endOfIncrementLoc =
2640
2688
Lexer::getLocForEndOfToken (TC.Context .SourceMgr ,
2641
2689
FS->getIncrement ().getPtrOrNull ()->getEndLoc ());
2642
-
2643
- diagnostic
2644
- .fixItRemoveChars (loopVarDecl->getLoc (), loopVar->getLoc ())
2645
- .fixItReplaceChars (loopPatternEnd, startValue->getStartLoc (), " in " )
2646
- .fixItReplaceChars (FS->getFirstSemicolonLoc (), endValue->getStartLoc (),
2647
- " ..< " )
2648
- .fixItRemoveChars (FS->getSecondSemicolonLoc (), endOfIncrementLoc);
2690
+
2691
+ if (strideByOne && OpKind != OperatorKind::Greater) {
2692
+ diagnostic
2693
+ .fixItRemoveChars (loopVarDecl->getLoc (), loopVar->getLoc ())
2694
+ .fixItReplaceChars (loopPatternEnd, startValue->getStartLoc (), " in " )
2695
+ .fixItReplaceChars (FS->getFirstSemicolonLoc (), endValue->getStartLoc (),
2696
+ " ..< " )
2697
+ .fixItRemoveChars (FS->getSecondSemicolonLoc (), endOfIncrementLoc);
2698
+ return ;
2699
+ } else if (strideBackByOne && OpKind != OperatorKind::Smaller) {
2700
+ SourceLoc startValueEnd = Lexer::getLocForEndOfToken (TC.Context .SourceMgr ,
2701
+ startValue->getEndLoc ());
2702
+
2703
+ StringRef endValueStr = CharSourceRange (TC.Context .SourceMgr , endValue->getStartLoc (),
2704
+ Lexer::getLocForEndOfToken (TC.Context .SourceMgr , endValue->getEndLoc ())).str ();
2705
+
2706
+ diagnostic
2707
+ .fixItRemoveChars (loopVarDecl->getLoc (), loopVar->getLoc ())
2708
+ .fixItReplaceChars (loopPatternEnd, startValue->getStartLoc (), " in " )
2709
+ .fixItInsert (startValue->getStartLoc (), (llvm::Twine (" ((" ) + endValueStr + " + 1)..." ).str ())
2710
+ .fixItInsert (startValueEnd, " ).reversed()" )
2711
+ .fixItRemoveChars (FS->getFirstSemicolonLoc (), endOfIncrementLoc);
2712
+ }
2649
2713
}
2650
2714
2651
2715
0 commit comments