Skip to content

Commit 041dfb9

Browse files
authored
Merge pull request swiftlang#29847 from tkachukandrew/add-equatable-conformance
2 parents b2a742c + d0ac023 commit 041dfb9

File tree

8 files changed

+399
-0
lines changed

8 files changed

+399
-0
lines changed

include/swift/IDE/RefactoringKinds.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ CURSOR_REFACTORING(TrailingClosure, "Convert To Trailing Closure", trailingclosu
5252

5353
CURSOR_REFACTORING(MemberwiseInitLocalRefactoring, "Generate Memberwise Initializer", memberwise.init.local.refactoring)
5454

55+
CURSOR_REFACTORING(AddEquatableConformance, "Add Equatable Conformance", add.equatable.conformance)
56+
5557
RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)
5658

5759
RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)

lib/IDE/Refactoring.cpp

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,6 +3172,250 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
31723172
return false;
31733173
}
31743174

3175+
class AddEquatableContext {
3176+
3177+
/// Declaration context
3178+
DeclContext *DC;
3179+
3180+
/// Adopter type
3181+
Type Adopter;
3182+
3183+
/// Start location of declaration context brace
3184+
SourceLoc StartLoc;
3185+
3186+
/// Array of all inherited protocols' locations
3187+
ArrayRef<TypeLoc> ProtocolsLocations;
3188+
3189+
/// Array of all conformed protocols
3190+
SmallVector<swift::ProtocolDecl *, 2> Protocols;
3191+
3192+
/// Start location of declaration,
3193+
/// a place to write protocol name
3194+
SourceLoc ProtInsertStartLoc;
3195+
3196+
/// Stored properties of extending adopter
3197+
ArrayRef<VarDecl *> StoredProperties;
3198+
3199+
/// Range of internal members in declaration
3200+
DeclRange Range;
3201+
3202+
bool conformsToEquatableProtocol() {
3203+
for (ProtocolDecl *Protocol : Protocols) {
3204+
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
3205+
return true;
3206+
}
3207+
}
3208+
return false;
3209+
}
3210+
3211+
bool isRequirementValid() {
3212+
auto Reqs = getProtocolRequirements();
3213+
if (Reqs.empty()) {
3214+
return false;
3215+
}
3216+
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
3217+
return Req && Req->getParameters()->size() == 2;
3218+
}
3219+
3220+
bool isPropertiesListValid() {
3221+
return !getUserAccessibleProperties().empty();
3222+
}
3223+
3224+
void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
3225+
ParameterList *Params);
3226+
3227+
std::vector<ValueDecl *> getProtocolRequirements();
3228+
3229+
std::vector<VarDecl *> getUserAccessibleProperties();
3230+
3231+
public:
3232+
3233+
AddEquatableContext(NominalTypeDecl *Decl) : DC(Decl),
3234+
Adopter(Decl->getDeclaredType()), StartLoc(Decl->getBraces().Start),
3235+
ProtocolsLocations(Decl->getInherited()),
3236+
Protocols(Decl->getAllProtocols()), ProtInsertStartLoc(Decl->getNameLoc()),
3237+
StoredProperties(Decl->getStoredProperties()), Range(Decl->getMembers()) {};
3238+
3239+
AddEquatableContext(ExtensionDecl *Decl) : DC(Decl),
3240+
Adopter(Decl->getExtendedType()), StartLoc(Decl->getBraces().Start),
3241+
ProtocolsLocations(Decl->getInherited()),
3242+
Protocols(Decl->getExtendedNominal()->getAllProtocols()),
3243+
ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()),
3244+
StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), Range(Decl->getMembers()) {};
3245+
3246+
AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
3247+
Protocols(), StoredProperties(), Range(nullptr, nullptr) {};
3248+
3249+
static AddEquatableContext getDeclarationContextFromInfo(ResolvedCursorInfo Info);
3250+
3251+
std::string getInsertionTextForProtocol();
3252+
3253+
std::string getInsertionTextForFunction(SourceManager &SM);
3254+
3255+
bool isValid() {
3256+
// FIXME: Allow to generate explicit == method for declarations which already have
3257+
// compiler-generated == method
3258+
return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
3259+
!conformsToEquatableProtocol() && isPropertiesListValid() &&
3260+
isRequirementValid();
3261+
}
3262+
3263+
SourceLoc getStartLocForProtocolDecl() {
3264+
if (ProtocolsLocations.empty()) {
3265+
return ProtInsertStartLoc;
3266+
}
3267+
return ProtocolsLocations.back().getSourceRange().Start;
3268+
}
3269+
3270+
bool isMembersRangeEmpty() {
3271+
return Range.empty();
3272+
}
3273+
3274+
SourceLoc getInsertStartLoc();
3275+
};
3276+
3277+
SourceLoc AddEquatableContext::
3278+
getInsertStartLoc() {
3279+
SourceLoc MaxLoc = StartLoc;
3280+
for (auto Mem : Range) {
3281+
if (Mem->getEndLoc().getOpaquePointerValue() >
3282+
MaxLoc.getOpaquePointerValue()) {
3283+
MaxLoc = Mem->getEndLoc();
3284+
}
3285+
}
3286+
return MaxLoc;
3287+
}
3288+
3289+
std::string AddEquatableContext::
3290+
getInsertionTextForProtocol() {
3291+
StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
3292+
std::string Buffer;
3293+
llvm::raw_string_ostream OS(Buffer);
3294+
if (ProtocolsLocations.empty()) {
3295+
OS << ": " << ProtocolName;
3296+
return Buffer;
3297+
}
3298+
OS << ", " << ProtocolName;
3299+
return Buffer;
3300+
}
3301+
3302+
std::string AddEquatableContext::
3303+
getInsertionTextForFunction(SourceManager &SM) {
3304+
auto Reqs = getProtocolRequirements();
3305+
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
3306+
auto Params = Req->getParameters();
3307+
StringRef ExtraIndent;
3308+
StringRef CurrentIndent =
3309+
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
3310+
std::string Indent;
3311+
if (isMembersRangeEmpty()) {
3312+
Indent = (CurrentIndent + ExtraIndent).str();
3313+
} else {
3314+
Indent = CurrentIndent.str();
3315+
}
3316+
PrintOptions Options = PrintOptions::printVerbose();
3317+
Options.PrintDocumentationComments = false;
3318+
Options.setBaseType(Adopter);
3319+
Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
3320+
Printer << " {";
3321+
Printer.printNewline();
3322+
printFunctionBody(Printer, ExtraIndent, Params);
3323+
Printer.printNewline();
3324+
Printer << "}";
3325+
};
3326+
std::string Buffer;
3327+
llvm::raw_string_ostream OS(Buffer);
3328+
ExtraIndentStreamPrinter Printer(OS, Indent);
3329+
Printer.printNewline();
3330+
if (!isMembersRangeEmpty()) {
3331+
Printer.printNewline();
3332+
}
3333+
Reqs[0]->print(Printer, Options);
3334+
return Buffer;
3335+
}
3336+
3337+
std::vector<VarDecl *> AddEquatableContext::
3338+
getUserAccessibleProperties() {
3339+
std::vector<VarDecl *> PublicProperties;
3340+
for (VarDecl *Decl : StoredProperties) {
3341+
if (Decl->Decl::isUserAccessible()) {
3342+
PublicProperties.push_back(Decl);
3343+
}
3344+
}
3345+
return PublicProperties;
3346+
}
3347+
3348+
std::vector<ValueDecl *> AddEquatableContext::
3349+
getProtocolRequirements() {
3350+
std::vector<ValueDecl *> Collection;
3351+
auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
3352+
for (auto Member : Proto->getMembers()) {
3353+
auto Req = dyn_cast<ValueDecl>(Member);
3354+
if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
3355+
continue;
3356+
}
3357+
Collection.push_back(Req);
3358+
}
3359+
return Collection;
3360+
}
3361+
3362+
AddEquatableContext AddEquatableContext::
3363+
getDeclarationContextFromInfo(ResolvedCursorInfo Info) {
3364+
if (Info.isInvalid()) {
3365+
return AddEquatableContext();
3366+
}
3367+
if (!Info.IsRef) {
3368+
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD)) {
3369+
return AddEquatableContext(NomDecl);
3370+
}
3371+
} else if (auto *ExtDecl = Info.ExtTyRef) {
3372+
return AddEquatableContext(ExtDecl);
3373+
}
3374+
return AddEquatableContext();
3375+
}
3376+
3377+
void AddEquatableContext::
3378+
printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
3379+
llvm::SmallString<128> Return;
3380+
llvm::raw_svector_ostream SS(Return);
3381+
SS << tok::kw_return;
3382+
StringRef Space = " ";
3383+
StringRef AdditionalSpace = " ";
3384+
StringRef Point = ".";
3385+
StringRef Join = " == ";
3386+
StringRef And = " &&";
3387+
auto Props = getUserAccessibleProperties();
3388+
auto FParam = Params->get(0)->getName();
3389+
auto SParam = Params->get(1)->getName();
3390+
auto Prop = Props[0]->getName();
3391+
Printer << ExtraIndent << Return << Space
3392+
<< FParam << Point << Prop << Join << SParam << Point << Prop;
3393+
if (Props.size() > 1) {
3394+
std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){
3395+
auto Name = VD->getName();
3396+
Printer << And;
3397+
Printer.printNewline();
3398+
Printer << ExtraIndent << AdditionalSpace << FParam << Point
3399+
<< Name << Join << SParam << Point << Name;
3400+
});
3401+
}
3402+
}
3403+
3404+
bool RefactoringActionAddEquatableConformance::
3405+
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
3406+
return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
3407+
}
3408+
3409+
bool RefactoringActionAddEquatableConformance::
3410+
performChange() {
3411+
auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
3412+
EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
3413+
Context.getInsertionTextForProtocol());
3414+
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
3415+
Context.getInsertionTextForFunction(SM));
3416+
return false;
3417+
}
3418+
31753419
static CharSourceRange
31763420
findSourceRangeToWrapInCatch(ResolvedCursorInfo CursorInfo,
31773421
SourceFile *TheFile,
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
class TestAddEquatable: Equatable {
2+
var property = "test"
3+
private var prop = "test2"
4+
let pr = "test3"
5+
6+
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
7+
return lhs.property == rhs.property &&
8+
lhs.prop == rhs.prop &&
9+
lhs.pr == rhs.pr
10+
}
11+
}
12+
13+
extension TestAddEquatable {
14+
func test() -> Bool {
15+
return true
16+
}
17+
}
18+
19+
extension TestAddEquatable {
20+
}
21+
22+
23+
24+
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
class TestAddEquatable {
2+
var property = "test"
3+
private var prop = "test2"
4+
let pr = "test3"
5+
}
6+
7+
extension TestAddEquatable: Equatable {
8+
func test() -> Bool {
9+
return true
10+
}
11+
12+
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
13+
return lhs.property == rhs.property &&
14+
lhs.prop == rhs.prop &&
15+
lhs.pr == rhs.pr
16+
}
17+
}
18+
19+
extension TestAddEquatable {
20+
}
21+
22+
23+
24+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
class TestAddEquatable {
2+
var property = "test"
3+
private var prop = "test2"
4+
let pr = "test3"
5+
}
6+
7+
extension TestAddEquatable {
8+
func test() -> Bool {
9+
return true
10+
}
11+
}
12+
13+
extension TestAddEquatable: Equatable {
14+
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
15+
return lhs.property == rhs.property &&
16+
lhs.prop == rhs.prop &&
17+
lhs.pr == rhs.pr
18+
}
19+
}
20+
21+
22+
23+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
class TestAddEquatable {
2+
var property = "test"
3+
private var prop = "test2"
4+
let pr = "test3"
5+
}
6+
7+
extension TestAddEquatable {
8+
func test() -> Bool {
9+
return true
10+
}
11+
}
12+
13+
extension TestAddEquatable {
14+
}
15+
16+
// RUN: rm -rf %t.result && mkdir -p %t.result
17+
18+
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=1:16 > %t.result/first.swift
19+
// RUN: diff -u %S/Outputs/basic/first.swift.expected %t.result/first.swift
20+
21+
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=7:13 > %t.result/second.swift
22+
// RUN: diff -u %S/Outputs/basic/second.swift.expected %t.result/second.swift
23+
24+
// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=13:13 > %t.result/third.swift
25+
// RUN: diff -u %S/Outputs/basic/third.swift.expected %t.result/third.swift

0 commit comments

Comments
 (0)