Skip to content

[SR-7293] Refactoring action to add Equatable Conformance #29847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/IDE/RefactoringKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ CURSOR_REFACTORING(TrailingClosure, "Convert To Trailing Closure", trailingclosu

CURSOR_REFACTORING(MemberwiseInitLocalRefactoring, "Generate Memberwise Initializer", memberwise.init.local.refactoring)

CURSOR_REFACTORING(AddEquatableConformance, "Add Equatable Conformance", add.equatable.conformance)

RANGE_REFACTORING(ExtractExpr, "Extract Expression", extract.expr)

RANGE_REFACTORING(ExtractFunction, "Extract Method", extract.function)
Expand Down
244 changes: 244 additions & 0 deletions lib/IDE/Refactoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3172,6 +3172,250 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
return false;
}

class AddEquatableContext {

/// Declaration context
DeclContext *DC;

/// Adopter type
Type Adopter;

/// Start location of declaration context brace
SourceLoc StartLoc;

/// Array of all inherited protocols' locations
ArrayRef<TypeLoc> ProtocolsLocations;

/// Array of all conformed protocols
SmallVector<swift::ProtocolDecl *, 2> Protocols;

/// Start location of declaration,
/// a place to write protocol name
SourceLoc ProtInsertStartLoc;

/// Stored properties of extending adopter
ArrayRef<VarDecl *> StoredProperties;

/// Range of internal members in declaration
DeclRange Range;

bool conformsToEquatableProtocol() {
for (ProtocolDecl *Protocol : Protocols) {
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
return true;
}
}
return false;
}

bool isRequirementValid() {
auto Reqs = getProtocolRequirements();
if (Reqs.empty()) {
return false;
}
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
return Req && Req->getParameters()->size() == 2;
}

bool isPropertiesListValid() {
return !getUserAccessibleProperties().empty();
}

void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
ParameterList *Params);

std::vector<ValueDecl *> getProtocolRequirements();

std::vector<VarDecl *> getUserAccessibleProperties();

public:

AddEquatableContext(NominalTypeDecl *Decl) : DC(Decl),
Adopter(Decl->getDeclaredType()), StartLoc(Decl->getBraces().Start),
ProtocolsLocations(Decl->getInherited()),
Protocols(Decl->getAllProtocols()), ProtInsertStartLoc(Decl->getNameLoc()),
StoredProperties(Decl->getStoredProperties()), Range(Decl->getMembers()) {};

AddEquatableContext(ExtensionDecl *Decl) : DC(Decl),
Adopter(Decl->getExtendedType()), StartLoc(Decl->getBraces().Start),
ProtocolsLocations(Decl->getInherited()),
Protocols(Decl->getExtendedNominal()->getAllProtocols()),
ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()),
StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), Range(Decl->getMembers()) {};

AddEquatableContext() : DC(nullptr), Adopter(), ProtocolsLocations(),
Protocols(), StoredProperties(), Range(nullptr, nullptr) {};

static AddEquatableContext getDeclarationContextFromInfo(ResolvedCursorInfo Info);

std::string getInsertionTextForProtocol();

std::string getInsertionTextForFunction(SourceManager &SM);

bool isValid() {
// FIXME: Allow to generate explicit == method for declarations which already have
// compiler-generated == method
return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
!conformsToEquatableProtocol() && isPropertiesListValid() &&
isRequirementValid();
}

SourceLoc getStartLocForProtocolDecl() {
if (ProtocolsLocations.empty()) {
return ProtInsertStartLoc;
}
return ProtocolsLocations.back().getSourceRange().Start;
}

bool isMembersRangeEmpty() {
return Range.empty();
}

SourceLoc getInsertStartLoc();
};

SourceLoc AddEquatableContext::
getInsertStartLoc() {
SourceLoc MaxLoc = StartLoc;
for (auto Mem : Range) {
if (Mem->getEndLoc().getOpaquePointerValue() >
MaxLoc.getOpaquePointerValue()) {
MaxLoc = Mem->getEndLoc();
}
}
return MaxLoc;
}

std::string AddEquatableContext::
getInsertionTextForProtocol() {
StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
if (ProtocolsLocations.empty()) {
OS << ": " << ProtocolName;
return Buffer;
}
OS << ", " << ProtocolName;
return Buffer;
}

std::string AddEquatableContext::
getInsertionTextForFunction(SourceManager &SM) {
auto Reqs = getProtocolRequirements();
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
auto Params = Req->getParameters();
StringRef ExtraIndent;
StringRef CurrentIndent =
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
std::string Indent;
if (isMembersRangeEmpty()) {
Indent = (CurrentIndent + ExtraIndent).str();
} else {
Indent = CurrentIndent.str();
}
PrintOptions Options = PrintOptions::printVerbose();
Options.PrintDocumentationComments = false;
Options.setBaseType(Adopter);
Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
Printer << " {";
Printer.printNewline();
printFunctionBody(Printer, ExtraIndent, Params);
Printer.printNewline();
Printer << "}";
};
std::string Buffer;
llvm::raw_string_ostream OS(Buffer);
ExtraIndentStreamPrinter Printer(OS, Indent);
Printer.printNewline();
if (!isMembersRangeEmpty()) {
Printer.printNewline();
}
Reqs[0]->print(Printer, Options);
return Buffer;
}

