Skip to content

Commit ab85807

Browse files
authored
Merge pull request #41136 from louisdh/refactoring-codable
[Refactoring] Add Codable refactoring action
2 parents 98be00b + 3f3e643 commit ab85807

20 files changed

+573
-6
lines changed

include/swift/AST/Decl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,8 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl> {
861861
void print(raw_ostream &OS) const;
862862
void print(raw_ostream &OS, const PrintOptions &Opts) const;
863863

864+
void printInherited(ASTPrinter &Printer, const PrintOptions &Options) const;
865+
864866
/// Pretty-print the given declaration.
865867
///
866868
/// \param Printer ASTPrinter object.

include/swift/AST/PrintOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,10 @@ struct PrintOptions {
486486
/// Whether to print inheritance lists for types.
487487
bool PrintInherited = true;
488488

489+
/// Whether to print a space before the `:` of an inheritance list in a type
490+
/// decl.
491+
bool PrintSpaceBeforeInheritance = true;
492+
489493
/// Whether to print feature checks for compatibility with older Swift
490494
/// compilers that might parse the result.
491495
bool PrintCompatibilityFeatureChecks = false;

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ CURSOR_REFACTORING(MemberwiseInitLocalRefactoring, "Generate Memberwise Initiali
5454

5555
CURSOR_REFACTORING(AddEquatableConformance, "Add Equatable Conformance", add.equatable.conformance)
5656

57+
CURSOR_REFACTORING(AddExplicitCodableImplementation, "Add Explicit Codable Implementation", add.explicit-codable-implementation)
58+
5759
CURSOR_REFACTORING(ConvertCallToAsyncAlternative, "Convert Call to Async Alternative", convert.call-to-async)
5860

5961
CURSOR_REFACTORING(ConvertToAsync, "Convert Function to Async", convert.func-to-async)

lib/AST/ASTPrinter.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,7 @@ class PrintAST : public ASTVisitor<PrintAST> {
863863
Decl *attachingTo);
864864
void printWhereClauseFromRequirementSignature(ProtocolDecl *proto,
865865
Decl *attachingTo);
866+
void printInherited(const Decl *decl);
866867

867868
void printGenericSignature(GenericSignature genericSig,
868869
unsigned flags);
@@ -891,7 +892,6 @@ class PrintAST : public ASTVisitor<PrintAST> {
891892
bool openBracket = true, bool closeBracket = true);
892893
void printGenericDeclGenericParams(GenericContext *decl);
893894
void printDeclGenericRequirements(GenericContext *decl);
894-
void printInherited(const Decl *decl);
895895
void printBodyIfNecessary(const AbstractFunctionDecl *decl);
896896

897897
void printEnumElement(EnumElementDecl *elt);
@@ -2416,7 +2416,10 @@ void PrintAST::printInherited(const Decl *decl) {
24162416
if (TypesToPrint.empty())
24172417
return;
24182418

2419-
Printer << " : ";
2419+
if (Options.PrintSpaceBeforeInheritance) {
2420+
Printer << " ";
2421+
}
2422+
Printer << ": ";
24202423

24212424
interleave(TypesToPrint, [&](InheritedEntry inherited) {
24222425
if (inherited.isUnchecked)
@@ -4195,7 +4198,12 @@ void PrintAST::visitLoadExpr(LoadExpr *expr) {
41954198
}
41964199

41974200
void PrintAST::visitTypeExpr(TypeExpr *expr) {
4198-
printType(expr->getType());
4201+
if (auto metaType = expr->getType()->castTo<AnyMetatypeType>()) {
4202+
// Don't print `.Type` for an expr.
4203+
printType(metaType->getInstanceType());
4204+
} else {
4205+
printType(expr->getType());
4206+
}
41994207
}
42004208

42014209
void PrintAST::visitArrayExpr(ArrayExpr *expr) {
@@ -4278,6 +4286,8 @@ void PrintAST::visitBinaryExpr(BinaryExpr *expr) {
42784286
Printer << " ";
42794287
if (auto operatorRef = expr->getFn()->getMemberOperatorRef()) {
42804288
Printer << operatorRef->getDecl()->getBaseName();
4289+
} else if (auto *operatorRef = dyn_cast<DeclRefExpr>(expr->getFn())) {
4290+
Printer << operatorRef->getDecl()->getBaseName();
42814291
}
42824292
Printer << " ";
42834293
visit(expr->getRHS());
@@ -4588,6 +4598,16 @@ void PrintAST::visitBraceStmt(BraceStmt *stmt) {
45884598
}
45894599

45904600
void PrintAST::visitReturnStmt(ReturnStmt *stmt) {
4601+
if (!stmt->hasResult()) {
4602+
if (auto *FD = dyn_cast<AbstractFunctionDecl>(Current)) {
4603+
if (auto *Body = FD->getBody()) {
4604+
if (Body->getLastElement().dyn_cast<Stmt *>() == stmt) {
4605+
// Don't print empty return.
4606+
return;
4607+
}
4608+
}
4609+
}
4610+
}
45914611
Printer << tok::kw_return;
45924612
if (stmt->hasResult()) {
45934613
Printer << " ";
@@ -4777,6 +4797,11 @@ bool Decl::print(ASTPrinter &Printer, const PrintOptions &Opts) const {
47774797
return printer.visit(const_cast<Decl *>(this));
47784798
}
47794799

4800+
void Decl::printInherited(ASTPrinter &Printer, const PrintOptions &Opts) const {
4801+
PrintAST printer(Printer, Opts);
4802+
printer.printInherited(this);
4803+
}
4804+
47804805
bool Decl::shouldPrintInContext(const PrintOptions &PO) const {
47814806
// Skip getters/setters. They are part of the variable or subscript.
47824807
if (isa<AccessorDecl>(this))

lib/IDE/Refactoring.cpp

Lines changed: 174 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3354,7 +3354,8 @@ class AddEquatableContext {
33543354
AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
33553355
Protocols(), StoredProperties(), Range(nullptr, nullptr) {};
33563356

3357-
static AddEquatableContext getDeclarationContextFromInfo(ResolvedCursorInfo Info);
3357+
static AddEquatableContext
3358+
getDeclarationContextFromInfo(const ResolvedCursorInfo &Info);
33583359

33593360
std::string getInsertionTextForProtocol();
33603361

@@ -3468,7 +3469,7 @@ getProtocolRequirements() {
34683469
}
34693470

34703471
AddEquatableContext AddEquatableContext::
3471-
getDeclarationContextFromInfo(ResolvedCursorInfo Info) {
3472+
getDeclarationContextFromInfo(const ResolvedCursorInfo &Info) {
34723473
if (Info.isInvalid()) {
34733474
return AddEquatableContext();
34743475
}
@@ -3526,6 +3527,177 @@ performChange() {
35263527
return false;
35273528
}
35283529

3530+
class AddCodableContext {
3531+
3532+
/// Declaration context
3533+
DeclContext *DC;
3534+
3535+
/// Start location of declaration context brace
3536+
SourceLoc StartLoc;
3537+
3538+
/// Array of all conformed protocols
3539+
SmallVector<swift::ProtocolDecl *, 2> Protocols;
3540+
3541+
/// Range of internal members in declaration
3542+
DeclRange Range;
3543+
3544+
bool conformsToCodableProtocol() {
3545+
for (ProtocolDecl *Protocol : Protocols) {
3546+
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable ||
3547+
Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) {
3548+
return true;
3549+
}
3550+
}
3551+
return false;
3552+
}
3553+
3554+
public:
3555+
AddCodableContext(NominalTypeDecl *Decl)
3556+
: DC(Decl), StartLoc(Decl->getBraces().Start),
3557+
Protocols(Decl->getAllProtocols()), Range(Decl->getMembers()){};
3558+
3559+
AddCodableContext(ExtensionDecl *Decl)
3560+
: DC(Decl), StartLoc(Decl->getBraces().Start),
3561+
Protocols(Decl->getExtendedNominal()->getAllProtocols()),
3562+
Range(Decl->getMembers()){};
3563+
3564+
AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){};
3565+
3566+
static AddCodableContext
3567+
getDeclarationContextFromInfo(const ResolvedCursorInfo &Info);
3568+
3569+
void printInsertionText(const ResolvedCursorInfo &CursorInfo,
3570+
SourceManager &SM, llvm::raw_ostream &OS);
3571+
3572+
bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); }
3573+
3574+
SourceLoc getInsertStartLoc();
3575+
};
3576+
3577+
SourceLoc AddCodableContext::getInsertStartLoc() {
3578+
SourceLoc MaxLoc = StartLoc;
3579+
for (auto Mem : Range) {
3580+
if (Mem->getEndLoc().getOpaquePointerValue() >
3581+
MaxLoc.getOpaquePointerValue()) {
3582+
MaxLoc = Mem->getEndLoc();
3583+
}
3584+
}
3585+
return MaxLoc;
3586+
}
3587+
3588+
/// Walks an AST and prints the synthesized Codable implementation.
3589+
class SynthesizedCodablePrinter : public ASTWalker {
3590+
private:
3591+
ASTPrinter &Printer;
3592+
3593+
public:
3594+
SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {}
3595+
3596+
bool walkToDeclPre(Decl *D) override {
3597+
auto *VD = dyn_cast<ValueDecl>(D);
3598+
if (!VD)
3599+
return false;
3600+
3601+
if (!VD->isSynthesized()) {
3602+
return true;
3603+
}
3604+
SmallString<32> Scratch;
3605+
auto name = VD->getName().getString(Scratch);
3606+
// Print all synthesized enums,
3607+
// since Codable can synthesize multiple enums (for associated values).
3608+
auto shouldPrint =
3609+
isa<EnumDecl>(VD) || name == "init(from:)" || name == "encode(to:)";
3610+
if (!shouldPrint) {
3611+
// Some other synthesized decl that we don't want to print.
3612+
return false;
3613+
}
3614+
3615+
Printer.printNewline();
3616+
3617+
if (auto enumDecl = dyn_cast<EnumDecl>(D)) {
3618+
// Manually print enum here, since we don't want to print synthesized
3619+
// functions.
3620+
Printer << "enum " << enumDecl->getNameStr();
3621+
PrintOptions Options;
3622+
Options.PrintSpaceBeforeInheritance = false;
3623+
enumDecl->printInherited(Printer, Options);
3624+
Printer << " {";
3625+
for (Decl *EC : enumDecl->getAllElements()) {
3626+
Printer.printNewline();
3627+
Printer << " ";
3628+
EC->print(Printer, Options);
3629+
}
3630+
Printer.printNewline();
3631+
Printer << "}";
3632+
return false;
3633+
}
3634+
3635+
PrintOptions Options;
3636+
Options.SynthesizeSugarOnTypes = true;
3637+
Options.FunctionDefinitions = true;
3638+
Options.VarInitializers = true;
3639+
Options.PrintExprs = true;
3640+
Options.TypeDefinitions = true;
3641+
Options.ExcludeAttrList.push_back(DAK_HasInitialValue);
3642+
3643+
Printer.printNewline();
3644+
D->print(Printer, Options);
3645+
3646+
return false;
3647+
}
3648+
};
3649+
3650+
void AddCodableContext::printInsertionText(const ResolvedCursorInfo &CursorInfo,
3651+
SourceManager &SM,
3652+
llvm::raw_ostream &OS) {
3653+
StringRef ExtraIndent;
3654+
StringRef CurrentIndent =
3655+
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
3656+
std::string Indent;
3657+
if (getInsertStartLoc() == StartLoc) {
3658+
Indent = (CurrentIndent + ExtraIndent).str();
3659+
} else {
3660+
Indent = CurrentIndent.str();
3661+
}
3662+
3663+
ExtraIndentStreamPrinter Printer(OS, Indent);
3664+
Printer.printNewline();
3665+
SynthesizedCodablePrinter Walker(Printer);
3666+
DC->getAsDecl()->walk(Walker);
3667+
}
3668+
3669+
AddCodableContext AddCodableContext::getDeclarationContextFromInfo(
3670+
const ResolvedCursorInfo &Info) {
3671+
if (Info.isInvalid()) {
3672+
return AddCodableContext();
3673+
}
3674+
if (!Info.IsRef) {
3675+
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD)) {
3676+
return AddCodableContext(NomDecl);
3677+
}
3678+
}
3679+
// TODO: support extensions
3680+
// (would need to get synthesized nodes from the main decl,
3681+
// and only if it's in the same file?)
3682+
return AddCodableContext();
3683+
}
3684+
3685+
bool RefactoringActionAddExplicitCodableImplementation::isApplicable(
3686+
const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) {
3687+
return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid();
3688+
}
3689+
3690+
bool RefactoringActionAddExplicitCodableImplementation::performChange() {
3691+
auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo);
3692+
3693+
SmallString<64> Buffer;
3694+
llvm::raw_svector_ostream OS(Buffer);
3695+
Context.printInsertionText(CursorInfo, SM, OS);
3696+
3697+
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str());
3698+
return false;
3699+
}
3700+
35293701
static CharSourceRange
35303702
findSourceRangeToWrapInCatch(const ResolvedCursorInfo &CursorInfo,
35313703
SourceFile *TheFile,

lib/Sema/DerivedConformanceCodable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ deriveBodyDecodable_enum_init(AbstractFunctionDecl *initDecl, void *) {
16691669
auto *nestedContainerDecl = createKeyedContainer(
16701670
C, funcDC, C.getKeyedDecodingContainerDecl(),
16711671
caseCodingKeys->getDeclaredInterfaceType(),
1672-
VarDecl::Introducer::Var, C.Id_nestedContainer);
1672+
VarDecl::Introducer::Let, C.Id_nestedContainer);
16731673

16741674
auto *nestedContainerCall = createNestedContainerKeyedByForKeyCall(
16751675
C, funcDC, containerExpr, caseCodingKeys, codingKeyCase);

test/expr/print/callexpr.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,8 @@ test(a: 0, b: 0, c: 0)
88
// CHECK: test(a: 0)
99
// CHECK: test(a: 0, b: 0)
1010
// CHECK: test(a: 0, b: 0, c: 0)
11+
12+
class Constants { static var myConst = 0 }
13+
func test(foo: Int) { }
14+
// CHECK-LABEL: test(foo: Constants.myConst)
15+
test(foo: Constants.myConst)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
struct User: Codable {
3+
let firstName: String
4+
let lastName: String?
5+
6+
enum CodingKeys: CodingKey {
7+
case firstName
8+
case lastName
9+
}
10+
11+
init(from decoder: Decoder) throws {
12+
let container: KeyedDecodingContainer<User.CodingKeys> = try decoder.container(keyedBy: User.CodingKeys.self)
13+
14+
self.firstName = try container.decode(String.self, forKey: User.CodingKeys.firstName)
15+
self.lastName = try container.decodeIfPresent(String.self, forKey: User.CodingKeys.lastName)
16+
17+
}
18+
19+
func encode(to encoder: Encoder) throws {
20+
var container = encoder.container(keyedBy: User.CodingKeys.self)
21+
22+
try container.encode(self.firstName, forKey: User.CodingKeys.firstName)
23+
try container.encodeIfPresent(self.lastName, forKey: User.CodingKeys.lastName)
24+
}
25+
}
26+
27+
struct User_D: Decodable {
28+
let firstName: String
29+
let lastName: String?
30+
}
31+
32+
struct User_E: Encodable {
33+
let firstName: String
34+
let lastName: String?
35+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
struct User: Codable {
3+
let firstName: String
4+
let lastName: String?
5+
}
6+
7+
struct User_D: Decodable {
8+
let firstName: String
9+
let lastName: String?
10+
11+
enum CodingKeys: CodingKey {
12+
case firstName
13+
case lastName
14+
}
15+
16+
init(from decoder: Decoder) throws {
17+
let container: KeyedDecodingContainer<User_D.CodingKeys> = try decoder.container(keyedBy: User_D.CodingKeys.self)
18+
19+
self.firstName = try container.decode(String.self, forKey: User_D.CodingKeys.firstName)
20+
self.lastName = try container.decodeIfPresent(String.self, forKey: User_D.CodingKeys.lastName)
21+
22+
}
23+
}
24+
25+
struct User_E: Encodable {
26+
let firstName: String
27+
let lastName: String?
28+
}

0 commit comments

Comments
 (0)