Skip to content

Commit ad15f21

Browse files
committed
[SR-7293] Refactoring action to add Equatable Conformance
1 parent 18455e6 commit ad15f21

File tree

7 files changed

+309
-0
lines changed

7 files changed

+309
-0
lines changed

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ CURSOR_REFACTORING(TrailingClosure, "Convert To Trailing Closure", trailingclosu
5252

5353
CURSOR_REFACTORING(MemberwiseInitLocalRefactoring, "Generate Memberwise Initializer", memberwise.init.local.refactoring)
5454

55+
CURSOR_REFACTORING(AddEquatableConformance, "Add Equatable Conformance", add.equatable.conformance)
56+
5557
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
5658

5759
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

lib/IDE/Refactoring.cpp

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3144,6 +3144,252 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
31443144
return false;
31453145
}
31463146

3147+
class AddEquatableContext {
3148+
3149+
/// Declaration context
3150+
DeclContext *DC;
3151+
3152+
/// Adopter type
3153+
Type Adopter;
3154+
3155+
/// Start location of declaration context brace
3156+
SourceLoc StartLoc;
3157+
3158+
/// Array of all inherited protocols' locations
3159+
ArrayRef<TypeLoc> ProtocolsLocations;
3160+
3161+
/// Array of all conformed protocols
3162+
SmallVector<swift::ProtocolDecl *, 2> Protocols;
3163+
3164+
/// Start location of declaration,
3165+
/// a place to write protocol name
3166+
SourceLoc ProtInsertStartLoc;
3167+
3168+
/// Stored properties of extending adopter
3169+
ArrayRef<VarDecl *> StoredProperties;
3170+
3171+
/// Range of internal members in declaration
3172+
DeclRange Range;
3173+
3174+
bool conformsToEquatableProtocol() {
3175+
for (ProtocolDecl *Protocol : Protocols) {
3176+
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
3177+
return true;
3178+
}
3179+
}
3180+
return false;
3181+
}
3182+
3183+
bool isRequirementValid() {
3184+
auto Reqs = getProtocolRequirements();
3185+
if (Reqs.empty()) {
3186+
return false;
3187+
}
3188+
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
3189+
auto Params = Req->getParameters();
3190+
if (!Req || Params->size() != 2) {
3191+
return false;
3192+
}
3193+
return true;
3194+
}
3195+
3196+
bool isPropertiesListValid() {
3197+
return !getPublicProperties().empty();
3198+
}
3199+
3200+
void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
3201+
ParameterList *Params);
3202+
3203+
std::vector<ValueDecl *> getProtocolRequirements();
3204+
3205+
std::vector<VarDecl *> getPublicProperties();
3206+
3207+
public:
3208+
3209+
AddEquatableContext(NominalTypeDecl *Decl) : DC(Decl),
3210+
Adopter(Decl->getDeclaredType()), StartLoc(Decl->getBraces().Start),
3211+
ProtocolsLocations(Decl->getInherited()),
3212+
Protocols(Decl->getAllProtocols()), ProtInsertStartLoc(Decl->getNameLoc()),
3213+
StoredProperties(Decl->getStoredProperties()), Range(Decl->getMembers()) {};
3214+
3215+
AddEquatableContext(ExtensionDecl *Decl) : DC(Decl),
3216+
Adopter(Decl->getExtendedType()), StartLoc(Decl->getBraces().Start),
3217+
ProtocolsLocations(Decl->getInherited()),
3218+
Protocols(Decl->getExtendedNominal()->getAllProtocols()),
3219+
ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()),
3220+
StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), Range(Decl->getMembers()) {};
3221+
3222+
AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
3223+
Protocols(), StoredProperties(), Range(nullptr, nullptr) {};
3224+
3225+
static AddEquatableContext getDeclarationContextFromInfo(ResolvedCursorInfo Info);
3226+
3227+
std::string getDeclForProtocol();
3228+
3229+
std::string getDeclForFunction(SourceManager &SM);
3230+
3231+
bool isValid() {
3232+
return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
3233+
!conformsToEquatableProtocol() && isPropertiesListValid() &&
3234+
isRequirementValid();
3235+
}
3236+
3237+
SourceLoc getStartLocForProtocolDecl() {
3238+
if (ProtocolsLocations.empty()) {
3239+
return ProtInsertStartLoc;
3240+
}
3241+
return ProtocolsLocations.back().getSourceRange().Start;
3242+
}
3243+
3244+
bool isMembersRangeEmpty() {
3245+
return Range.empty();
3246+
}
3247+
3248+
SourceLoc getInsertStartLoc();
3249+
};
3250+
3251+
SourceLoc AddEquatableContext::
3252+
getInsertStartLoc() {
3253+
SourceLoc MaxLoc = StartLoc;
3254+
for (auto Mem : Range) {
3255+
if (Mem->getEndLoc().getOpaquePointerValue() >
3256+
MaxLoc.getOpaquePointerValue()) {
3257+
MaxLoc = Mem->getEndLoc();
3258+
}
3259+
}
3260+
return MaxLoc;
3261+
}
3262+
3263+
std::string AddEquatableContext::
3264+
getDeclForProtocol() {
3265+
StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
3266+
std::string Buffer;
3267+
llvm::raw_string_ostream OS(Buffer);
3268+
if (ProtocolsLocations.empty()) {
3269+
OS << ": " << ProtocolName;
3270+
return Buffer;
3271+
}
3272+
OS << ", " << ProtocolName;
3273+
return Buffer;
3274+
}
3275+
3276+
std::string AddEquatableContext::
3277+
getDeclForFunction(SourceManager &SM) {
3278+
auto Reqs = getProtocolRequirements();
3279+
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
3280+
auto Params = Req->getParameters();
3281+
StringRef ExtraIndent;
3282+
StringRef CurrentIndent =
3283+
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
3284+
std::string Indent;
3285+
if (isMembersRangeEmpty()) {
3286+
Indent = (CurrentIndent + ExtraIndent).str();
3287+
} else {
3288+
Indent = CurrentIndent.str();
3289+
}
3290+
PrintOptions Options = PrintOptions::printVerbose();
3291+
Options.PrintDocumentationComments = false;
3292+
Options.setBaseType(Adopter);
3293+
Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
3294+
Printer << " {";
3295+
Printer.printNewline();
3296+
printFunctionBody(Printer, ExtraIndent, Params);
3297+
Printer.printNewline();
3298+
Printer << "}";
3299+
};
3300+
std::string Buffer;
3301+
llvm::raw_string_ostream OS(Buffer);
3302+
ExtraIndentStreamPrinter Printer(OS, Indent);
3303+
Printer.printNewline();
3304+
if (!isMembersRangeEmpty()) {
3305+
Printer.printNewline();
3306+
}
3307+
Reqs[0]->print(Printer, Options);
3308+
return Buffer;
3309+
}
3310+
3311+
std::vector<VarDecl *> AddEquatableContext::
3312+
getPublicProperties() {
3313+
std::vector<VarDecl *> PublicProperties;
3314+
for (VarDecl *Decl : StoredProperties) {
3315+
if (!Decl->hasPrivateAccessor()) {
3316+
PublicProperties.push_back(Decl);
3317+
}
3318+
}
3319+
return PublicProperties;
3320+
}
3321+
3322+
std::vector<ValueDecl *> AddEquatableContext::
3323+
getProtocolRequirements() {
3324+
std::vector<ValueDecl *> Collection;
3325+
auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
3326+
for (auto Member : Proto->getMembers()) {
3327+
auto Req = dyn_cast<ValueDecl>(Member);
3328+
if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
3329+
continue;
3330+
}
3331+
Collection.push_back(Req);
3332+
}
3333+
return Collection;
3334+
}
3335+
3336+
AddEquatableContext AddEquatableContext::
3337+
getDeclarationContextFromInfo(ResolvedCursorInfo Info) {
3338+
if (Info.isInvalid()) {
3339+
return AddEquatableContext();
3340+
}
3341+
if (!Info.IsRef) {
3342+
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD)) {
3343+
return AddEquatableContext(NomDecl);
3344+
}
3345+
} else if (auto *ExtDecl = Info.ExtTyRef) {
3346+
return AddEquatableContext(ExtDecl);
3347+
}
3348+
return AddEquatableContext();
3349+
}
3350+
3351+
void AddEquatableContext::
3352+
printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
3353+
llvm::SmallString<128> Return;
3354+
llvm::raw_svector_ostream SS(Return);
3355+
SS << tok::kw_return;
3356+
StringRef Space = " ";
3357+
StringRef AdditionalSpace = " ";
3358+
StringRef Point = ".";
3359+
StringRef Join = " == ";
3360+
StringRef And = " &&";
3361+
auto Props = getPublicProperties();
3362+
auto FParam = Params->get(0)->getName();
3363+
auto SParam = Params->get(1)->getName();
3364+
auto Prop = Props[0]->getName();
3365+
Printer << ExtraIndent << Return << Space
3366+
<< FParam << Point << Prop << Join << SParam << Point << Prop;
3367+
if (Props.size() > 1) {
3368+
std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){
3369+
auto Name = VD->getName();
3370+
Printer << And;
3371+
Printer.printNewline();
3372+
Printer << ExtraIndent << AdditionalSpace << FParam << Point
3373+
<< Name << Join << SParam << Point << Name;
3374+
});
3375+
}
3376+
}
3377+
3378+
bool RefactoringActionAddEquatableConformance::
3379+
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
3380+
return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
3381+
}
3382+
3383+
bool RefactoringActionAddEquatableConformance::
3384+
performChange() {
3385+
auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
3386+
EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
3387+
Context.getDeclForProtocol());
3388+
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
3389+
Context.getDeclForFunction(SM));
3390+
return false;
3391+
}
3392+
31473393
static CharSourceRange
31483394
findSourceRangeToWrapInCatch(ResolvedCursorInfo CursorInfo,
31493395
SourceFile *TheFile,
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
class TestAddEquatable: Equatable {
2+
var property = "test"
3+
private var prop = "test2"
4+
let pr = "test3"
5+
6+
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
7+
return lhs.property == rhs.property &&
8+
lhs.pr == rhs.pr
9+
}
10+
}
11+
12+
extension TestAddEquatable {
13+
func test() -> Bool {
14+
return true
15+
}
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
class TestAddEquatable {
2+
var property = "test"
3+
private var prop = "test2"
4+
let pr = "test3"
5+
}
6+
7+
extension TestAddEquatable: Equatable {
8+
func test() -> Bool {
9+
return true
10+
}
11+
12+
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
13+
return lhs.property == rhs.property &&
14+
lhs.pr == rhs.pr
15+
}
16+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class TestAddEquatable {
2+
var property = "test"
3+
private var prop = "test2"
4+
let pr = "test3"
5+
}
6+
7+
extension TestAddEquatable {
8+
func test() -> Bool {
9+
return true
10+
}
11+
}
12+
13+
// RUN: rm -rf %t.result && mkdir -p %t.result
14+
15+
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=1:16 > %t.result/first.swift
16+
// RUN: diff -u %S/Outputs/basic/first.swift.expected %t.result/first.swift
17+
18+
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=7:13 > %t.result/second.swift
19+
// RUN: diff -u %S/Outputs/basic/second.swift.expected %t.result/second.swift

test/refactoring/RefactoringKind/basic.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,12 @@ struct S {
296296
}
297297
}
298298

