Skip to content

Commit b27ee0e

Browse files
committed
Merge remote-tracking branch 'origin/main' into rebranch
2 parents ef34932 + 365d56f commit b27ee0e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+10250
-8991
lines changed
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "RefactoringActions.h"
14+
#include "Utils.h"
15+
#include "swift/AST/ParameterList.h"
16+
#include "swift/AST/TypeRepr.h"
17+
18+
using namespace swift::refactoring;
19+
20+
namespace {
21+
class AddEquatableContext {
22+
23+
/// Declaration context
24+
DeclContext *DC;
25+
26+
/// Adopter type
27+
Type Adopter;
28+
29+
/// Start location of declaration context brace
30+
SourceLoc StartLoc;
31+
32+
/// Array of all inherited protocols' locations
33+
ArrayRef<InheritedEntry> ProtocolsLocations;
34+
35+
/// Array of all conformed protocols
36+
SmallVector<swift::ProtocolDecl *, 2> Protocols;
37+
38+
/// Start location of declaration,
39+
/// a place to write protocol name
40+
SourceLoc ProtInsertStartLoc;
41+
42+
/// Stored properties of extending adopter
43+
ArrayRef<VarDecl *> StoredProperties;
44+
45+
/// Range of internal members in declaration
46+
DeclRange Range;
47+
48+
bool conformsToEquatableProtocol() {
49+
for (ProtocolDecl *Protocol : Protocols) {
50+
if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) {
51+
return true;
52+
}
53+
}
54+
return false;
55+
}
56+
57+
bool isRequirementValid() {
58+
auto Reqs = getProtocolRequirements();
59+
if (Reqs.empty()) {
60+
return false;
61+
}
62+
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
63+
return Req && Req->getParameters()->size() == 2;
64+
}
65+
66+
bool isPropertiesListValid() {
67+
return !getUserAccessibleProperties().empty();
68+
}
69+
70+
void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent,
71+
ParameterList *Params);
72+
73+
std::vector<ValueDecl *> getProtocolRequirements();
74+
75+
std::vector<VarDecl *> getUserAccessibleProperties();
76+
77+
public:
78+
AddEquatableContext(NominalTypeDecl *Decl)
79+
: DC(Decl), Adopter(Decl->getDeclaredType()),
80+
StartLoc(Decl->getBraces().Start),
81+
ProtocolsLocations(Decl->getInherited().getEntries()),
82+
Protocols(getAllProtocols(Decl)),
83+
ProtInsertStartLoc(Decl->getNameLoc()),
84+
StoredProperties(Decl->getStoredProperties()),
85+
Range(Decl->getMembers()){};
86+
87+
AddEquatableContext(ExtensionDecl *Decl)
88+
: DC(Decl), Adopter(Decl->getExtendedType()),
89+
StartLoc(Decl->getBraces().Start),
90+
ProtocolsLocations(Decl->getInherited().getEntries()),
91+
Protocols(getAllProtocols(Decl->getExtendedNominal())),
92+
ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()),
93+
StoredProperties(Decl->getExtendedNominal()->getStoredProperties()),
94+
Range(Decl->getMembers()){};
95+
96+
AddEquatableContext()
97+
: DC(nullptr), Adopter(), ProtocolsLocations(), Protocols(),
98+
StoredProperties(), Range(nullptr, nullptr){};
99+
100+
static AddEquatableContext
101+
getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info);
102+
103+
std::string getInsertionTextForProtocol();
104+
105+
std::string getInsertionTextForFunction(SourceManager &SM);
106+
107+
bool isValid() {
108+
// FIXME: Allow to generate explicit == method for declarations which
109+
// already have compiler-generated == method
110+
return StartLoc.isValid() && ProtInsertStartLoc.isValid() &&
111+
!conformsToEquatableProtocol() && isPropertiesListValid() &&
112+
isRequirementValid();
113+
}
114+
115+
SourceLoc getStartLocForProtocolDecl() {
116+
if (ProtocolsLocations.empty()) {
117+
return ProtInsertStartLoc;
118+
}
119+
return ProtocolsLocations.back().getSourceRange().Start;
120+
}
121+
122+
bool isMembersRangeEmpty() { return Range.empty(); }
123+
124+
SourceLoc getInsertStartLoc();
125+
};
126+
127+
SourceLoc AddEquatableContext::getInsertStartLoc() {
128+
SourceLoc MaxLoc = StartLoc;
129+
for (auto Mem : Range) {
130+
if (Mem->getEndLoc().getOpaquePointerValue() >
131+
MaxLoc.getOpaquePointerValue()) {
132+
MaxLoc = Mem->getEndLoc();
133+
}
134+
}
135+
return MaxLoc;
136+
}
137+
138+
std::string AddEquatableContext::getInsertionTextForProtocol() {
139+
StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable);
140+
std::string Buffer;
141+
llvm::raw_string_ostream OS(Buffer);
142+
if (ProtocolsLocations.empty()) {
143+
OS << ": " << ProtocolName;
144+
return Buffer;
145+
}
146+
OS << ", " << ProtocolName;
147+
return Buffer;
148+
}
149+
150+
std::string
151+
AddEquatableContext::getInsertionTextForFunction(SourceManager &SM) {
152+
auto Reqs = getProtocolRequirements();
153+
auto Req = dyn_cast<FuncDecl>(Reqs[0]);
154+
auto Params = Req->getParameters();
155+
StringRef ExtraIndent;
156+
StringRef CurrentIndent =
157+
Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent);
158+
std::string Indent;
159+
if (isMembersRangeEmpty()) {
160+
Indent = (CurrentIndent + ExtraIndent).str();
161+
} else {
162+
Indent = CurrentIndent.str();
163+
}
164+
PrintOptions Options = PrintOptions::printVerbose();
165+
Options.PrintDocumentationComments = false;
166+
Options.setBaseType(Adopter);
167+
Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
168+
Printer << " {";
169+
Printer.printNewline();
170+
printFunctionBody(Printer, ExtraIndent, Params);
171+
Printer.printNewline();
172+
Printer << "}";
173+
};
174+
std::string Buffer;
175+
llvm::raw_string_ostream OS(Buffer);
176+
ExtraIndentStreamPrinter Printer(OS, Indent);
177+
Printer.printNewline();
178+
if (!isMembersRangeEmpty()) {
179+
Printer.printNewline();
180+
}
181+
Reqs[0]->print(Printer, Options);
182+
return Buffer;
183+
}
184+
185+
std::vector<VarDecl *> AddEquatableContext::getUserAccessibleProperties() {
186+
std::vector<VarDecl *> PublicProperties;
187+
for (VarDecl *Decl : StoredProperties) {
188+
if (Decl->Decl::isUserAccessible()) {
189+
PublicProperties.push_back(Decl);
190+
}
191+
}
192+
return PublicProperties;
193+
}
194+
195+
std::vector<ValueDecl *> AddEquatableContext::getProtocolRequirements() {
196+
std::vector<ValueDecl *> Collection;
197+
auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable);
198+
for (auto Member : Proto->getMembers()) {
199+
auto Req = dyn_cast<ValueDecl>(Member);
200+
if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) {
201+
continue;
202+
}
203+
Collection.push_back(Req);
204+
}
205+
return Collection;
206+
}
207+
208+
AddEquatableContext
209+
AddEquatableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) {
210+
auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info);
211+
if (!ValueRefInfo) {
212+
return AddEquatableContext();
213+
}
214+
if (!ValueRefInfo->isRef()) {
215+
if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) {
216+
return AddEquatableContext(NomDecl);
217+
}
218+
} else if (auto *ExtDecl = ValueRefInfo->getExtTyRef()) {
219+
if (ExtDecl->getExtendedNominal()) {
220+
return AddEquatableContext(ExtDecl);
221+
}
222+
}
223+
return AddEquatableContext();
224+
}
225+
226+
void AddEquatableContext::printFunctionBody(ASTPrinter &Printer,
227+
StringRef ExtraIndent,
228+
ParameterList *Params) {
229+
SmallString<128> Return;
230+
llvm::raw_svector_ostream SS(Return);
231+
SS << tok::kw_return;
232+
StringRef Space = " ";
233+
StringRef AdditionalSpace = " ";
234+
StringRef Point = ".";
235+
StringRef Join = " == ";
236+
StringRef And = " &&";
237+
auto Props = getUserAccessibleProperties();
238+
auto FParam = Params->get(0)->getName();
239+
auto SParam = Params->get(1)->getName();
240+
auto Prop = Props[0]->getName();
241+
Printer << ExtraIndent << Return << Space << FParam << Point << Prop << Join
242+
<< SParam << Point << Prop;
243+
if (Props.size() > 1) {
244+
std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD) {
245+
auto Name = VD->getName();
246+
Printer << And;
247+
Printer.printNewline();
248+
Printer << ExtraIndent << AdditionalSpace << FParam << Point << Name
249+
<< Join << SParam << Point << Name;
250+
});
251+
}
252+
}
253+
} // namespace
254+
255+
bool RefactoringActionAddEquatableConformance::isApplicable(
256+
ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) {
257+
return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid();
258+
}
259+
260+
bool RefactoringActionAddEquatableConformance::performChange() {
261+
auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo);
262+
EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(),
263+
Context.getInsertionTextForProtocol());
264+
EditConsumer.insertAfter(SM, Context.getInsertStartLoc(),
265+
Context.getInsertionTextForFunction(SM));
266+
return false;
267+
}

0 commit comments

Comments
 (0)