Skip to content

Commit 398124c

Browse files
committed
[Refactoring] Only unwrap optionals if the handler has an optional error
Resolves rdar://73973459
1 parent 058613d commit 398124c

File tree

4 files changed

+249
-106
lines changed

4 files changed

+249
-106
lines changed

lib/IDE/Refactoring.cpp

Lines changed: 104 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -4284,6 +4284,10 @@ struct AsyncHandlerDesc {
42844284
return getSuccessParamAsyncReturnType(Ty)->isVoid();
42854285
});
42864286
}
4287+
4288+
bool shouldUnwrap(swift::Type Ty) const {
4289+
return HasError && Ty->isOptional();
4290+
}
42874291
};
42884292

42894293
enum class ConditionType { INVALID, NIL, NOT_NIL };
@@ -4549,18 +4553,12 @@ struct CallbackClassifier {
45494553
static void classifyInto(ClassifiedBlocks &Blocks,
45504554
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
45514555
DiagnosticEngine &DiagEngine,
4552-
ArrayRef<const ParamDecl *> SuccessParams,
4556+
llvm::DenseSet<const Decl *> UnwrapParams,
45534557
const ParamDecl *ErrParam, HandlerType ResultType,
45544558
ArrayRef<ASTNode> Body) {
45554559
assert(!Body.empty() && "Cannot classify empty body");
4556-
4557-
auto ParamsSet = llvm::DenseSet<const Decl *>(SuccessParams.begin(),
4558-
SuccessParams.end());
4559-
if (ErrParam)
4560-
ParamsSet.insert(ErrParam);
4561-
45624560
CallbackClassifier Classifier(Blocks, HandledSwitches, DiagEngine,
4563-
ParamsSet, ErrParam,
4561+
UnwrapParams, ErrParam,
45644562
ResultType == HandlerType::RESULT);
45654563
Classifier.classifyNodes(Body);
45664564
}
@@ -4570,19 +4568,19 @@ struct CallbackClassifier {
45704568
llvm::DenseSet<SwitchStmt *> &HandledSwitches;
45714569
DiagnosticEngine &DiagEngine;
45724570
ClassifiedBlock *CurrentBlock;
4573-
llvm::DenseSet<const Decl *> ParamsSet;
4571+
llvm::DenseSet<const Decl *> UnwrapParams;
45744572
const ParamDecl *ErrParam;
45754573
bool IsResultParam;
45764574

45774575
CallbackClassifier(ClassifiedBlocks &Blocks,
45784576
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
45794577
DiagnosticEngine &DiagEngine,
4580-
llvm::DenseSet<const Decl *> ParamsSet,
4578+
llvm::DenseSet<const Decl *> UnwrapParams,
45814579
const ParamDecl *ErrParam, bool IsResultParam)
45824580
: Blocks(Blocks), HandledSwitches(HandledSwitches),
45834581
DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock),
4584-
ParamsSet(ParamsSet), ErrParam(ErrParam), IsResultParam(IsResultParam) {
4585-
}
4582+
UnwrapParams(UnwrapParams), ErrParam(ErrParam),
4583+
IsResultParam(IsResultParam) {}
45864584

45874585
void classifyNodes(ArrayRef<ASTNode> Nodes) {
45884586
for (auto I = Nodes.begin(), E = Nodes.end(); I < E; ++I) {
@@ -4614,7 +4612,7 @@ struct CallbackClassifier {
46144612
ArrayRef<ASTNode> ThenNodes, Stmt *ElseStmt) {
46154613
llvm::DenseMap<const Decl *, CallbackCondition> CallbackConditions;
46164614
bool UnhandledConditions =
4617-
!CallbackCondition::all(Condition, ParamsSet, CallbackConditions);
4615+
!CallbackCondition::all(Condition, UnwrapParams, CallbackConditions);
46184616
CallbackCondition ErrCondition = CallbackConditions.lookup(ErrParam);
46194617

46204618
if (UnhandledConditions) {
@@ -4942,7 +4940,7 @@ class AsyncConverter : private SourceEntityWalker {
49424940
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
49434941
if (HandlerDesc.isValid())
49444942
return addCustom(CE->getSourceRange(),
4945-
[&]() { addAsyncAlternativeCall(CE, HandlerDesc); });
4943+
[&]() { addHoistedCallback(CE, HandlerDesc); });
49464944
}
49474945
}
49484946

@@ -5145,8 +5143,8 @@ class AsyncConverter : private SourceEntityWalker {
51455143
}
51465144
}
51475145

5148-
void addAsyncAlternativeCall(const CallExpr *CE,
5149-
const AsyncHandlerDesc &HandlerDesc) {
5146+
void addHoistedCallback(const CallExpr *CE,
5147+
const AsyncHandlerDesc &HandlerDesc) {
51505148
auto ArgList = callArgs(CE);
51515149
if ((size_t)HandlerDesc.Index >= ArgList.ref().size()) {
51525150
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
@@ -5159,53 +5157,63 @@ class AsyncConverter : private SourceEntityWalker {
51595157
return;
51605158
}
51615159

5162-
ParameterList *CallbackParams = Callback->getParameters();
5160+
ArrayRef<const ParamDecl *> CallbackParams =
5161+
Callback->getParameters()->getArray();
51635162
ArrayRef<ASTNode> CallbackBody = Callback->getBody()->getElements();
5164-
if (HandlerDesc.params().size() != CallbackParams->size()) {
5163+
if (HandlerDesc.params().size() != CallbackParams.size()) {
51655164
DiagEngine.diagnose(CE->getStartLoc(), diag::mismatched_callback_args);
51665165
return;
51675166
}
51685167

51695168
// Note that the `ErrParam` may be a Result (in which case it's also the
51705169
// only element in `SuccessParams`)
5171-
ArrayRef<const ParamDecl *> SuccessParams = CallbackParams->getArray();
5170+
ArrayRef<const ParamDecl *> SuccessParams = CallbackParams;
51725171
const ParamDecl *ErrParam = nullptr;
5173-
if (HandlerDesc.HasError) {
5172+
if (HandlerDesc.Type == HandlerType::RESULT) {
51745173
ErrParam = SuccessParams.back();
5175-
if (HandlerDesc.Type == HandlerType::PARAMS)
5176-
SuccessParams = SuccessParams.drop_back();
5174+
} else if (HandlerDesc.HasError) {
5175+
assert(HandlerDesc.Type == HandlerType::PARAMS);
5176+
ErrParam = SuccessParams.back();
5177+
SuccessParams = SuccessParams.drop_back();
51775178
}
5178-
ArrayRef<const ParamDecl *> ErrParams;
5179-
if (ErrParam)
5180-
ErrParams = llvm::makeArrayRef(ErrParam);
51815179

51825180
ClassifiedBlocks Blocks;
51835181
if (!HandlerDesc.HasError) {
51845182
Blocks.SuccessBlock.addAllNodes(CallbackBody);
51855183
} else if (!CallbackBody.empty()) {
5184+
llvm::DenseSet<const Decl *> UnwrapParams;
5185+
for (auto *Param : SuccessParams) {
5186+
if (HandlerDesc.shouldUnwrap(Param->getType()))
5187+
UnwrapParams.insert(Param);
5188+
}
5189+
if (ErrParam)
5190+
UnwrapParams.insert(ErrParam);
51865191
CallbackClassifier::classifyInto(Blocks, HandledSwitches, DiagEngine,
5187-
SuccessParams, ErrParam,
5188-
HandlerDesc.Type, CallbackBody);
5189-
if (DiagEngine.hadAnyError()) {
5190-
// Can only fallback when the results are params, in which case only
5191-
// the names are used (defaulted to the names of the params if none)
5192-
if (HandlerDesc.Type != HandlerType::PARAMS)
5193-
return;
5194-
DiagEngine.resetHadAnyError();
5192+
UnwrapParams, ErrParam, HandlerDesc.Type,
5193+
CallbackBody);
5194+
}
51955195

5196-
setNames(ClassifiedBlock(), CallbackParams->getArray());
5196+
if (DiagEngine.hadAnyError()) {
5197+
// Can only fallback when the results are params, in which case only
5198+
// the names are used (defaulted to the names of the params if none)
5199+
if (HandlerDesc.Type != HandlerType::PARAMS)
5200+
return;
5201+
DiagEngine.resetHadAnyError();
51975202

5198-
addFallbackVars(CallbackParams->getArray(), Blocks);
5199-
addDo();
5200-
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
5201-
HandlerDesc, /*AddDeclarations=*/!HandlerDesc.HasError);
5202-
addFallbackCatch(ErrParam);
5203-
OS << "\n";
5204-
convertNodes(CallbackBody);
5203+
// Don't do any unwrapping or placeholder replacement since all params
5204+
// are still valid in the fallback case
5205+
prepareNames(ClassifiedBlock(), CallbackParams);
52055206

5206-
clearParams(CallbackParams->getArray());
5207-
return;
5208-
}
5207+
addFallbackVars(CallbackParams, Blocks);
5208+
addDo();
5209+
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
5210+
HandlerDesc, /*AddDeclarations=*/!HandlerDesc.HasError);
5211+
addFallbackCatch(ErrParam);
5212+
OS << "\n";
5213+
convertNodes(CallbackBody);
5214+
5215+
clearNames(CallbackParams);
5216+
return;
52095217
}
52105218

52115219
bool RequireDo = !Blocks.ErrorBlock.nodes().empty();
@@ -5229,25 +5237,27 @@ class AsyncConverter : private SourceEntityWalker {
52295237
addDo();
52305238
}
52315239

5232-
setNames(Blocks.SuccessBlock, SuccessParams);
5240+
prepareNames(Blocks.SuccessBlock, SuccessParams);
5241+
preparePlaceholdersAndUnwraps(HandlerDesc, SuccessParams, ErrParam,
5242+
/*Success=*/true);
5243+
52335244
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
52345245
HandlerDesc, /*AddDeclarations=*/true);
5235-
5236-
prepareNamesForBody(HandlerDesc, SuccessParams, ErrParams);
52375246
convertNodes(Blocks.SuccessBlock.nodes());
5247+
clearNames(SuccessParams);
52385248

52395249
if (RequireDo) {
5240-
clearParams(SuccessParams);
52415250
// Always use the ErrParam name if none is bound
5242-
setNames(Blocks.ErrorBlock, ErrParams,
5243-
HandlerDesc.Type != HandlerType::RESULT);
5244-
addCatch(ErrParam);
5251+
prepareNames(Blocks.ErrorBlock, llvm::makeArrayRef(ErrParam),
5252+
HandlerDesc.Type != HandlerType::RESULT);
5253+
preparePlaceholdersAndUnwraps(HandlerDesc, SuccessParams, ErrParam,
5254+
/*Success=*/false);
52455255

5246-
prepareNamesForBody(HandlerDesc, ErrParams, SuccessParams);
5247-
addCatchBody(ErrParam, Blocks.ErrorBlock);
5256+
addCatch(ErrParam);
5257+
convertNodes(Blocks.ErrorBlock.nodes());
5258+
OS << "\n" << tok::r_brace;
5259+
clearNames(llvm::makeArrayRef(ErrParam));
52485260
}
5249-
5250-
clearParams(CallbackParams->getArray());
52515261
}
52525262

52535263
void addAwaitCall(const CallExpr *CE, ArrayRef<Expr *> Args,
@@ -5318,44 +5328,56 @@ class AsyncConverter : private SourceEntityWalker {
53185328
OS << tok::l_brace;
53195329
}
53205330

5321-
void addCatchBody(const ParamDecl *ErrParam,
5322-
const ClassifiedBlock &ErrorBlock) {
5323-
convertNodes(ErrorBlock.nodes());
5324-
OS << "\n" << tok::r_brace;
5325-
}
5326-
5327-
void prepareNamesForBody(const AsyncHandlerDesc &HandlerDesc,
5328-
ArrayRef<const ParamDecl *> CurrentParams,
5329-
ArrayRef<const ParamDecl *> OtherParams) {
5331+
void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc,
5332+
ArrayRef<const ParamDecl *> SuccessParams,
5333+
const ParamDecl *ErrParam, bool Success) {
53305334
switch (HandlerDesc.Type) {
53315335
case HandlerType::PARAMS:
5332-
for (auto *Param : CurrentParams) {
5333-
auto Ty = Param->getType();
5334-
if (Ty->getOptionalObjectType()) {
5335-
Unwraps.insert(Param);
5336-
Placeholders.insert(Param);
5336+
if (!Success) {
5337+
if (ErrParam) {
5338+
if (HandlerDesc.shouldUnwrap(ErrParam->getType())) {
5339+
Placeholders.insert(ErrParam);
5340+
Unwraps.insert(ErrParam);
5341+
}
5342+
// Can't use success params in the error body
5343+
Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
5344+
}
5345+
} else {
5346+
for (auto *SuccessParam : SuccessParams) {
5347+
auto Ty = SuccessParam->getType();
5348+
if (HandlerDesc.shouldUnwrap(Ty)) {
5349+
// Either unwrap or replace with a placeholder if there's some other
5350+
// reference
5351+
Unwraps.insert(SuccessParam);
5352+
Placeholders.insert(SuccessParam);
5353+
}
5354+
5355+
// Void parameters get omitted where possible, so turn any reference
5356+
// into a placeholder, as its usage is unlikely what the user wants.
5357+
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
5358+
Placeholders.insert(SuccessParam);
53375359
}
5338-
// Void parameters get omitted where possible, so turn any reference
5339-
// into a placeholder, as its usage is unlikely what the user wants.
5340-
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
5341-
Placeholders.insert(Param);
5360+
// Can't use the error param in the success body
5361+
if (ErrParam)
5362+
Placeholders.insert(ErrParam);
53425363
}
5343-
// Use of the other params is invalid within the current body
5344-
Placeholders.insert(OtherParams.begin(), OtherParams.end());
53455364
break;
53465365
case HandlerType::RESULT:
5347-
// Any uses of the result parameter in the current body (that
5348-
// isn't replaced) are invalid, so replace them with a placeholder
5349-
Placeholders.insert(CurrentParams.begin(), CurrentParams.end());
5366+
// Any uses of the result parameter in the current body (that aren't
5367+
// replaced) are invalid, so replace them with a placeholder.
5368+
assert(SuccessParams.size() == 1 && SuccessParams[0] == ErrParam);
5369+
Placeholders.insert(ErrParam);
53505370
break;
53515371
default:
53525372
llvm_unreachable("Unhandled handler type");
53535373
}
53545374
}
53555375

5356-
// TODO: Check for clashes with existing names
5357-
void setNames(const ClassifiedBlock &Block,
5358-
ArrayRef<const ParamDecl *> Params, bool AddIfMissing = true) {
5376+
// TODO: Check for clashes with existing names and add all decls, not just
5377+
// params
5378+
void prepareNames(const ClassifiedBlock &Block,
5379+
ArrayRef<const ParamDecl *> Params,
5380+
bool AddIfMissing = true) {
53595381
for (auto *Param : Params) {
53605382
StringRef Name = Block.boundName(Param);
53615383
if (!Name.empty()) {
@@ -5384,7 +5406,7 @@ class AsyncConverter : private SourceEntityWalker {
53845406
return StringRef(Res->second);
53855407
}
53865408

5387-
void clearParams(ArrayRef<const ParamDecl *> Params) {
5409+
void clearNames(ArrayRef<const ParamDecl *> Params) {
53885410
for (auto *Param : Params) {
53895411
Unwraps.erase(Param);
53905412
Placeholders.erase(Param);

0 commit comments

Comments
 (0)