299+
class TestAddEquatable {
300+
var property = "test"
301+
private var prop = "test2"
302+
let pr = "test3"
303+
}
304+
299305
// RUN: %refactor -source-filename %s -pos=2:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
300306
// RUN: %refactor -source-filename %s -pos=3:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
301307
// RUN: %refactor -source-filename %s -pos=4:1 -end-pos=5:13 | %FileCheck %s -check-prefix=CHECK1
@@ -397,6 +403,8 @@ struct S {
397403
// RUN: %refactor -source-filename %s -pos=291:3 -end-pos=291:18 | %FileCheck %s -check-prefix=CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY
398404
// RUN: %refactor -source-filename %s -pos=292:3 -end-pos=296:4 | %FileCheck %s -check-prefix=CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY
399405

406+
// RUN: %refactor -source-filename %s -pos=299:16 | %FileCheck %s -check-prefix=CHECK-ADD-EQUATABLE-CONFORMANCE
407+
400408
// CHECK1: Action begins
401409
// CHECK1-NEXT: Extract Method
402410
// CHECK1-NEXT: Action ends
@@ -454,3 +462,4 @@ struct S {
454462
// CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY-NOT: Convert To Computed Property
455463
// CHECK-IS-NOT-CONVERT-TO-COMPUTED-PROPERTY: Action ends
456464

465+
// CHECK-ADD-EQUATABLE-CONFORMANCE: Add Equatable Conformance

tools/swift-refactor/swift-refactor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Action(llvm::cl::desc("kind:"), llvm::cl::init(RefactoringKind::None),
7272
clEnumValN(RefactoringKind::ReplaceBodiesWithFatalError,
7373
"replace-bodies-with-fatalError", "Perform trailing closure refactoring"),
7474
clEnumValN(RefactoringKind::MemberwiseInitLocalRefactoring, "memberwise-init", "Generate member wise initializer"),
75+
clEnumValN(RefactoringKind::AddEquatableConformance, "add-equatable-conformance", "Add Equatable conformance"),
7576
clEnumValN(RefactoringKind::ConvertToComputedProperty,
7677
"convert-to-computed-property", "Convert from field initialization to computed property"),
7778
clEnumValN(RefactoringKind::ConvertToSwitchStmt, "convert-to-switch-stmt", "Perform convert to switch statement")));

0 commit comments

Comments
 (0)