Skip to content

[Refactoring] Only unwrap optionals if the handler has an optional error #37235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 104 additions & 82 deletions lib/IDE/Refactoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4284,6 +4284,10 @@ struct AsyncHandlerDesc {
return getSuccessParamAsyncReturnType(Ty)->isVoid();
});
}

bool shouldUnwrap(swift::Type Ty) const {
return HasError && Ty->isOptional();
}
};

enum class ConditionType { INVALID, NIL, NOT_NIL };
Expand Down Expand Up @@ -4549,18 +4553,12 @@ struct CallbackClassifier {
static void classifyInto(ClassifiedBlocks &Blocks,
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
DiagnosticEngine &DiagEngine,
ArrayRef<const ParamDecl *> SuccessParams,
llvm::DenseSet<const Decl *> UnwrapParams,
const ParamDecl *ErrParam, HandlerType ResultType,
ArrayRef<ASTNode> Body) {
assert(!Body.empty() && "Cannot classify empty body");

auto ParamsSet = llvm::DenseSet<const Decl *>(SuccessParams.begin(),
SuccessParams.end());
if (ErrParam)
ParamsSet.insert(ErrParam);

CallbackClassifier Classifier(Blocks, HandledSwitches, DiagEngine,
ParamsSet, ErrParam,
UnwrapParams, ErrParam,
ResultType == HandlerType::RESULT);
Classifier.classifyNodes(Body);
}
Expand All @@ -4570,19 +4568,19 @@ struct CallbackClassifier {
llvm::DenseSet<SwitchStmt *> &HandledSwitches;
DiagnosticEngine &DiagEngine;
ClassifiedBlock *CurrentBlock;
llvm::DenseSet<const Decl *> ParamsSet;
llvm::DenseSet<const Decl *> UnwrapParams;
const ParamDecl *ErrParam;
bool IsResultParam;

CallbackClassifier(ClassifiedBlocks &Blocks,
llvm::DenseSet<SwitchStmt *> &HandledSwitches,
DiagnosticEngine &DiagEngine,
llvm::DenseSet<const Decl *> ParamsSet,
llvm::DenseSet<const Decl *> UnwrapParams,
const ParamDecl *ErrParam, bool IsResultParam)
: Blocks(Blocks), HandledSwitches(HandledSwitches),
DiagEngine(DiagEngine), CurrentBlock(&Blocks.SuccessBlock),
ParamsSet(ParamsSet), ErrParam(ErrParam), IsResultParam(IsResultParam) {
}
UnwrapParams(UnwrapParams), ErrParam(ErrParam),
IsResultParam(IsResultParam) {}

void classifyNodes(ArrayRef<ASTNode> Nodes) {
for (auto I = Nodes.begin(), E = Nodes.end(); I < E; ++I) {
Expand Down Expand Up @@ -4614,7 +4612,7 @@ struct CallbackClassifier {
ArrayRef<ASTNode> ThenNodes, Stmt *ElseStmt) {
llvm::DenseMap<const Decl *, CallbackCondition> CallbackConditions;
bool UnhandledConditions =
!CallbackCondition::all(Condition, ParamsSet, CallbackConditions);
!CallbackCondition::all(Condition, UnwrapParams, CallbackConditions);
CallbackCondition ErrCondition = CallbackConditions.lookup(ErrParam);

if (UnhandledConditions) {
Expand Down Expand Up @@ -4942,7 +4940,7 @@ class AsyncConverter : private SourceEntityWalker {
getUnderlyingFunc(CE->getFn()), StartNode.dyn_cast<Expr *>() == CE);
if (HandlerDesc.isValid())
return addCustom(CE->getSourceRange(),
[&]() { addAsyncAlternativeCall(CE, HandlerDesc); });
[&]() { addHoistedCallback(CE, HandlerDesc); });
}
}

Expand Down Expand Up @@ -5145,8 +5143,8 @@ class AsyncConverter : private SourceEntityWalker {
}
}

void addAsyncAlternativeCall(const CallExpr *CE,
const AsyncHandlerDesc &HandlerDesc) {
void addHoistedCallback(const CallExpr *CE,
const AsyncHandlerDesc &HandlerDesc) {
auto ArgList = callArgs(CE);
if ((size_t)HandlerDesc.Index >= ArgList.ref().size()) {
DiagEngine.diagnose(CE->getStartLoc(), diag::missing_callback_arg);
Expand All @@ -5159,53 +5157,63 @@ class AsyncConverter : private SourceEntityWalker {
return;
}

ParameterList *CallbackParams = Callback->getParameters();
ArrayRef<const ParamDecl *> CallbackParams =
Callback->getParameters()->getArray();
ArrayRef<ASTNode> CallbackBody = Callback->getBody()->getElements();
if (HandlerDesc.params().size() != CallbackParams->size()) {
if (HandlerDesc.params().size() != CallbackParams.size()) {
DiagEngine.diagnose(CE->getStartLoc(), diag::mismatched_callback_args);
return;
}

// Note that the `ErrParam` may be a Result (in which case it's also the
// only element in `SuccessParams`)
ArrayRef<const ParamDecl *> SuccessParams = CallbackParams->getArray();
ArrayRef<const ParamDecl *> SuccessParams = CallbackParams;
const ParamDecl *ErrParam = nullptr;
if (HandlerDesc.HasError) {
if (HandlerDesc.Type == HandlerType::RESULT) {
ErrParam = SuccessParams.back();
if (HandlerDesc.Type == HandlerType::PARAMS)
SuccessParams = SuccessParams.drop_back();
} else if (HandlerDesc.HasError) {
assert(HandlerDesc.Type == HandlerType::PARAMS);
ErrParam = SuccessParams.back();
SuccessParams = SuccessParams.drop_back();
}
ArrayRef<const ParamDecl *> ErrParams;
if (ErrParam)
ErrParams = llvm::makeArrayRef(ErrParam);

ClassifiedBlocks Blocks;
if (!HandlerDesc.HasError) {
Blocks.SuccessBlock.addAllNodes(CallbackBody);
} else if (!CallbackBody.empty()) {
llvm::DenseSet<const Decl *> UnwrapParams;
for (auto *Param : SuccessParams) {
if (HandlerDesc.shouldUnwrap(Param->getType()))
UnwrapParams.insert(Param);
}
if (ErrParam)
UnwrapParams.insert(ErrParam);
CallbackClassifier::classifyInto(Blocks, HandledSwitches, DiagEngine,
SuccessParams, ErrParam,
HandlerDesc.Type, CallbackBody);
if (DiagEngine.hadAnyError()) {
// Can only fallback when the results are params, in which case only
// the names are used (defaulted to the names of the params if none)
if (HandlerDesc.Type != HandlerType::PARAMS)
return;
DiagEngine.resetHadAnyError();
UnwrapParams, ErrParam, HandlerDesc.Type,
CallbackBody);
}

setNames(ClassifiedBlock(), CallbackParams->getArray());
if (DiagEngine.hadAnyError()) {
// Can only fallback when the results are params, in which case only
// the names are used (defaulted to the names of the params if none)
if (HandlerDesc.Type != HandlerType::PARAMS)
return;
DiagEngine.resetHadAnyError();

addFallbackVars(CallbackParams->getArray(), Blocks);
addDo();
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
HandlerDesc, /*AddDeclarations=*/!HandlerDesc.HasError);
addFallbackCatch(ErrParam);
OS << "\n";
convertNodes(CallbackBody);
// Don't do any unwrapping or placeholder replacement since all params
// are still valid in the fallback case
prepareNames(ClassifiedBlock(), CallbackParams);

clearParams(CallbackParams->getArray());
return;
}
addFallbackVars(CallbackParams, Blocks);
addDo();
addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
HandlerDesc, /*AddDeclarations=*/!HandlerDesc.HasError);
addFallbackCatch(ErrParam);
OS << "\n";
convertNodes(CallbackBody);

clearNames(CallbackParams);
return;
}

bool RequireDo = !Blocks.ErrorBlock.nodes().empty();
Expand All @@ -5229,25 +5237,27 @@ class AsyncConverter : private SourceEntityWalker {
addDo();
}

setNames(Blocks.SuccessBlock, SuccessParams);
prepareNames(Blocks.SuccessBlock, SuccessParams);
preparePlaceholdersAndUnwraps(HandlerDesc, SuccessParams, ErrParam,
/*Success=*/true);

addAwaitCall(CE, ArgList.ref(), Blocks.SuccessBlock, SuccessParams,
HandlerDesc, /*AddDeclarations=*/true);

prepareNamesForBody(HandlerDesc, SuccessParams, ErrParams);
convertNodes(Blocks.SuccessBlock.nodes());
clearNames(SuccessParams);

if (RequireDo) {
clearParams(SuccessParams);
// Always use the ErrParam name if none is bound
setNames(Blocks.ErrorBlock, ErrParams,
HandlerDesc.Type != HandlerType::RESULT);
addCatch(ErrParam);
prepareNames(Blocks.ErrorBlock, llvm::makeArrayRef(ErrParam),
HandlerDesc.Type != HandlerType::RESULT);
preparePlaceholdersAndUnwraps(HandlerDesc, SuccessParams, ErrParam,
/*Success=*/false);

prepareNamesForBody(HandlerDesc, ErrParams, SuccessParams);
addCatchBody(ErrParam, Blocks.ErrorBlock);
addCatch(ErrParam);
convertNodes(Blocks.ErrorBlock.nodes());
OS << "\n" << tok::r_brace;
clearNames(llvm::makeArrayRef(ErrParam));
}

clearParams(CallbackParams->getArray());
}

void addAwaitCall(const CallExpr *CE, ArrayRef<Expr *> Args,
Expand Down Expand Up @@ -5318,44 +5328,56 @@ class AsyncConverter : private SourceEntityWalker {
OS << tok::l_brace;
}

void addCatchBody(const ParamDecl *ErrParam,
const ClassifiedBlock &ErrorBlock) {
convertNodes(ErrorBlock.nodes());
OS << "\n" << tok::r_brace;
}

void prepareNamesForBody(const AsyncHandlerDesc &HandlerDesc,
ArrayRef<const ParamDecl *> CurrentParams,
ArrayRef<const ParamDecl *> OtherParams) {
void preparePlaceholdersAndUnwraps(AsyncHandlerDesc HandlerDesc,
ArrayRef<const ParamDecl *> SuccessParams,
const ParamDecl *ErrParam, bool Success) {
switch (HandlerDesc.Type) {
case HandlerType::PARAMS:
for (auto *Param : CurrentParams) {
auto Ty = Param->getType();
if (Ty->getOptionalObjectType()) {
Unwraps.insert(Param);
Placeholders.insert(Param);
if (!Success) {
if (ErrParam) {
if (HandlerDesc.shouldUnwrap(ErrParam->getType())) {
Placeholders.insert(ErrParam);
Unwraps.insert(ErrParam);
}
// Can't use success params in the error body
Placeholders.insert(SuccessParams.begin(), SuccessParams.end());
}
} else {
for (auto *SuccessParam : SuccessParams) {
auto Ty = SuccessParam->getType();
if (HandlerDesc.shouldUnwrap(Ty)) {
// Either unwrap or replace with a placeholder if there's some other
// reference
Unwraps.insert(SuccessParam);
Placeholders.insert(SuccessParam);
}

// Void parameters get omitted where possible, so turn any reference
// into a placeholder, as its usage is unlikely what the user wants.
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
Placeholders.insert(SuccessParam);
}
// Void parameters get omitted where possible, so turn any reference
// into a placeholder, as its usage is unlikely what the user wants.
if (HandlerDesc.getSuccessParamAsyncReturnType(Ty)->isVoid())
Placeholders.insert(Param);
// Can't use the error param in the success body
if (ErrParam)
Placeholders.insert(ErrParam);
}
// Use of the other params is invalid within the current body
Placeholders.insert(OtherParams.begin(), OtherParams.end());
break;
case HandlerType::RESULT:
// Any uses of the result parameter in the current body (that
// isn't replaced) are invalid, so replace them with a placeholder
Placeholders.insert(CurrentParams.begin(), CurrentParams.end());
// Any uses of the result parameter in the current body (that aren't
// replaced) are invalid, so replace them with a placeholder.
assert(SuccessParams.size() == 1 && SuccessParams[0] == ErrParam);
Placeholders.insert(ErrParam);
break;
default:
llvm_unreachable("Unhandled handler type");
}
}

// TODO: Check for clashes with existing names
void setNames(const ClassifiedBlock &Block,
ArrayRef<const ParamDecl *> Params, bool AddIfMissing = true) {
// TODO: Check for clashes with existing names and add all decls, not just
// params
void prepareNames(const ClassifiedBlock &Block,
ArrayRef<const ParamDecl *> Params,
bool AddIfMissing = true) {
for (auto *Param : Params) {
StringRef Name = Block.boundName(Param);
if (!Name.empty()) {
Expand Down Expand Up @@ -5384,7 +5406,7 @@ class AsyncConverter : private SourceEntityWalker {
return StringRef(Res->second);
}

void clearParams(ArrayRef<const ParamDecl *> Params) {
void clearNames(ArrayRef<const ParamDecl *> Params) {
for (auto *Param : Params) {
Unwraps.erase(Param);
Placeholders.erase(Param);
Expand Down
Loading