Skip to content

Commit 56347d1

Browse files
authored
Merge pull request #37297 from bnbarham/cherry-rdar73973459
[Refactoring] Only unwrap optionals if the handler has an optional error
2 parents 378b1be + a6d4358 commit 56347d1

File tree

4 files changed

+267
-113
lines changed

4 files changed

+267
-113
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 104 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4293,6 +4293,10 @@ struct AsyncHandlerDesc {
42934293
return getSuccessParamAsyncReturnType(Ty)->isVoid();
42944294
});
42954295
}
4296+
4297+
bool shouldUnwrap(swift::Type Ty) const {
4298+
return HasError && Ty->isOptional();
4299+
}
42964300
};
42974301

42984302
enum class ConditionType { INVALID, NIL, NOT_NIL };
@@ -4557,18 +4561,12 @@ struct CallbackClassifier {
45574561
static void classifyInto(ClassifiedBlocks &Blocks,
45584562
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
45594563
DiagnosticEngine &DiagEngine,
4560-
ArrayRef<const ParamDecl *> SuccessParams,
4564+
llvm::DenseSet<const Decl *> UnwrapParams,
45614565
const ParamDecl *ErrParam, HandlerType ResultType,
45624566
ArrayRef<ASTNode> Body) {
45634567
assert(!Body.empty() && "Cannot classify empty body");
4564-
4565-
auto ParamsSet = llvm::DenseSet<const Decl *>(SuccessParams.begin(),
4566-
SuccessParams.end());
4567-
if (ErrParam)
4568-
ParamsSet.insert(ErrParam);
4569-
45704568
CallbackClassifier Classifier(Blocks, HandledSwitches, DiagEngine,
4571-
ParamsSet, ErrParam,
4569+
UnwrapParams, ErrParam,
45724570
ResultType == HandlerType::RESULT);
45734571
Classifier.classifyNodes(Body);
45744572
}
@@ -4578,19 +4576,19 @@ struct CallbackClassifier {
45784576
llvm::DenseSet<SwitchStmt *> &HandledSwitches;
45794577
DiagnosticEngine &DiagEngine;
45804578
ClassifiedBlock *CurrentBlock;
4581-
llvm::DenseSet<const Decl *> ParamsSet;
4579+
llvm::DenseSet<const Decl *> UnwrapParams;
45824580
const ParamDecl *ErrParam;
45834581
bool IsResultParam;
45844582

45854583
CallbackClassifier(ClassifiedBlocks &Blocks,
45864584
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
45874585
DiagnosticEngine &DiagEngine,
4588-
llvm::DenseSet<const Decl *> ParamsSet,
4586+
llvm::DenseSet<const Decl *> UnwrapParams,
45894587
const ParamDecl *ErrParam, bool IsResultParam)
45904588
: Blocks(Blocks), HandledSwitches(HandledSwitches),
45914589
DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock),
4592-
ParamsSet(ParamsSet), ErrParam(ErrParam), IsResultParam(IsResultParam) {
4593-
}
4590+
UnwrapParams(UnwrapParams), ErrParam(ErrParam),
4591+
IsResultParam(IsResultParam) {}
45944592

45954593
void classifyNodes(ArrayRef<ASTNode> Nodes) {
45964594
for (auto I = Nodes.begin(), E = Nodes.end(); I < E; ++I) {
@@ -4622,7 +4620,7 @@ struct CallbackClassifier {
46224620
ArrayRef<ASTNode> ThenNodes, Stmt *ElseStmt) {
46234621
llvm::DenseMap<const Decl *, CallbackCondition> CallbackConditions;
46244622
bool UnhandledConditions =
4625-
!CallbackCondition::all(Condition, ParamsSet, CallbackConditions);
4623+
!CallbackCondition::all(Condition, UnwrapParams, CallbackConditions);
46264624
CallbackCondition ErrCondition = CallbackConditions.lookup(ErrParam);
46274625

46284626
if (UnhandledConditions) {
@@ -4950,7 +4948,7 @@ class AsyncConverter : private SourceEntityWalker {
49504948
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
49514949
if (HandlerDesc.isValid())
49524950
return addCustom(CE->getSourceRange(),
4953-
[&]() { addAsyncAlternativeCall(CE, HandlerDesc); });
4951+
[&]() { addHoistedCallback(CE, HandlerDesc); });
49544952
}
49554953
}
49564954

@@ -5153,8 +5151,8 @@ class AsyncConverter : private SourceEntityWalker {
51535151
}
51545152
}
51555153

5156-
void addAsyncAlternativeCall(const CallExpr *CE,
5157-
const AsyncHandlerDesc &HandlerDesc) {
5154+
void addHoistedCallback(const CallExpr *CE,
5155+
const AsyncHandlerDesc &HandlerDesc) {
51585156
auto ArgList = callArgs(CE);
51595157
if ((size_t)HandlerDesc.Index >= ArgList.ref().size()) {
51605158
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
@@ -5167,53 +5165,63 @@ class AsyncConverter : private SourceEntityWalker {
51675165
return;
51685166
}
51695167

5170-
ParameterList *CallbackParams = Callback->getParameters();
5168+
ArrayRef<const ParamDecl *> CallbackParams =
5169+
Callback->getParameters()->getArray();
51715170
ArrayRef<ASTNode> CallbackBody = Callback->getBody()->getElements();
5172-
if (HandlerDesc.params().size() != CallbackParams->size()) {
5171+
if (HandlerDesc.params().size() != CallbackParams.size()) {
51735172
DiagEngine.diagnose(CE->getStartLoc(), diag::mismatched_callback_args);
51745173
return;
51755174
}
51765175

51775176
// Note that the `ErrParam` may be a Result (in which case it's also the
51785177
// only element in `SuccessParams`)
5179-
ArrayRef<const ParamDecl *> SuccessParams = CallbackParams->getArray();
5178+
ArrayRef<const ParamDecl *> SuccessParams = CallbackParams;
51805179
const ParamDecl *ErrParam = nullptr;
5181-
if (HandlerDesc.HasError) {
5180+
if (HandlerDesc.Type == HandlerType::RESULT) {
51825181
ErrParam = SuccessParams.back();
5183-
if (HandlerDesc.Type == HandlerType::PARAMS)
5184-
SuccessParams = SuccessParams.drop_back();
5182+
} else if (HandlerDesc.HasError) {
5183+
assert(HandlerDesc.Type == HandlerType::PARAMS);
5184+
ErrParam = SuccessParams.back();
5185+
SuccessParams = SuccessParams.drop_back();
51855186
}
5186-
ArrayRef<const ParamDecl *> ErrParams;
5187-
if (ErrParam)
5188-
ErrParams = llvm::makeArrayRef(ErrParam);
51895187

51905188
ClassifiedBlocks Blocks;
51915189
if (!HandlerDesc.HasError) {
51925190
Blocks.SuccessBlock.addAllNodes(CallbackBody);
51935191
} else if (!CallbackBody.empty()) {
5192+
llvm::DenseSet<const Decl *> UnwrapParams;
5193+
for (auto *Param : SuccessParams) {
5194+
if (HandlerDesc.shouldUnwrap(Param->getType()))
5195+
UnwrapParams.insert(Param);
5196+
}
5197+
if (ErrParam)
5198+
UnwrapParams.insert(ErrParam);
51945199
CallbackClassifier::classifyInto(Blocks, HandledSwitches, DiagEngine,
5195-
SuccessParams, ErrParam,
5196-
HandlerDesc.Type, CallbackBody);
5197-
if (DiagEngine.hadAnyError()) {
5198-
// Can only fallback when the results are params, in which case only
5199-
// the names are used (defaulted to the names of the params if none)
5200-
if (HandlerDesc.Type != HandlerType::PARAMS)
5201-
return;
5202-
DiagEngine.resetHadAnyError();
5200+
UnwrapParams, ErrParam, HandlerDesc.Type,
5201+
CallbackBody);
5202+
}
52035203

5204-
setNames(ClassifiedBlock(), CallbackParams->getArray());
5204+
if (DiagEngine.hadAnyError()) {
5205+
// Can only fallback when the results are params, in which case only
5206+
// the names are used (defaulted to the names of the params if none)
5207+
if (HandlerDesc.Type != HandlerType::PARAMS)
5208+
return;
5209+
DiagEngine.resetHadAnyError();
52055210

5206-
addFallbackVars(CallbackParams->getArray(), Blocks);
5207-
addDo();
5208-
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
5209-
HandlerDesc, /*AddDeclarations=*/!HandlerDesc.HasError);
5210-
addFallbackCatch(ErrParam);
5211-
OS << "\n";
5212-
convertNodes(CallbackBody);
5211+
// Don't do any unwrapping or placeholder replacement since all params
5212+
// are still valid in the fallback case
5213+
prepareNames(ClassifiedBlock(), CallbackParams);
52135214

5214-
clearParams(CallbackParams->getArray());
5215-
return;
5216-
}
5215+
addFallbackVars(CallbackParams, Blocks);
5216+
addDo();
5217+
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
5218+
HandlerDesc, /*AddDeclarations=*/!HandlerDesc.HasError);
5219+
addFallbackCatch(ErrParam);
5220+
OS << "\n";
5221+
convertNodes(CallbackBody);
5222+
5223+
clearNames(CallbackParams);
5224+
return;
52175225
}
52185226

52195227
bool RequireDo = !Blocks.ErrorBlock.nodes().empty();
@@ -5237,25 +5245,27 @@ class AsyncConverter : private SourceEntityWalker {
52375245
addDo();
52385246
}
52395247

5240-
setNames(Blocks.SuccessBlock, SuccessParams);
5248+
prepareNames(Blocks.SuccessBlock, SuccessParams);
5249+
preparePlaceholdersAndUnwraps(HandlerDesc, SuccessParams, ErrParam,
5250+
/*Success=*/true);
5251+
52415252
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
52425253
HandlerDesc, /*AddDeclarations=*/true);
5243-
5244-
prepareNamesForBody(HandlerDesc, SuccessParams, ErrParams);
52455254
convertNodes(Blocks.SuccessBlock.nodes());
5255+
clearNames(SuccessParams);
52465256

52475257
if (RequireDo) {
5248-
clearParams(SuccessParams);
52495258
// Always use the ErrParam name if none is bound
5250-
setNames(Blocks.ErrorBlock, ErrParams,
5251-
HandlerDesc.Type != HandlerType::RESULT);
5252-
addCatch(ErrParam);
5259+
prepareNames(Blocks.ErrorBlock, llvm::makeArrayRef(ErrParam),
5260+
HandlerDesc.Type != HandlerType::RESULT);
5261+
preparePlaceholdersAndUnwraps(HandlerDesc, SuccessParams, ErrParam,
5262+
/*Success=*/false);
52535263

5254-
prepareNamesForBody(HandlerDesc, ErrParams, SuccessParams);
5255-
addCatchBody(ErrParam, Blocks.ErrorBlock);
5264+
addCatch(ErrParam);
5265+
convertNodes(Blocks.ErrorBlock.nodes());
5266+
OS << "\n" << tok::r_brace;
5267+
clearNames(llvm::makeArrayRef(ErrParam));
52565268
}
5257-
5258-
clearParams(CallbackParams->getArray());
52595269
}
52605270

52615271
void addAwaitCall(const CallExpr *CE, ArrayRef<Expr *> Args,
@@ -5326,44 +5336,56 @@ class AsyncConverter : private SourceEntityWalker {
53265336
OS << tok::l_brace;
53275337
}
53285338

5329-
void addCatchBody(const ParamDecl *ErrParam,
5330-
const ClassifiedBlock &ErrorBlock) {
5331-
convertNodes(ErrorBlock.nodes());
5332-
OS << "\n" << tok::r_brace;
5333-
}
5334-
5335-
void prepareNamesForBody(const AsyncHandlerDesc &HandlerDesc,
5336-
ArrayRef<const ParamDecl *> CurrentParams,
5337-
ArrayRef<const ParamDecl *> OtherParams) {
5339+
void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc,
5340+
ArrayRef<const ParamDecl *> SuccessParams,
5341+
const ParamDecl *ErrParam, bool Success) {
53385342
switch (HandlerDesc.Type) {
53395343
case HandlerType::PARAMS:
5340-
for (auto *Param : CurrentParams) {
5341-
auto Ty = Param->getType();
5342-
if (Ty->getOptionalObjectType()) {
5343-
Unwraps.insert(Param);
5344-
Placeholders.insert(Param);
5344+
if (!Success) {
5345+
if (ErrParam) {
5346+
if (HandlerDesc.shouldUnwrap(ErrParam->getType())) {
5347+
Placeholders.insert(ErrParam);
5348+
Unwraps.insert(ErrParam);
5349+
}
5350+
// Can't use success params in the error body
5351+
Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
5352+
}
5353+
} else {
5354+
for (auto *SuccessParam : SuccessParams) {
5355+
auto Ty = SuccessParam->getType();
5356+
if (HandlerDesc.shouldUnwrap(Ty)) {
5357+
// Either unwrap or replace with a placeholder if there's some other
5358+
// reference
5359+
Unwraps.insert(SuccessParam);
5360+
Placeholders.insert(SuccessParam);
5361+
}
5362+
5363+
// Void parameters get omitted where possible, so turn any reference
5364+
// into a placeholder, as its usage is unlikely what the user wants.
5365+
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
5366+
Placeholders.insert(SuccessParam);
53455367
}
5346-
// Void parameters get omitted where possible, so turn any reference
5347-
// into a placeholder, as its usage is unlikely what the user wants.
5348-
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
5349-
Placeholders.insert(Param);
5368+
// Can't use the error param in the success body
5369+
if (ErrParam)
5370+
Placeholders.insert(ErrParam);
53505371
}
5351-
// Use of the other params is invalid within the current body
5352-
Placeholders.insert(OtherParams.begin(), OtherParams.end());
53535372
break;
53545373
case HandlerType::RESULT:
5355-
// Any uses of the result parameter in the current body (that
5356-
// isn't replaced) are invalid, so replace them with a placeholder
5357-
Placeholders.insert(CurrentParams.begin(), CurrentParams.end());
5374+
// Any uses of the result parameter in the current body (that aren't
5375+
// replaced) are invalid, so replace them with a placeholder.
5376+
assert(SuccessParams.size() == 1 && SuccessParams[0] == ErrParam);
5377+
Placeholders.insert(ErrParam);
53585378
break;
53595379
default:
53605380
llvm_unreachable("Unhandled handler type");
53615381
}
53625382
}
53635383

5364-
// TODO: Check for clashes with existing names
5365-
void setNames(const ClassifiedBlock &Block,
5366-
ArrayRef<const ParamDecl *> Params, bool AddIfMissing = true) {
5384+
// TODO: Check for clashes with existing names and add all decls, not just
5385+
// params
5386+
void prepareNames(const ClassifiedBlock &Block,
5387+
ArrayRef<const ParamDecl *> Params,
5388+
bool AddIfMissing = true) {
53675389
for (auto *Param : Params) {
53685390
StringRef Name = Block.boundName(Param);
53695391
if (!Name.empty()) {
@@ -5392,7 +5414,7 @@ class AsyncConverter : private SourceEntityWalker {
53925414
return StringRef(Res->second);
53935415
}
53945416

5395-
void clearParams(ArrayRef<const ParamDecl *> Params) {
5417+
void clearNames(ArrayRef<const ParamDecl *> Params) {
53965418
for (auto *Param : Params) {
53975419
Unwraps.erase(Param);
53985420
Placeholders.erase(Param);

0 commit comments

Comments
 (0)