Skip to content

Commit 72070b7

Browse files
committed
[Async Refactoring] Bind known bool param in fallback
Generalize the logic to handle different BlockKinds, and add binding logic that lets us assign `true` or `false` to the given bool success param in the fallback case.
1 parent 00977af commit 72070b7

File tree

2 files changed

+127
-49
lines changed

2 files changed

+127
-49
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 103 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4916,6 +4916,11 @@ class ClassifiedBlock {
49164916
}
49174917
};
49184918

4919+
/// The type of block rewritten code may be placed in.
4920+
enum class BlockKind {
4921+
SUCCESS, ERROR, FALLBACK
4922+
};
4923+
49194924
/// A completion handler function parameter that is known to be a Bool flag
49204925
/// indicating success or failure.
49214926
struct KnownBoolFlagParam {
@@ -4937,7 +4942,7 @@ class ClosureCallbackParams final {
49374942
: HandlerDesc(HandlerDesc),
49384943
AllParams(Closure->getParameters()->getArray()) {
49394944
assert(AllParams.size() == HandlerDesc.params().size());
4940-
assert(!(HandlerDesc.Type == HandlerType::RESULT && AllParams.size() != 1));
4945+
assert(HandlerDesc.Type != HandlerType::RESULT || AllParams.size() == 1);
49414946

49424947
SuccessParams.insert(AllParams.begin(), AllParams.end());
49434948
if (HandlerDesc.HasError && HandlerDesc.Type == HandlerType::PARAMS)
@@ -4976,29 +4981,41 @@ class ClosureCallbackParams final {
49764981
return HandlerDesc.shouldUnwrap(Param->getType());
49774982
}
49784983

4984+
/// Whether \p Param is the known Bool parameter that indicates success or
4985+
/// failure.
4986+
bool isKnownBoolFlagParam(const ParamDecl *Param) const {
4987+
if (auto BoolFlag = getKnownBoolFlagParam())
4988+
return BoolFlag->Param == Param;
4989+
return false;
4990+
}
4991+
49794992
/// Whether \p Param is a closure parameter that has a binding available in
4980-
/// the async variant of the call, either as a thrown error, or a success
4981-
/// return value.
4982-
bool hasBinding(const ParamDecl *Param) const {
4983-
if (!hasParam(Param))
4984-
return false;
4985-
if (auto BoolFlag = getKnownBoolFlagParam()) {
4986-
if (Param == BoolFlag->Param)
4993+
/// the async variant of the call for a particular \p Block.
4994+
bool hasBinding(const ParamDecl *Param, BlockKind Block) const {
4995+
switch (Block) {
4996+
case BlockKind::SUCCESS:
4997+
// Known bool flags get dropped from the imported async variant.
4998+
if (isKnownBoolFlagParam(Param))
49874999
return false;
5000+
5001+
return isSuccessParam(Param);
5002+
case BlockKind::ERROR:
5003+
return Param == ErrParam;
5004+
case BlockKind::FALLBACK:
5005+
// We generally want to bind everything in the fallback case.
5006+
return hasParam(Param);
49885007
}
4989-
return true;
5008+
llvm_unreachable("Unhandled case in switch");
49905009
}
49915010

4992-
/// Retrieve the success parameters that have a binding in a call to the
4993-
/// async variant.
4994-
ArrayRef<const ParamDecl *>
4995-
getSuccessParamsToBind(SmallVectorImpl<const ParamDecl *> &Scratch) {
4996-
assert(Scratch.empty());
4997-
for (auto *Param : SuccessParams) {
4998-
if (hasBinding(Param))
4999-
Scratch.push_back(Param);
5011+
/// Retrieve the parameters to bind in a given \p Block.
5012+
TinyPtrVector<const ParamDecl *> getParamsToBind(BlockKind Block) {
5013+
TinyPtrVector<const ParamDecl *> Result;
5014+
for (auto *Param : AllParams) {
5015+
if (hasBinding(Param, Block))
5016+
Result.push_back(Param);
50005017
}
5001-
return Scratch;
5018+
return Result;
50025019
}
50035020

50045021
/// If there is a known Bool flag parameter indicating success or failure,
@@ -5248,7 +5265,7 @@ struct CallbackClassifier {
52485265
// Check to see if we have a known bool flag parameter that indicates
52495266
// success or failure.
52505267
if (auto KnownBoolFlag = Params.getKnownBoolFlagParam()) {
5251-
if (KnownBoolFlag->Param != Cond.Subject)
5268+
if (KnownBoolFlag->Param != SubjectParam)
52525269
return None;
52535270

52545271
// The path may need to be flipped depending on whether the flag indicates
@@ -6722,10 +6739,21 @@ class AsyncConverter : private SourceEntityWalker {
67226739
}
67236740

67246741
void addFallbackVars(ArrayRef<const ParamDecl *> FallbackParams,
6725-
ClassifiedBlocks &Blocks) {
6726-
for (auto Param : FallbackParams) {
6727-
OS << tok::kw_var << " " << newNameFor(Param) << ": ";
6742+
const ClosureCallbackParams &AllParams) {
6743+
for (auto *Param : FallbackParams) {
67286744
auto Ty = Param->getType();
6745+
auto ParamName = newNameFor(Param);
6746+
6747+
// If this is the known bool success param, we can use 'let' and type it
6748+
// as non-optional, as it gets bound in both blocks.
6749+
if (AllParams.isKnownBoolFlagParam(Param)) {
6750+
OS << tok::kw_let << " " << ParamName << ": ";
6751+
Ty->print(OS);
6752+
OS << "\n";
6753+
continue;
6754+
}
6755+
6756+
OS << tok::kw_var << " " << ParamName << ": ";
67296757
Ty->print(OS);
67306758
if (!Ty->getOptionalObjectType())
67316759
OS << "?";
@@ -7144,6 +7172,30 @@ class AsyncConverter : private SourceEntityWalker {
71447172
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
71457173
}
71467174

7175+
/// Add a binding to a known bool flag that indicates success or failure.
7176+
void addBoolFlagParamBindingIfNeeded(Optional<KnownBoolFlagParam> Flag,
7177+
BlockKind Block) {
7178+
if (!Flag)
7179+
return;
7180+
// Figure out the polarity of the binding based on the block we're in and
7181+
// whether the flag indicates success.
7182+
auto Polarity = true;
7183+
switch (Block) {
7184+
case BlockKind::SUCCESS:
7185+
break;
7186+
case BlockKind::ERROR:
7187+
Polarity = !Polarity;
7188+
break;
7189+
case BlockKind::FALLBACK:
7190+
llvm_unreachable("Not a valid place to bind");
7191+
}
7192+
if (!Flag->IsSuccessFlag)
7193+
Polarity = !Polarity;
7194+
7195+
OS << newNameFor(Flag->Param) << " " << tok::equal << " ";
7196+
OS << (Polarity ? tok::kw_true : tok::kw_false) << "\n";
7197+
}
7198+
71477199
/// Add a call to the async alternative of \p CE and convert the \p Callback
71487200
/// to be executed after the async call. \p HandlerDesc describes the
71497201
/// completion handler in the function that's called by \p CE and \p ArgList
@@ -7165,8 +7217,7 @@ class AsyncConverter : private SourceEntityWalker {
71657217
DiagEngine, CallbackBody);
71667218
}
71677219

7168-
SmallVector<const ParamDecl *, 4> Scratch;
7169-
auto SuccessBindings = CallbackParams.getSuccessParamsToBind(Scratch);
7220+
auto SuccessBindings = CallbackParams.getParamsToBind(BlockKind::SUCCESS);
71707221
auto *ErrParam = CallbackParams.getErrParam();
71717222
if (DiagEngine.hadAnyError()) {
71727223
// For now, only fallback when the results are params with an error param,
@@ -7180,18 +7231,21 @@ class AsyncConverter : private SourceEntityWalker {
71807231
// assignments to the names in the outer scope.
71817232
InlinePatternsToPrint InlinePatterns;
71827233

7183-
SmallVector<const ParamDecl *, 4> AllBindings;
7184-
AllBindings.append(SuccessBindings.begin(), SuccessBindings.end());
7185-
AllBindings.push_back(ErrParam);
7234+
auto AllBindings = CallbackParams.getParamsToBind(BlockKind::FALLBACK);
71867235

71877236
prepareNames(ClassifiedBlock(), AllBindings, InlinePatterns);
71887237
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
7189-
PlaceholderMode::FALLBACK);
7190-
addFallbackVars(AllBindings, Blocks);
7238+
BlockKind::FALLBACK);
7239+
addFallbackVars(AllBindings, CallbackParams);
71917240
addDo();
71927241
addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns,
71937242
HandlerDesc, /*AddDeclarations*/ false);
7194-
addFallbackCatch(ErrParam);
7243+
OS << "\n";
7244+
7245+
// If we have a known Bool success param, we need to bind it.
7246+
addBoolFlagParamBindingIfNeeded(CallbackParams.getKnownBoolFlagParam(),
7247+
BlockKind::SUCCESS);
7248+
addFallbackCatch(CallbackParams);
71957249
OS << "\n";
71967250
convertNodes(NodesToPrint::inBraceStmt(CallbackBody));
71977251

@@ -7242,7 +7296,7 @@ class AsyncConverter : private SourceEntityWalker {
72427296

72437297
prepareNames(Blocks.SuccessBlock, SuccessBindings, InlinePatterns);
72447298
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
7245-
PlaceholderMode::SUCCESS_BLOCK);
7299+
BlockKind::SUCCESS);
72467300

72477301
addAwaitCall(CE, Blocks.SuccessBlock, SuccessBindings, InlinePatterns,
72487302
HandlerDesc, /*AddDeclarations=*/true);
@@ -7259,7 +7313,7 @@ class AsyncConverter : private SourceEntityWalker {
72597313
ErrInlinePatterns,
72607314
/*AddIfMissing=*/HandlerDesc.Type != HandlerType::RESULT);
72617315
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
7262-
PlaceholderMode::ERROR_BLOCK);
7316+
BlockKind::ERROR);
72637317

72647318
addCatch(ErrOrResultParam);
72657319
convertNodes(Blocks.ErrorBlock.nodesToPrint());
@@ -7529,12 +7583,17 @@ class AsyncConverter : private SourceEntityWalker {
75297583
OS << tok::r_paren;
75307584
}
75317585

7532-
void addFallbackCatch(const ParamDecl *ErrParam) {
7586+
void addFallbackCatch(const ClosureCallbackParams &Params) {
7587+
auto *ErrParam = Params.getErrParam();
7588+
assert(ErrParam);
75337589
auto ErrName = newNameFor(ErrParam);
7534-
OS << "\n"
7535-
<< tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n"
7536-
<< ErrName << " = error\n"
7537-
<< tok::r_brace;
7590+
OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n"
7591+
<< ErrName << " = error\n";
7592+
7593+
// If we have a known Bool success param, we need to bind it.
7594+
addBoolFlagParamBindingIfNeeded(Params.getKnownBoolFlagParam(),
7595+
BlockKind::ERROR);
7596+
OS << tok::r_brace;
75387597
}
75397598

75407599
void addCatch(const ParamDecl *ErrParam) {
@@ -7546,31 +7605,27 @@ class AsyncConverter : private SourceEntityWalker {
75467605
OS << tok::l_brace;
75477606
}
75487607

7549-
enum class PlaceholderMode {
7550-
SUCCESS_BLOCK, ERROR_BLOCK, FALLBACK
7551-
};
7552-
75537608
void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc,
75547609
const ClosureCallbackParams &Params,
7555-
PlaceholderMode Mode) {
7610+
BlockKind Block) {
75567611
// Params that have been dropped always need placeholdering.
75577612
for (auto *Param : Params.getAllParams()) {
7558-
if (!Params.hasBinding(Param))
7613+
if (!Params.hasBinding(Param, Block))
75597614
Placeholders.insert(Param);
75607615
}
75617616
// For the fallback case, no other params need placeholdering, as they are
75627617
// all freely accessible in the fallback case.
7563-
if (Mode == PlaceholderMode::FALLBACK)
7618+
if (Block == BlockKind::FALLBACK)
75647619
return;
75657620

75667621
switch (HandlerDesc.Type) {
75677622
case HandlerType::PARAMS: {
75687623
auto *ErrParam = Params.getErrParam();
75697624
auto SuccessParams = Params.getSuccessParams();
7570-
switch (Mode) {
7571-
case PlaceholderMode::FALLBACK:
7625+
switch (Block) {
7626+
case BlockKind::FALLBACK:
75727627
llvm_unreachable("Already handled");
7573-
case PlaceholderMode::ERROR_BLOCK:
7628+
case BlockKind::ERROR:
75747629
if (ErrParam) {
75757630
if (HandlerDesc.shouldUnwrap(ErrParam->getType())) {
75767631
Placeholders.insert(ErrParam);
@@ -7580,7 +7635,7 @@ class AsyncConverter : private SourceEntityWalker {
75807635
Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
75817636
}
75827637
break;
7583-
case PlaceholderMode::SUCCESS_BLOCK:
7638+
case BlockKind::SUCCESS:
75847639
for (auto *SuccessParam : SuccessParams) {
75857640
auto Ty = SuccessParam->getType();
75867641
if (HandlerDesc.shouldUnwrap(Ty)) {

test/refactoring/ConvertAsync/convert_bool.swift

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,14 +466,37 @@ func testConvertBool() async throws {
466466
print("much success", unrelated, str)
467467
}
468468
// OBJC-BOOL-WITH-ERR-FALLBACK: var str: String? = nil
469+
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: let success: Bool
469470
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: var unrelated: Bool? = nil
470471
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: var err: Error? = nil
471472
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: do {
472473
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: (str, unrelated) = try await ClassWithHandlerMethods.firstBoolFlagSuccess("")
474+
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: success = true
473475
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: } catch {
474476
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: err = error
477+
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: success = false
475478
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: }
476479
// OBJC-BOOL-WITH-ERR-FALLBACK-EMPTY:
477-
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: guard <#success#> && <#success#> == .random() else { fatalError() }
480+
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: guard success && success == .random() else { fatalError() }
478481
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: print("much success", unrelated, str)
482+
483+
// RUN: %refactor-check-compiles -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+1):3 -I %S/Inputs -I %t -target %target-triple %clang-importer-sdk-nosource | %FileCheck -check-prefix=OBJC-BOOL-WITH-ERR-FALLBACK2 %s
484+
ClassWithHandlerMethods.secondBoolFlagFailure("") { str, unrelated, failure, err in
485+
guard !failure && failure == .random() else { fatalError() }
486+
print("much fails", unrelated, str)
487+
}
488+
// OBJC-BOOL-WITH-ERR-FALLBACK2: var str: String? = nil
489+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: var unrelated: Bool? = nil
490+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: let failure: Bool
491+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: var err: Error? = nil
492+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: do {
493+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: (str, unrelated) = try await ClassWithHandlerMethods.secondBoolFlagFailure("")
494+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: failure = false
495+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: } catch {
496+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: err = error
497+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: failure = true
498+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: }
499+
// OBJC-BOOL-WITH-ERR-FALLBACK2-EMPTY:
500+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: guard !failure && failure == .random() else { fatalError() }
501+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: print("much fails", unrelated, str)
479502
}

0 commit comments

Comments
 (0)