std::vector<VarDecl *> AddEquatableContext::
getUserAccessibleProperties() {
std::vector<VarDecl *> PublicProperties;
for (VarDecl *Decl : StoredProperties) {
if (Decl->Decl::isUserAccessible()) {
PublicProperties.push_back(Decl);
}
}
return PublicProperties;
}

std::vector<ValueDecl *> AddEquatableContext::
getProtocolRequirements() {
std::vector<ValueDecl *> Collection;
auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
for (auto Member : Proto->getMembers()) {
auto Req = dyn_cast<ValueDecl>(Member);
if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
continue;
}
Collection.push_back(Req);
}
return Collection;
}

AddEquatableContext AddEquatableContext::
getDeclarationContextFromInfo(ResolvedCursorInfo Info) {
if (Info.isInvalid()) {
return AddEquatableContext();
}
if (!Info.IsRef) {
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD)) {
return AddEquatableContext(NomDecl);
}
} else if (auto *ExtDecl = Info.ExtTyRef) {
return AddEquatableContext(ExtDecl);
}
return AddEquatableContext();
}

void AddEquatableContext::
printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
llvm::SmallString<128> Return;
llvm::raw_svector_ostream SS(Return);
SS << tok::kw_return;
StringRef Space = " ";
StringRef AdditionalSpace = " ";
StringRef Point = ".";
StringRef Join = " == ";
StringRef And = " &&";
auto Props = getUserAccessibleProperties();
auto FParam = Params->get(0)->getName();
auto SParam = Params->get(1)->getName();
auto Prop = Props[0]->getName();
Printer << ExtraIndent << Return << Space
<< FParam << Point << Prop << Join << SParam << Point << Prop;
if (Props.size() > 1) {
std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD){
auto Name = VD->getName();
Printer << And;
Printer.printNewline();
Printer << ExtraIndent << AdditionalSpace << FParam << Point
<< Name << Join << SParam << Point << Name;
});
}
}

bool RefactoringActionAddEquatableConformance::
isApplicable(ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
}

bool RefactoringActionAddEquatableConformance::
performChange() {
auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
Context.getInsertionTextForProtocol());
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
Context.getInsertionTextForFunction(SM));
return false;
}

static CharSourceRange
findSourceRangeToWrapInCatch(ResolvedCursorInfo CursorInfo,
SourceFile *TheFile,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
class TestAddEquatable: Equatable {
var property = "test"
private var prop = "test2"
let pr = "test3"

static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
return lhs.property == rhs.property &&
lhs.prop == rhs.prop &&
lhs.pr == rhs.pr
}
}

extension TestAddEquatable {
func test() -> Bool {
return true
}
}

extension TestAddEquatable {
}




Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
class TestAddEquatable {
var property = "test"
private var prop = "test2"
let pr = "test3"
}

extension TestAddEquatable: Equatable {
func test() -> Bool {
return true
}

static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
return lhs.property == rhs.property &&
lhs.prop == rhs.prop &&
lhs.pr == rhs.pr
}
}

extension TestAddEquatable {
}




Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
class TestAddEquatable {
var property = "test"
private var prop = "test2"
let pr = "test3"
}

extension TestAddEquatable {
func test() -> Bool {
return true
}
}

extension TestAddEquatable: Equatable {
static func == (lhs: TestAddEquatable, rhs: TestAddEquatable) -> Bool {
return lhs.property == rhs.property &&
lhs.prop == rhs.prop &&
lhs.pr == rhs.pr
}
}




25 changes: 25 additions & 0 deletions test/refactoring/AddEquatableConformance/basic.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
class TestAddEquatable {
var property = "test"
private var prop = "test2"
let pr = "test3"
}

extension TestAddEquatable {
func test() -> Bool {
return true
}
}

extension TestAddEquatable {
}

// RUN: rm -rf %t.result && mkdir -p %t.result

// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=1:16 > %t.result/first.swift
// RUN: diff -u %S/Outputs/basic/first.swift.expected %t.result/first.swift

// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=7:13 > %t.result/second.swift
// RUN: diff -u %S/Outputs/basic/second.swift.expected %t.result/second.swift

// RUN: %refactor -add-equatable-conformance -source-filename %s -pos=13:13 > %t.result/third.swift
// RUN: diff -u %S/Outputs/basic/third.swift.expected %t.result/third.swift
Loading