@@ -3144,6 +3144,252 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
3144
3144
return false ;
3145
3145
}
3146
3146
3147
+ class AddEquatableContext {
3148
+
3149
+ // / Declaration context
3150
+ DeclContext *DC;
3151
+
3152
+ // / Adopter type
3153
+ Type Adopter;
3154
+
3155
+ // / Start location of declaration context brace
3156
+ SourceLoc StartLoc;
3157
+
3158
+ // / Array of all inherited protocols' locations
3159
+ ArrayRef<TypeLoc> ProtocolsLocations;
3160
+
3161
+ // / Array of all conformed protocols
3162
+ SmallVector<swift::ProtocolDecl *, 2 > Protocols;
3163
+
3164
+ // / Start location of declaration,
3165
+ // / a place to write protocol name
3166
+ SourceLoc ProtInsertStartLoc;
3167
+
3168
+ // / Stored properties of extending adopter
3169
+ ArrayRef<VarDecl *> StoredProperties;
3170
+
3171
+ // / Range of internal members in declaration
3172
+ DeclRange Range;
3173
+
3174
+ bool conformsToEquatableProtocol () {
3175
+ for (ProtocolDecl *Protocol : Protocols) {
3176
+ if (Protocol->getKnownProtocolKind () == KnownProtocolKind::Equatable) {
3177
+ return true ;
3178
+ }
3179
+ }
3180
+ return false ;
3181
+ }
3182
+
3183
+ bool isRequirementValid () {
3184
+ auto Reqs = getProtocolRequirements ();
3185
+ if (Reqs.empty ()) {
3186
+ return false ;
3187
+ }
3188
+ auto Req = dyn_cast<FuncDecl>(Reqs[0 ]);
3189
+ auto Params = Req->getParameters ();
3190
+ if (!Req || Params->size () != 2 ) {
3191
+ return false ;
3192
+ }
3193
+ return true ;
3194
+ }
3195
+
3196
+ bool isPropertiesListValid () {
3197
+ return !getPublicProperties ().empty ();
3198
+ }
3199
+
3200
+ void printFunctionBody (ASTPrinter &Printer, StringRef ExtraIndent,
3201
+ ParameterList *Params);
3202
+
3203
+ std::vector<ValueDecl *> getProtocolRequirements ();
3204
+
3205
+ std::vector<VarDecl *> getPublicProperties ();
3206
+
3207
+ public:
3208
+
3209
+ AddEquatableContext (NominalTypeDecl *Decl) : DC(Decl),
3210
+ Adopter (Decl->getDeclaredType ()), StartLoc(Decl->getBraces ().Start),
3211
+ ProtocolsLocations(Decl->getInherited ()),
3212
+ Protocols(Decl->getAllProtocols ()), ProtInsertStartLoc(Decl->getNameLoc ()),
3213
+ StoredProperties(Decl->getStoredProperties ()), Range(Decl->getMembers ()) {};
3214
+
3215
+ AddEquatableContext (ExtensionDecl *Decl) : DC(Decl),
3216
+ Adopter(Decl->getExtendedType ()), StartLoc(Decl->getBraces ().Start),
3217
+ ProtocolsLocations(Decl->getInherited ()),
3218
+ Protocols(Decl->getExtendedNominal ()->getAllProtocols()),
3219
+ ProtInsertStartLoc(Decl->getExtendedTypeRepr ()->getEndLoc()),
3220
+ StoredProperties(Decl->getExtendedNominal ()->getStoredProperties()), Range(Decl->getMembers ()) {};
3221
+
3222
+ AddEquatableContext () : DC(nullptr ), Adopter(), ProtocolsLocations(),
3223
+ Protocols(), StoredProperties(), Range(nullptr , nullptr ) {};
3224
+
3225
+ static AddEquatableContext getDeclarationContextFromInfo (ResolvedCursorInfo Info);
3226
+
3227
+ std::string getDeclForProtocol ();
3228
+
3229
+ std::string getDeclForFunction (SourceManager &SM);
3230
+
3231
+ bool isValid () {
3232
+ return StartLoc.isValid () && ProtInsertStartLoc.isValid () &&
3233
+ !conformsToEquatableProtocol () && isPropertiesListValid () &&
3234
+ isRequirementValid ();
3235
+ }
3236
+
3237
+ SourceLoc getStartLocForProtocolDecl () {
3238
+ if (ProtocolsLocations.empty ()) {
3239
+ return ProtInsertStartLoc;
3240
+ }
3241
+ return ProtocolsLocations.back ().getSourceRange ().Start ;
3242
+ }
3243
+
3244
+ bool isMembersRangeEmpty () {
3245
+ return Range.empty ();
3246
+ }
3247
+
3248
+ SourceLoc getInsertStartLoc ();
3249
+ };
3250
+
3251
+ SourceLoc AddEquatableContext::
3252
+ getInsertStartLoc () {
3253
+ SourceLoc MaxLoc = StartLoc;
3254
+ for (auto Mem : Range) {
3255
+ if (Mem->getEndLoc ().getOpaquePointerValue () >
3256
+ MaxLoc.getOpaquePointerValue ()) {
3257
+ MaxLoc = Mem->getEndLoc ();
3258
+ }
3259
+ }
3260
+ return MaxLoc;
3261
+ }
3262
+
3263
+ std::string AddEquatableContext::
3264
+ getDeclForProtocol () {
3265
+ StringRef ProtocolName = getProtocolName (KnownProtocolKind::Equatable);
3266
+ std::string Buffer;
3267
+ llvm::raw_string_ostream OS (Buffer);
3268
+ if (ProtocolsLocations.empty ()) {
3269
+ OS << " : " << ProtocolName;
3270
+ return Buffer;
3271
+ }
3272
+ OS << " , " << ProtocolName;
3273
+ return Buffer;
3274
+ }
3275
+
3276
+ std::string AddEquatableContext::
3277
+ getDeclForFunction (SourceManager &SM) {
3278
+ auto Reqs = getProtocolRequirements ();
3279
+ auto Req = dyn_cast<FuncDecl>(Reqs[0 ]);
3280
+ auto Params = Req->getParameters ();
3281
+ StringRef ExtraIndent;
3282
+ StringRef CurrentIndent =
3283
+ Lexer::getIndentationForLine (SM, getInsertStartLoc (), &ExtraIndent);
3284
+ std::string Indent;
3285
+ if (isMembersRangeEmpty ()) {
3286
+ Indent = (CurrentIndent + ExtraIndent).str ();
3287
+ } else {
3288
+ Indent = CurrentIndent.str ();
3289
+ }
3290
+ PrintOptions Options = PrintOptions::printVerbose ();
3291
+ Options.PrintDocumentationComments = false ;
3292
+ Options.setBaseType (Adopter);
3293
+ Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
3294
+ Printer << " {" ;
3295
+ Printer.printNewline ();
3296
+ printFunctionBody (Printer, ExtraIndent, Params);
3297
+ Printer.printNewline ();
3298
+ Printer << " }" ;
3299
+ };
3300
+ std::string Buffer;
3301
+ llvm::raw_string_ostream OS (Buffer);
3302
+ ExtraIndentStreamPrinter Printer (OS, Indent);
3303
+ Printer.printNewline ();
3304
+ if (!isMembersRangeEmpty ()) {
3305
+ Printer.printNewline ();
3306
+ }
3307
+ Reqs[0 ]->print (Printer, Options);
3308
+ return Buffer;
3309
+ }
3310
+
3311
+ std::vector<VarDecl *> AddEquatableContext::
3312
+ getPublicProperties () {
3313
+ std::vector<VarDecl *> PublicProperties;
3314
+ for (VarDecl *Decl : StoredProperties) {
3315
+ if (!Decl->hasPrivateAccessor ()) {
3316
+ PublicProperties.push_back (Decl);
3317
+ }
3318
+ }
3319
+ return PublicProperties;
3320
+ }
3321
+
3322
+ std::vector<ValueDecl *> AddEquatableContext::
3323
+ getProtocolRequirements () {
3324
+ std::vector<ValueDecl *> Collection;
3325
+ auto Proto = DC->getASTContext ().getProtocol (KnownProtocolKind::Equatable);
3326
+ for (auto Member : Proto->getMembers ()) {
3327
+ auto Req = dyn_cast<ValueDecl>(Member);
3328
+ if (!Req || Req->isInvalid () || !Req->isProtocolRequirement ()) {
3329
+ continue ;
3330
+ }
3331
+ Collection.push_back (Req);
3332
+ }
3333
+ return Collection;
3334
+ }
3335
+
3336
+ AddEquatableContext AddEquatableContext::
3337
+ getDeclarationContextFromInfo (ResolvedCursorInfo Info) {
3338
+ if (Info.isInvalid ()) {
3339
+ return AddEquatableContext ();
3340
+ }
3341
+ if (!Info.IsRef ) {
3342
+ if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD )) {
3343
+ return AddEquatableContext (NomDecl);
3344
+ }
3345
+ } else if (auto *ExtDecl = Info.ExtTyRef ) {
3346
+ return AddEquatableContext (ExtDecl);
3347
+ }
3348
+ return AddEquatableContext ();
3349
+ }
3350
+
3351
+ void AddEquatableContext::
3352
+ printFunctionBody (ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
3353
+ llvm::SmallString<128 > Return;
3354
+ llvm::raw_svector_ostream SS (Return);
3355
+ SS << tok::kw_return;
3356
+ StringRef Space = " " ;
3357
+ StringRef AdditionalSpace = " " ;
3358
+ StringRef Point = " ." ;
3359
+ StringRef Join = " == " ;
3360
+ StringRef And = " &&" ;
3361
+ auto Props = getPublicProperties ();
3362
+ auto FParam = Params->get (0 )->getName ();
3363
+ auto SParam = Params->get (1 )->getName ();
3364
+ auto Prop = Props[0 ]->getName ();
3365
+ Printer << ExtraIndent << Return << Space
3366
+ << FParam << Point << Prop << Join << SParam << Point << Prop;
3367
+ if (Props.size () > 1 ) {
3368
+ std::for_each (Props.begin () + 1 , Props.end (), [&](VarDecl *VD){
3369
+ auto Name = VD->getName ();
3370
+ Printer << And;
3371
+ Printer.printNewline ();
3372
+ Printer << ExtraIndent << AdditionalSpace << FParam << Point
3373
+ << Name << Join << SParam << Point << Name;
3374
+ });
3375
+ }
3376
+ }
3377
+
3378
+ bool RefactoringActionAddEquatableConformance::
3379
+ isApplicable (ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
3380
+ return AddEquatableContext::getDeclarationContextFromInfo (Tok).isValid ();
3381
+ }
3382
+
3383
+ bool RefactoringActionAddEquatableConformance::
3384
+ performChange () {
3385
+ auto Context = AddEquatableContext::getDeclarationContextFromInfo (CursorInfo);
3386
+ EditConsumer.insertAfter (SM, Context.getStartLocForProtocolDecl (),
3387
+ Context.getDeclForProtocol ());
3388
+ EditConsumer.insertAfter (SM, Context.getInsertStartLoc (),
3389
+ Context.getDeclForFunction (SM));
3390
+ return false ;
3391
+ }
3392
+
3147
3393
static CharSourceRange
3148
3394
findSourceRangeToWrapInCatch (ResolvedCursorInfo CursorInfo,
3149
3395
SourceFile *TheFile,
0 commit comments