|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// This source file is part of the Swift.org open source project |
| 4 | +// |
| 5 | +// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors |
| 6 | +// Licensed under Apache License v2.0 with Runtime Library Exception |
| 7 | +// |
| 8 | +// See https://swift.org/LICENSE.txt for license information |
| 9 | +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "RefactoringActions.h" |
| 14 | +#include "Utils.h" |
| 15 | +#include "swift/AST/ParameterList.h" |
| 16 | +#include "swift/AST/TypeRepr.h" |
| 17 | + |
| 18 | +using namespace swift::refactoring; |
| 19 | + |
| 20 | +namespace { |
| 21 | +class AddEquatableContext { |
| 22 | + |
| 23 | + /// Declaration context |
| 24 | + DeclContext *DC; |
| 25 | + |
| 26 | + /// Adopter type |
| 27 | + Type Adopter; |
| 28 | + |
| 29 | + /// Start location of declaration context brace |
| 30 | + SourceLoc StartLoc; |
| 31 | + |
| 32 | + /// Array of all inherited protocols' locations |
| 33 | + ArrayRef<InheritedEntry> ProtocolsLocations; |
| 34 | + |
| 35 | + /// Array of all conformed protocols |
| 36 | + SmallVector<swift::ProtocolDecl *, 2> Protocols; |
| 37 | + |
| 38 | + /// Start location of declaration, |
| 39 | + /// a place to write protocol name |
| 40 | + SourceLoc ProtInsertStartLoc; |
| 41 | + |
| 42 | + /// Stored properties of extending adopter |
| 43 | + ArrayRef<VarDecl *> StoredProperties; |
| 44 | + |
| 45 | + /// Range of internal members in declaration |
| 46 | + DeclRange Range; |
| 47 | + |
| 48 | + bool conformsToEquatableProtocol() { |
| 49 | + for (ProtocolDecl *Protocol : Protocols) { |
| 50 | + if (Protocol->getKnownProtocolKind() == KnownProtocolKind::Equatable) { |
| 51 | + return true; |
| 52 | + } |
| 53 | + } |
| 54 | + return false; |
| 55 | + } |
| 56 | + |
| 57 | + bool isRequirementValid() { |
| 58 | + auto Reqs = getProtocolRequirements(); |
| 59 | + if (Reqs.empty()) { |
| 60 | + return false; |
| 61 | + } |
| 62 | + auto Req = dyn_cast<FuncDecl>(Reqs[0]); |
| 63 | + return Req && Req->getParameters()->size() == 2; |
| 64 | + } |
| 65 | + |
| 66 | + bool isPropertiesListValid() { |
| 67 | + return !getUserAccessibleProperties().empty(); |
| 68 | + } |
| 69 | + |
| 70 | + void printFunctionBody(ASTPrinter &Printer, StringRef ExtraIndent, |
| 71 | + ParameterList *Params); |
| 72 | + |
| 73 | + std::vector<ValueDecl *> getProtocolRequirements(); |
| 74 | + |
| 75 | + std::vector<VarDecl *> getUserAccessibleProperties(); |
| 76 | + |
| 77 | +public: |
| 78 | + AddEquatableContext(NominalTypeDecl *Decl) |
| 79 | + : DC(Decl), Adopter(Decl->getDeclaredType()), |
| 80 | + StartLoc(Decl->getBraces().Start), |
| 81 | + ProtocolsLocations(Decl->getInherited().getEntries()), |
| 82 | + Protocols(getAllProtocols(Decl)), |
| 83 | + ProtInsertStartLoc(Decl->getNameLoc()), |
| 84 | + StoredProperties(Decl->getStoredProperties()), |
| 85 | + Range(Decl->getMembers()){}; |
| 86 | + |
| 87 | + AddEquatableContext(ExtensionDecl *Decl) |
| 88 | + : DC(Decl), Adopter(Decl->getExtendedType()), |
| 89 | + StartLoc(Decl->getBraces().Start), |
| 90 | + ProtocolsLocations(Decl->getInherited().getEntries()), |
| 91 | + Protocols(getAllProtocols(Decl->getExtendedNominal())), |
| 92 | + ProtInsertStartLoc(Decl->getExtendedTypeRepr()->getEndLoc()), |
| 93 | + StoredProperties(Decl->getExtendedNominal()->getStoredProperties()), |
| 94 | + Range(Decl->getMembers()){}; |
| 95 | + |
| 96 | + AddEquatableContext() |
| 97 | + : DC(nullptr), Adopter(), ProtocolsLocations(), Protocols(), |
| 98 | + StoredProperties(), Range(nullptr, nullptr){}; |
| 99 | + |
| 100 | + static AddEquatableContext |
| 101 | + getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info); |
| 102 | + |
| 103 | + std::string getInsertionTextForProtocol(); |
| 104 | + |
| 105 | + std::string getInsertionTextForFunction(SourceManager &SM); |
| 106 | + |
| 107 | + bool isValid() { |
| 108 | + // FIXME: Allow to generate explicit == method for declarations which |
| 109 | + // already have compiler-generated == method |
| 110 | + return StartLoc.isValid() && ProtInsertStartLoc.isValid() && |
| 111 | + !conformsToEquatableProtocol() && isPropertiesListValid() && |
| 112 | + isRequirementValid(); |
| 113 | + } |
| 114 | + |
| 115 | + SourceLoc getStartLocForProtocolDecl() { |
| 116 | + if (ProtocolsLocations.empty()) { |
| 117 | + return ProtInsertStartLoc; |
| 118 | + } |
| 119 | + return ProtocolsLocations.back().getSourceRange().Start; |
| 120 | + } |
| 121 | + |
| 122 | + bool isMembersRangeEmpty() { return Range.empty(); } |
| 123 | + |
| 124 | + SourceLoc getInsertStartLoc(); |
| 125 | +}; |
| 126 | + |
| 127 | +SourceLoc AddEquatableContext::getInsertStartLoc() { |
| 128 | + SourceLoc MaxLoc = StartLoc; |
| 129 | + for (auto Mem : Range) { |
| 130 | + if (Mem->getEndLoc().getOpaquePointerValue() > |
| 131 | + MaxLoc.getOpaquePointerValue()) { |
| 132 | + MaxLoc = Mem->getEndLoc(); |
| 133 | + } |
| 134 | + } |
| 135 | + return MaxLoc; |
| 136 | +} |
| 137 | + |
| 138 | +std::string AddEquatableContext::getInsertionTextForProtocol() { |
| 139 | + StringRef ProtocolName = getProtocolName(KnownProtocolKind::Equatable); |
| 140 | + std::string Buffer; |
| 141 | + llvm::raw_string_ostream OS(Buffer); |
| 142 | + if (ProtocolsLocations.empty()) { |
| 143 | + OS << ": " << ProtocolName; |
| 144 | + return Buffer; |
| 145 | + } |
| 146 | + OS << ", " << ProtocolName; |
| 147 | + return Buffer; |
| 148 | +} |
| 149 | + |
| 150 | +std::string |
| 151 | +AddEquatableContext::getInsertionTextForFunction(SourceManager &SM) { |
| 152 | + auto Reqs = getProtocolRequirements(); |
| 153 | + auto Req = dyn_cast<FuncDecl>(Reqs[0]); |
| 154 | + auto Params = Req->getParameters(); |
| 155 | + StringRef ExtraIndent; |
| 156 | + StringRef CurrentIndent = |
| 157 | + Lexer::getIndentationForLine(SM, getInsertStartLoc(), &ExtraIndent); |
| 158 | + std::string Indent; |
| 159 | + if (isMembersRangeEmpty()) { |
| 160 | + Indent = (CurrentIndent + ExtraIndent).str(); |
| 161 | + } else { |
| 162 | + Indent = CurrentIndent.str(); |
| 163 | + } |
| 164 | + PrintOptions Options = PrintOptions::printVerbose(); |
| 165 | + Options.PrintDocumentationComments = false; |
| 166 | + Options.setBaseType(Adopter); |
| 167 | + Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) { |
| 168 | + Printer << " {"; |
| 169 | + Printer.printNewline(); |
| 170 | + printFunctionBody(Printer, ExtraIndent, Params); |
| 171 | + Printer.printNewline(); |
| 172 | + Printer << "}"; |
| 173 | + }; |
| 174 | + std::string Buffer; |
| 175 | + llvm::raw_string_ostream OS(Buffer); |
| 176 | + ExtraIndentStreamPrinter Printer(OS, Indent); |
| 177 | + Printer.printNewline(); |
| 178 | + if (!isMembersRangeEmpty()) { |
| 179 | + Printer.printNewline(); |
| 180 | + } |
| 181 | + Reqs[0]->print(Printer, Options); |
| 182 | + return Buffer; |
| 183 | +} |
| 184 | + |
| 185 | +std::vector<VarDecl *> AddEquatableContext::getUserAccessibleProperties() { |
| 186 | + std::vector<VarDecl *> PublicProperties; |
| 187 | + for (VarDecl *Decl : StoredProperties) { |
| 188 | + if (Decl->Decl::isUserAccessible()) { |
| 189 | + PublicProperties.push_back(Decl); |
| 190 | + } |
| 191 | + } |
| 192 | + return PublicProperties; |
| 193 | +} |
| 194 | + |
| 195 | +std::vector<ValueDecl *> AddEquatableContext::getProtocolRequirements() { |
| 196 | + std::vector<ValueDecl *> Collection; |
| 197 | + auto Proto = DC->getASTContext().getProtocol(KnownProtocolKind::Equatable); |
| 198 | + for (auto Member : Proto->getMembers()) { |
| 199 | + auto Req = dyn_cast<ValueDecl>(Member); |
| 200 | + if (!Req || Req->isInvalid() || !Req->isProtocolRequirement()) { |
| 201 | + continue; |
| 202 | + } |
| 203 | + Collection.push_back(Req); |
| 204 | + } |
| 205 | + return Collection; |
| 206 | +} |
| 207 | + |
| 208 | +AddEquatableContext |
| 209 | +AddEquatableContext::getDeclarationContextFromInfo(ResolvedCursorInfoPtr Info) { |
| 210 | + auto ValueRefInfo = dyn_cast<ResolvedValueRefCursorInfo>(Info); |
| 211 | + if (!ValueRefInfo) { |
| 212 | + return AddEquatableContext(); |
| 213 | + } |
| 214 | + if (!ValueRefInfo->isRef()) { |
| 215 | + if (auto *NomDecl = dyn_cast<NominalTypeDecl>(ValueRefInfo->getValueD())) { |
| 216 | + return AddEquatableContext(NomDecl); |
| 217 | + } |
| 218 | + } else if (auto *ExtDecl = ValueRefInfo->getExtTyRef()) { |
| 219 | + if (ExtDecl->getExtendedNominal()) { |
| 220 | + return AddEquatableContext(ExtDecl); |
| 221 | + } |
| 222 | + } |
| 223 | + return AddEquatableContext(); |
| 224 | +} |
| 225 | + |
| 226 | +void AddEquatableContext::printFunctionBody(ASTPrinter &Printer, |
| 227 | + StringRef ExtraIndent, |
| 228 | + ParameterList *Params) { |
| 229 | + SmallString<128> Return; |
| 230 | + llvm::raw_svector_ostream SS(Return); |
| 231 | + SS << tok::kw_return; |
| 232 | + StringRef Space = " "; |
| 233 | + StringRef AdditionalSpace = " "; |
| 234 | + StringRef Point = "."; |
| 235 | + StringRef Join = " == "; |
| 236 | + StringRef And = " &&"; |
| 237 | + auto Props = getUserAccessibleProperties(); |
| 238 | + auto FParam = Params->get(0)->getName(); |
| 239 | + auto SParam = Params->get(1)->getName(); |
| 240 | + auto Prop = Props[0]->getName(); |
| 241 | + Printer << ExtraIndent << Return << Space << FParam << Point << Prop << Join |
| 242 | + << SParam << Point << Prop; |
| 243 | + if (Props.size() > 1) { |
| 244 | + std::for_each(Props.begin() + 1, Props.end(), [&](VarDecl *VD) { |
| 245 | + auto Name = VD->getName(); |
| 246 | + Printer << And; |
| 247 | + Printer.printNewline(); |
| 248 | + Printer << ExtraIndent << AdditionalSpace << FParam << Point << Name |
| 249 | + << Join << SParam << Point << Name; |
| 250 | + }); |
| 251 | + } |
| 252 | +} |
| 253 | +} // namespace |
| 254 | + |
| 255 | +bool RefactoringActionAddEquatableConformance::isApplicable( |
| 256 | + ResolvedCursorInfoPtr Tok, DiagnosticEngine &Diag) { |
| 257 | + return AddEquatableContext::getDeclarationContextFromInfo(Tok).isValid(); |
| 258 | +} |
| 259 | + |
| 260 | +bool RefactoringActionAddEquatableConformance::performChange() { |
| 261 | + auto Context = AddEquatableContext::getDeclarationContextFromInfo(CursorInfo); |
| 262 | + EditConsumer.insertAfter(SM, Context.getStartLocForProtocolDecl(), |
| 263 | + Context.getInsertionTextForProtocol()); |
| 264 | + EditConsumer.insertAfter(SM, Context.getInsertStartLoc(), |
| 265 | + Context.getInsertionTextForFunction(SM)); |
| 266 | + return false; |
| 267 | +} |
0 commit comments