Skip to content

Commit 5d36507

Browse files
committed
[Refactoring] Add Codable refactoring action
Inserts the synthesized implementation. As part of this, fix some ASTPrinter bugs. rdar://87904700
1 parent 7ebdb8e commit 5d36507

19 files changed

+558
-3
lines changed

include/swift/AST/Decl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,8 @@ class alignas(1 << DeclAlignInBits) Decl : public ASTAllocated<Decl> {
858858
void print(raw_ostream &OS) const;
859859
void print(raw_ostream &OS, const PrintOptions &Opts) const;
860860

861+
void printInherited(ASTPrinter &Printer, const PrintOptions &Options) const;
862+
861863
/// Pretty-print the given declaration.
862864
///
863865
/// \param Printer ASTPrinter object.

include/swift/AST/PrintOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,10 @@ struct PrintOptions {
479479
/// Whether to print inheritance lists for types.
480480
bool PrintInherited = true;
481481

482+
/// Whether to print a space before the `:` of an inheritance list in a type
483+
/// decl.
484+
bool PrintSpaceBeforeInheritance = true;
485+
482486
/// Whether to print feature checks for compatibility with older Swift
483487
/// compilers that might parse the result.
484488
bool PrintCompatibilityFeatureChecks = false;

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ CURSOR_REFACTORING(MemberwiseInitLocalRefactoring, "Generate Memberwise Initiali
5454

5555
CURSOR_REFACTORING(AddEquatableConformance, "Add Equatable Conformance", add.equatable.conformance)
5656

57+
CURSOR_REFACTORING(AddExplicitCodableImplementation, "Add Explicit Codable Implementation", add.explicit-codable-implementation)
58+
5759
CURSOR_REFACTORING(ConvertCallToAsyncAlternative, "Convert Call to Async Alternative", convert.call-to-async)
5860

5961
CURSOR_REFACTORING(ConvertToAsync, "Convert Function to Async", convert.func-to-async)

lib/AST/ASTPrinter.cpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,7 @@ class PrintAST : public ASTVisitor<PrintAST> {
861861
Decl *attachingTo);
862862
void printWhereClauseFromRequirementSignature(ProtocolDecl *proto,
863863
Decl *attachingTo);
864+
void printInherited(const Decl *decl);
864865

865866
void printGenericSignature(GenericSignature genericSig,
866867
unsigned flags);
@@ -889,7 +890,6 @@ class PrintAST : public ASTVisitor<PrintAST> {
889890
bool openBracket = true, bool closeBracket = true);
890891
void printGenericDeclGenericParams(GenericContext *decl);
891892
void printDeclGenericRequirements(GenericContext *decl);
892-
void printInherited(const Decl *decl);
893893
void printBodyIfNecessary(const AbstractFunctionDecl *decl);
894894

895895
void printEnumElement(EnumElementDecl *elt);
@@ -2372,7 +2372,10 @@ void PrintAST::printInherited(const Decl *decl) {
23722372
if (TypesToPrint.empty())
23732373
return;
23742374

2375-
Printer << " : ";
2375+
if (Options.PrintSpaceBeforeInheritance) {
2376+
Printer << " ";
2377+
}
2378+
Printer << ": ";
23762379

23772380
interleave(TypesToPrint, [&](InheritedEntry inherited) {
23782381
if (inherited.isUnchecked)
@@ -4551,6 +4554,17 @@ void PrintAST::visitBraceStmt(BraceStmt *stmt) {
45514554
}
45524555

45534556
void PrintAST::visitReturnStmt(ReturnStmt *stmt) {
4557+
if (!stmt->hasResult()) {
4558+
if (auto *FD = dyn_cast<AbstractFunctionDecl>(Current)) {
4559+
if (FD->hasBody()) {
4560+
auto *Body = FD->getBody();
4561+
if (Body->getLastElement().dyn_cast<Stmt *>() == stmt) {
4562+
// Don't print empty return.
4563+
return;
4564+
}
4565+
}
4566+
}
4567+
}
45544568
Printer << tok::kw_return;
45554569
if (stmt->hasResult()) {
45564570
Printer << " ";
@@ -4740,6 +4754,11 @@ bool Decl::print(ASTPrinter &Printer, const PrintOptions &Opts) const {
47404754
return printer.visit(const_cast<Decl *>(this));
47414755
}
47424756

4757+
void Decl::printInherited(ASTPrinter &Printer, const PrintOptions &Opts) const {
4758+
PrintAST printer(Printer, Opts);
4759+
printer.printInherited(this);
4760+
}
4761+
47434762
bool Decl::shouldPrintInContext(const PrintOptions &PO) const {
47444763
// Skip getters/setters. They are part of the variable or subscript.
47454764
if (isa<AccessorDecl>(this))

lib/IDE/Refactoring.cpp

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3526,6 +3526,177 @@ performChange() {
35263526
return false;
35273527
}
35283528

3529+
class AddCodableContext {
3530+
3531+
/// Declaration context
3532+
DeclContext *DC;
3533+
3534+
/// Start location of declaration context brace
3535+
SourceLoc StartLoc;
3536+
3537+
/// Array of all conformed protocols
3538+
SmallVector<swift::ProtocolDecl *, 2> Protocols;
3539+
3540+
/// Range of internal members in declaration
3541+
DeclRange Range;
3542+
3543+
bool conformsToCodableProtocol() {
3544+
for (ProtocolDecl *Protocol : Protocols) {
3545+
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Encodable ||
3546+
Protocol->getKnownProtocolKind() == KnownProtocolKind::Decodable) {
3547+
return true;
3548+
}
3549+
}
3550+
return false;
3551+
}
3552+
3553+
public:
3554+
AddCodableContext(NominalTypeDecl *Decl)
3555+
: DC(Decl), StartLoc(Decl->getBraces().Start),
3556+
Protocols(Decl->getAllProtocols()), Range(Decl->getMembers()){};
3557+
3558+
AddCodableContext(ExtensionDecl *Decl)
3559+
: DC(Decl), StartLoc(Decl->getBraces().Start),
3560+
Protocols(Decl->getExtendedNominal()->getAllProtocols()),
3561+
Range(Decl->getMembers()){};
3562+
3563+
AddCodableContext() : DC(nullptr), Protocols(), Range(nullptr, nullptr){};
3564+
3565+
static AddCodableContext
3566+
getDeclarationContextFromInfo(const ResolvedCursorInfo &Info);
3567+
3568+
void printInsertionText(const ResolvedCursorInfo &CursorInfo,
3569+
SourceManager &SM, llvm::raw_ostream &OS);
3570+
3571+
bool isValid() { return StartLoc.isValid() && conformsToCodableProtocol(); }
3572+
3573+
SourceLoc getInsertStartLoc();
3574+
};
3575+
3576+
SourceLoc AddCodableContext::getInsertStartLoc() {
3577+
SourceLoc MaxLoc = StartLoc;
3578+
for (auto Mem : Range) {
3579+
if (Mem->getEndLoc().getOpaquePointerValue() >
3580+
MaxLoc.getOpaquePointerValue()) {
3581+
MaxLoc = Mem->getEndLoc();
3582+
}
3583+
}
3584+
return MaxLoc;
3585+
}
3586+
3587+
/// Walks an AST and prints the synthesized Codable implementation.
3588+
class SynthesizedCodablePrinter : public ASTWalker {
3589+
private:
3590+
ASTPrinter &Printer;
3591+
3592+
public:
3593+
SynthesizedCodablePrinter(ASTPrinter &Printer) : Printer(Printer) {}
3594+
3595+
bool walkToDeclPre(Decl *D) override {
3596+
auto *VD = dyn_cast<ValueDecl>(D);
3597+
if (!VD)
3598+
return false;
3599+
3600+
if (!VD->isSynthesized()) {
3601+
return true;
3602+
}
3603+
SmallString<32> Scratch;
3604+
auto name = VD->getName().getString(Scratch);
3605+
// Print all synthesized enums,
3606+
// since Codable can synthesize multiple enums (for associated values).
3607+
auto shouldPrint =
3608+
isa<EnumDecl>(VD) || name == "init(from:)" || name == "encode(to:)";
3609+
if (!shouldPrint) {
3610+
// Some other synthesized decl that we don't want to print.
3611+
return false;
3612+
}
3613+
3614+
Printer.printNewline();
3615+
3616+
if (auto enumDecl = dyn_cast<EnumDecl>(D)) {
3617+
// Manually print enum here, since we don't want to print synthesized
3618+
// functions.
3619+
Printer << "enum " << enumDecl->getNameStr();
3620+
PrintOptions Options;
3621+
Options.PrintSpaceBeforeInheritance = false;
3622+
enumDecl->printInherited(Printer, Options);
3623+
Printer << " {";
3624+
for (Decl *EC : enumDecl->getAllElements()) {
3625+
Printer.printNewline();
3626+
Printer << " ";
3627+
EC->print(Printer, Options);
3628+
}
3629+
Printer.printNewline();
3630+
Printer << "}";
3631+
return false;
3632+
}
3633+
3634+
PrintOptions Options;
3635+
Options.SynthesizeSugarOnTypes = true;
3636+
Options.FunctionDefinitions = true;
3637+
Options.VarInitializers = true;
3638+
Options.PrintExprs = true;
3639+
Options.TypeDefinitions = true;
3640+
Options.ExcludeAttrList.push_back(DAK_HasInitialValue);
3641+
3642+
Printer.printNewline();
3643+
D->print(Printer, Options);
3644+
3645+
return false;
3646+
}
3647+
};
3648+
3649+
void AddCodableContext::printInsertionText(const ResolvedCursorInfo &CursorInfo,
3650+
SourceManager &SM,
3651+
llvm::raw_ostream &OS) {
3652+
StringRef ExtraIndent;
3653+
StringRef CurrentIndent =
3654+
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
3655+
std::string Indent;
3656+
if (getInsertStartLoc() == StartLoc) {
3657+
Indent = (CurrentIndent + ExtraIndent).str();
3658+
} else {
3659+
Indent = CurrentIndent.str();
3660+
}
3661+
3662+
ExtraIndentStreamPrinter Printer(OS, Indent);
3663+
Printer.printNewline();
3664+
SynthesizedCodablePrinter Walker(Printer);
3665+
DC->getAsDecl()->walk(Walker);
3666+
}
3667+
3668+
AddCodableContext AddCodableContext::getDeclarationContextFromInfo(
3669+
const ResolvedCursorInfo &Info) {
3670+
if (Info.isInvalid()) {
3671+
return AddCodableContext();
3672+
}
3673+
if (!Info.IsRef) {
3674+
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD)) {
3675+
return AddCodableContext(NomDecl);
3676+
}
3677+
}
3678+
// TODO: support extensions
3679+
// (would need to get synthesized nodes from the main decl,
3680+
// and only if it's in the same file?)
3681+
return AddCodableContext();
3682+
}
3683+
3684+
bool RefactoringActionAddExplicitCodableImplementation::isApplicable(
3685+
const ResolvedCursorInfo &Tok, DiagnosticEngine &Diag) {
3686+
return AddCodableContext::getDeclarationContextFromInfo(Tok).isValid();
3687+
}
3688+
3689+
bool RefactoringActionAddExplicitCodableImplementation::performChange() {
3690+
auto Context = AddCodableContext::getDeclarationContextFromInfo(CursorInfo);
3691+
3692+
SmallString<64> Buffer;
3693+
llvm::raw_svector_ostream OS(Buffer);
3694+
Context.printInsertionText(CursorInfo, SM, OS);
3695+
3696+
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), OS.str());
3697+
return false;
3698+
}
3699+
35293700
static CharSourceRange
35303701
findSourceRangeToWrapInCatch(const ResolvedCursorInfo &CursorInfo,
35313702
SourceFile *TheFile,

lib/Sema/DerivedConformanceCodable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1669,7 +1669,7 @@ deriveBodyDecodable_enum_init(AbstractFunctionDecl *initDecl, void *) {
16691669
auto *nestedContainerDecl = createKeyedContainer(
16701670
C, funcDC, C.getKeyedDecodingContainerDecl(),
16711671
caseCodingKeys->getDeclaredInterfaceType(),
1672-
VarDecl::Introducer::Var, C.Id_nestedContainer);
1672+
VarDecl::Introducer::Let, C.Id_nestedContainer);
16731673

16741674
auto *nestedContainerCall = createNestedContainerKeyedByForKeyCall(
16751675
C, funcDC, containerExpr, caseCodingKeys, codingKeyCase);
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
struct User: Codable {
3+
let firstName: String
4+
let lastName: String?
5+
6+
enum CodingKeys: CodingKey {
7+
case firstName
8+
case lastName
9+
}
10+
11+
init(from decoder: Decoder) throws {
12+
let container: KeyedDecodingContainer<User.CodingKeys> = try decoder.container(keyedBy: User.CodingKeys.self)
13+
14+
self.firstName = try container.decode(String.self, forKey: User.CodingKeys.firstName)
15+
self.lastName = try container.decodeIfPresent(String.self, forKey: User.CodingKeys.lastName)
16+
17+
}
18+
19+
func encode(to encoder: Encoder) throws {
20+
var container = encoder.container(keyedBy: User.CodingKeys.self)
21+
22+
try container.encode(self.firstName, forKey: User.CodingKeys.firstName)
23+
try container.encodeIfPresent(self.lastName, forKey: User.CodingKeys.lastName)
24+
}
25+
}
26+
27+
struct User_D: Decodable {
28+
let firstName: String
29+
let lastName: String?
30+
}
31+
32+
struct User_E: Encodable {
33+
let firstName: String
34+
let lastName: String?
35+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
struct User: Codable {
3+
let firstName: String
4+
let lastName: String?
5+
}
6+
7+
struct User_D: Decodable {
8+
let firstName: String
9+
let lastName: String?
10+
11+
enum CodingKeys: CodingKey {
12+
case firstName
13+
case lastName
14+
}
15+
16+
init(from decoder: Decoder) throws {
17+
let container: KeyedDecodingContainer<User_D.CodingKeys> = try decoder.container(keyedBy: User_D.CodingKeys.self)
18+
19+
self.firstName = try container.decode(String.self, forKey: User_D.CodingKeys.firstName)
20+
self.lastName = try container.decodeIfPresent(String.self, forKey: User_D.CodingKeys.lastName)
21+
22+
}
23+
}
24+
25+
struct User_E: Encodable {
26+
let firstName: String
27+
let lastName: String?
28+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
struct User: Codable {
3+
let firstName: String
4+
let lastName: String?
5+
}
6+
7+
struct User_D: Decodable {
8+
let firstName: String
9+
let lastName: String?
10+
}
11+
12+
struct User_E: Encodable {
13+
let firstName: String
14+
let lastName: String?
15+
16+
enum CodingKeys: CodingKey {
17+
case firstName
18+
case lastName
19+
}
20+
21+
func encode(to encoder: Encoder) throws {
22+
var container: KeyedEncodingContainer<User_E.CodingKeys> = encoder.container(keyedBy: User_E.CodingKeys.self)
23+
24+
try container.encode(self.firstName, forKey: User_E.CodingKeys.firstName)
25+
try container.encodeIfPresent(self.lastName, forKey: User_E.CodingKeys.lastName)
26+
}
27+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
struct User: Codable {
3+
4+
enum CodingKeys: CodingKey {
5+
}
6+
7+
init(from decoder: Decoder) throws {
8+
9+
}
10+
11+
func encode(to encoder: Encoder) throws {
12+
var container = encoder.container(keyedBy: User.CodingKeys.self)
13+
14+
}
15+
}

0 commit comments

Comments
 (0)