Skip to content

Commit b4bfbf7

Browse files
committed
[Async Refactoring] Bind known bool param in fallback
Generalize the logic to handle different BlockKinds, and add binding logic that lets us assign `true` or `false` to the given bool success param in the fallback case.
1 parent b5459b4 commit b4bfbf7

File tree

2 files changed

+127
-49
lines changed

2 files changed

+127
-49
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 103 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -4971,6 +4971,11 @@ class ClassifiedBlock {
49714971
}
49724972
};
49734973

4974+
/// The type of block rewritten code may be placed in.
4975+
enum class BlockKind {
4976+
SUCCESS, ERROR, FALLBACK
4977+
};
4978+
49744979
/// A completion handler function parameter that is known to be a Bool flag
49754980
/// indicating success or failure.
49764981
struct KnownBoolFlagParam {
@@ -4992,7 +4997,7 @@ class ClosureCallbackParams final {
49924997
: HandlerDesc(HandlerDesc),
49934998
AllParams(Closure->getParameters()->getArray()) {
49944999
assert(AllParams.size() == HandlerDesc.params().size());
4995-
assert(!(HandlerDesc.Type == HandlerType::RESULT && AllParams.size() != 1));
5000+
assert(HandlerDesc.Type != HandlerType::RESULT || AllParams.size() == 1);
49965001

49975002
SuccessParams.insert(AllParams.begin(), AllParams.end());
49985003
if (HandlerDesc.HasError && HandlerDesc.Type == HandlerType::PARAMS)
@@ -5031,29 +5036,41 @@ class ClosureCallbackParams final {
50315036
return HandlerDesc.shouldUnwrap(Param->getType());
50325037
}
50335038

5039+
/// Whether \p Param is the known Bool parameter that indicates success or
5040+
/// failure.
5041+
bool isKnownBoolFlagParam(const ParamDecl *Param) const {
5042+
if (auto BoolFlag = getKnownBoolFlagParam())
5043+
return BoolFlag->Param == Param;
5044+
return false;
5045+
}
5046+
50345047
/// Whether \p Param is a closure parameter that has a binding available in
5035-
/// the async variant of the call, either as a thrown error, or a success
5036-
/// return value.
5037-
bool hasBinding(const ParamDecl *Param) const {
5038-
if (!hasParam(Param))
5039-
return false;
5040-
if (auto BoolFlag = getKnownBoolFlagParam()) {
5041-
if (Param == BoolFlag->Param)
5048+
/// the async variant of the call for a particular \p Block.
5049+
bool hasBinding(const ParamDecl *Param, BlockKind Block) const {
5050+
switch (Block) {
5051+
case BlockKind::SUCCESS:
5052+
// Known bool flags get dropped from the imported async variant.
5053+
if (isKnownBoolFlagParam(Param))
50425054
return false;
5055+
5056+
return isSuccessParam(Param);
5057+
case BlockKind::ERROR:
5058+
return Param == ErrParam;
5059+
case BlockKind::FALLBACK:
5060+
// We generally want to bind everything in the fallback case.
5061+
return hasParam(Param);
50435062
}
5044-
return true;
5063+
llvm_unreachable("Unhandled case in switch");
50455064
}
50465065

5047-
/// Retrieve the success parameters that have a binding in a call to the
5048-
/// async variant.
5049-
ArrayRef<const ParamDecl *>
5050-
getSuccessParamsToBind(SmallVectorImpl<const ParamDecl *> &Scratch) {
5051-
assert(Scratch.empty());
5052-
for (auto *Param : SuccessParams) {
5053-
if (hasBinding(Param))
5054-
Scratch.push_back(Param);
5066+
/// Retrieve the parameters to bind in a given \p Block.
5067+
TinyPtrVector<const ParamDecl *> getParamsToBind(BlockKind Block) {
5068+
TinyPtrVector<const ParamDecl *> Result;
5069+
for (auto *Param : AllParams) {
5070+
if (hasBinding(Param, Block))
5071+
Result.push_back(Param);
50555072
}
5056-
return Scratch;
5073+
return Result;
50575074
}
50585075

50595076
/// If there is a known Bool flag parameter indicating success or failure,
@@ -5303,7 +5320,7 @@ struct CallbackClassifier {
53035320
// Check to see if we have a known bool flag parameter that indicates
53045321
// success or failure.
53055322
if (auto KnownBoolFlag = Params.getKnownBoolFlagParam()) {
5306-
if (KnownBoolFlag->Param != Cond.Subject)
5323+
if (KnownBoolFlag->Param != SubjectParam)
53075324
return None;
53085325

53095326
// The path may need to be flipped depending on whether the flag indicates
@@ -6785,10 +6802,21 @@ class AsyncConverter : private SourceEntityWalker {
67856802
}
67866803

67876804
void addFallbackVars(ArrayRef<const ParamDecl *> FallbackParams,
6788-
ClassifiedBlocks &Blocks) {
6789-
for (auto Param : FallbackParams) {
6790-
OS << tok::kw_var << " " << newNameFor(Param) << ": ";
6805+
const ClosureCallbackParams &AllParams) {
6806+
for (auto *Param : FallbackParams) {
67916807
auto Ty = Param->getType();
6808+
auto ParamName = newNameFor(Param);
6809+
6810+
// If this is the known bool success param, we can use 'let' and type it
6811+
// as non-optional, as it gets bound in both blocks.
6812+
if (AllParams.isKnownBoolFlagParam(Param)) {
6813+
OS << tok::kw_let << " " << ParamName << ": ";
6814+
Ty->print(OS);
6815+
OS << "\n";
6816+
continue;
6817+
}
6818+
6819+
OS << tok::kw_var << " " << ParamName << ": ";
67926820
Ty->print(OS);
67936821
if (!Ty->getOptionalObjectType())
67946822
OS << "?";
@@ -7207,6 +7235,30 @@ class AsyncConverter : private SourceEntityWalker {
72077235
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
72087236
}
72097237

7238+
/// Add a binding to a known bool flag that indicates success or failure.
7239+
void addBoolFlagParamBindingIfNeeded(Optional<KnownBoolFlagParam> Flag,
7240+
BlockKind Block) {
7241+
if (!Flag)
7242+
return;
7243+
// Figure out the polarity of the binding based on the block we're in and
7244+
// whether the flag indicates success.
7245+
auto Polarity = true;
7246+
switch (Block) {
7247+
case BlockKind::SUCCESS:
7248+
break;
7249+
case BlockKind::ERROR:
7250+
Polarity = !Polarity;
7251+
break;
7252+
case BlockKind::FALLBACK:
7253+
llvm_unreachable("Not a valid place to bind");
7254+
}
7255+
if (!Flag->IsSuccessFlag)
7256+
Polarity = !Polarity;
7257+
7258+
OS << newNameFor(Flag->Param) << " " << tok::equal << " ";
7259+
OS << (Polarity ? tok::kw_true : tok::kw_false) << "\n";
7260+
}
7261+
72107262
/// Add a call to the async alternative of \p CE and convert the \p Callback
72117263
/// to be executed after the async call. \p HandlerDesc describes the
72127264
/// completion handler in the function that's called by \p CE and \p ArgList
@@ -7229,8 +7281,7 @@ class AsyncConverter : private SourceEntityWalker {
72297281
DiagEngine, CallbackBody);
72307282
}
72317283

7232-
SmallVector<const ParamDecl *, 4> Scratch;
7233-
auto SuccessBindings = CallbackParams.getSuccessParamsToBind(Scratch);
7284+
auto SuccessBindings = CallbackParams.getParamsToBind(BlockKind::SUCCESS);
72347285
auto *ErrParam = CallbackParams.getErrParam();
72357286
if (DiagEngine.hadAnyError()) {
72367287
// For now, only fallback when the results are params with an error param,
@@ -7244,18 +7295,21 @@ class AsyncConverter : private SourceEntityWalker {
72447295
// assignments to the names in the outer scope.
72457296
InlinePatternsToPrint InlinePatterns;
72467297

7247-
SmallVector<const ParamDecl *, 4> AllBindings;
7248-
AllBindings.append(SuccessBindings.begin(), SuccessBindings.end());
7249-
AllBindings.push_back(ErrParam);
7298+
auto AllBindings = CallbackParams.getParamsToBind(BlockKind::FALLBACK);
72507299

72517300
prepareNames(ClassifiedBlock(), AllBindings, InlinePatterns);
72527301
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
7253-
PlaceholderMode::FALLBACK);
7254-
addFallbackVars(AllBindings, Blocks);
7302+
BlockKind::FALLBACK);
7303+
addFallbackVars(AllBindings, CallbackParams);
72557304
addDo();
72567305
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessBindings,
72577306
InlinePatterns, HandlerDesc, /*AddDeclarations*/ false);
7258-
addFallbackCatch(ErrParam);
7307+
OS << "\n";
7308+
7309+
// If we have a known Bool success param, we need to bind it.
7310+
addBoolFlagParamBindingIfNeeded(CallbackParams.getKnownBoolFlagParam(),
7311+
BlockKind::SUCCESS);
7312+
addFallbackCatch(CallbackParams);
72597313
OS << "\n";
72607314
convertNodes(NodesToPrint::inBraceStmt(CallbackBody));
72617315

@@ -7305,7 +7359,7 @@ class AsyncConverter : private SourceEntityWalker {
73057359

73067360
prepareNames(Blocks.SuccessBlock, SuccessBindings, InlinePatterns);
73077361
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
7308-
PlaceholderMode::SUCCESS_BLOCK);
7362+
BlockKind::SUCCESS);
73097363

73107364
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessBindings,
73117365
InlinePatterns, HandlerDesc, /*AddDeclarations=*/true);
@@ -7322,7 +7376,7 @@ class AsyncConverter : private SourceEntityWalker {
73227376
ErrInlinePatterns,
73237377
/*AddIfMissing=*/HandlerDesc.Type != HandlerType::RESULT);
73247378
preparePlaceholdersAndUnwraps(HandlerDesc, CallbackParams,
7325-
PlaceholderMode::ERROR_BLOCK);
7379+
BlockKind::ERROR);
73267380

73277381
addCatch(ErrOrResultParam);
73287382
convertNodes(Blocks.ErrorBlock.nodesToPrint());
@@ -7590,12 +7644,17 @@ class AsyncConverter : private SourceEntityWalker {
75907644
OS << tok::r_paren;
75917645
}
75927646

7593-
void addFallbackCatch(const ParamDecl *ErrParam) {
7647+
void addFallbackCatch(const ClosureCallbackParams &Params) {
7648+
auto *ErrParam = Params.getErrParam();
7649+
assert(ErrParam);
75947650
auto ErrName = newNameFor(ErrParam);
7595-
OS << "\n"
7596-
<< tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n"
7597-
<< ErrName << " = error\n"
7598-
<< tok::r_brace;
7651+
OS << tok::r_brace << " " << tok::kw_catch << " " << tok::l_brace << "\n"
7652+
<< ErrName << " = error\n";
7653+
7654+
// If we have a known Bool success param, we need to bind it.
7655+
addBoolFlagParamBindingIfNeeded(Params.getKnownBoolFlagParam(),
7656+
BlockKind::ERROR);
7657+
OS << tok::r_brace;
75997658
}
76007659

76017660
void addCatch(const ParamDecl *ErrParam) {
@@ -7607,31 +7666,27 @@ class AsyncConverter : private SourceEntityWalker {
76077666
OS << tok::l_brace;
76087667
}
76097668

7610-
enum class PlaceholderMode {
7611-
SUCCESS_BLOCK, ERROR_BLOCK, FALLBACK
7612-
};
7613-
76147669
void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc,
76157670
const ClosureCallbackParams &Params,
7616-
PlaceholderMode Mode) {
7671+
BlockKind Block) {
76177672
// Params that have been dropped always need placeholdering.
76187673
for (auto *Param : Params.getAllParams()) {
7619-
if (!Params.hasBinding(Param))
7674+
if (!Params.hasBinding(Param, Block))
76207675
Placeholders.insert(Param);
76217676
}
76227677
// For the fallback case, no other params need placeholdering, as they are
76237678
// all freely accessible in the fallback case.
7624-
if (Mode == PlaceholderMode::FALLBACK)
7679+
if (Block == BlockKind::FALLBACK)
76257680
return;
76267681

76277682
switch (HandlerDesc.Type) {
76287683
case HandlerType::PARAMS: {
76297684
auto *ErrParam = Params.getErrParam();
76307685
auto SuccessParams = Params.getSuccessParams();
7631-
switch (Mode) {
7632-
case PlaceholderMode::FALLBACK:
7686+
switch (Block) {
7687+
case BlockKind::FALLBACK:
76337688
llvm_unreachable("Already handled");
7634-
case PlaceholderMode::ERROR_BLOCK:
7689+
case BlockKind::ERROR:
76357690
if (ErrParam) {
76367691
if (HandlerDesc.shouldUnwrap(ErrParam->getType())) {
76377692
Placeholders.insert(ErrParam);
@@ -7641,7 +7696,7 @@ class AsyncConverter : private SourceEntityWalker {
76417696
Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
76427697
}
76437698
break;
7644-
case PlaceholderMode::SUCCESS_BLOCK:
7699+
case BlockKind::SUCCESS:
76457700
for (auto *SuccessParam : SuccessParams) {
76467701
auto Ty = SuccessParam->getType();
76477702
if (HandlerDesc.shouldUnwrap(Ty)) {

test/refactoring/ConvertAsync/convert_bool.swift

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,14 +466,37 @@ func testConvertBool() async throws {
466466
print("much success", unrelated, str)
467467
}
468468
// OBJC-BOOL-WITH-ERR-FALLBACK: var str: String? = nil
469+
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: let success: Bool
469470
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: var unrelated: Bool? = nil
470471
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: var err: Error? = nil
471472
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: do {
472473
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: (str, unrelated) = try await ClassWithHandlerMethods.firstBoolFlagSuccess("")
474+
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: success = true
473475
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: } catch {
474476
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: err = error
477+
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: success = false
475478
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: }
476479
// OBJC-BOOL-WITH-ERR-FALLBACK-EMPTY:
477-
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: guard <#success#> && <#success#> == .random() else { fatalError() }
480+
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: guard success && success == .random() else { fatalError() }
478481
// OBJC-BOOL-WITH-ERR-FALLBACK-NEXT: print("much success", unrelated, str)
482+
483+
// RUN: %refactor-check-compiles -convert-call-to-async-alternative -dump-text -source-filename %s -pos=%(line+1):3 -I %S/Inputs -I %t -target %target-triple %clang-importer-sdk-nosource | %FileCheck -check-prefix=OBJC-BOOL-WITH-ERR-FALLBACK2 %s
484+
ClassWithHandlerMethods.secondBoolFlagFailure("") { str, unrelated, failure, err in
485+
guard !failure && failure == .random() else { fatalError() }
486+
print("much fails", unrelated, str)
487+
}
488+
// OBJC-BOOL-WITH-ERR-FALLBACK2: var str: String? = nil
489+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: var unrelated: Bool? = nil
490+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: let failure: Bool
491+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: var err: Error? = nil
492+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: do {
493+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: (str, unrelated) = try await ClassWithHandlerMethods.secondBoolFlagFailure("")
494+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: failure = false
495+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: } catch {
496+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: err = error
497+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: failure = true
498+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: }
499+
// OBJC-BOOL-WITH-ERR-FALLBACK2-EMPTY:
500+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: guard !failure && failure == .random() else { fatalError() }
501+
// OBJC-BOOL-WITH-ERR-FALLBACK2-NEXT: print("much fails", unrelated, str)
479502
}

0 commit comments

Comments
 (0)