@@ -3172,6 +3172,250 @@ bool RefactoringActionMemberwiseInitLocalRefactoring::performChange() {
3172
3172
return false ;
3173
3173
}
3174
3174
3175
+ class AddEquatableContext {
3176
+
3177
+ // / Declaration context
3178
+ DeclContext *DC;
3179
+
3180
+ // / Adopter type
3181
+ Type Adopter;
3182
+
3183
+ // / Start location of declaration context brace
3184
+ SourceLoc StartLoc;
3185
+
3186
+ // / Array of all inherited protocols' locations
3187
+ ArrayRef<TypeLoc> ProtocolsLocations;
3188
+
3189
+ // / Array of all conformed protocols
3190
+ SmallVector<swift::ProtocolDecl *, 2 > Protocols;
3191
+
3192
+ // / Start location of declaration,
3193
+ // / a place to write protocol name
3194
+ SourceLoc ProtInsertStartLoc;
3195
+
3196
+ // / Stored properties of extending adopter
3197
+ ArrayRef<VarDecl *> StoredProperties;
3198
+
3199
+ // / Range of internal members in declaration
3200
+ DeclRange Range;
3201
+
3202
+ bool conformsToEquatableProtocol () {
3203
+ for (ProtocolDecl *Protocol : Protocols) {
3204
+ if (Protocol->getKnownProtocolKind () == KnownProtocolKind::Equatable) {
3205
+ return true ;
3206
+ }
3207
+ }
3208
+ return false ;
3209
+ }
3210
+
3211
+ bool isRequirementValid () {
3212
+ auto Reqs = getProtocolRequirements ();
3213
+ if (Reqs.empty ()) {
3214
+ return false ;
3215
+ }
3216
+ auto Req = dyn_cast<FuncDecl>(Reqs[0 ]);
3217
+ return Req && Req->getParameters ()->size () == 2 ;
3218
+ }
3219
+
3220
+ bool isPropertiesListValid () {
3221
+ return !getUserAccessibleProperties ().empty ();
3222
+ }
3223
+
3224
+ void printFunctionBody (ASTPrinter &Printer, StringRef ExtraIndent,
3225
+ ParameterList *Params);
3226
+
3227
+ std::vector<ValueDecl *> getProtocolRequirements ();
3228
+
3229
+ std::vector<VarDecl *> getUserAccessibleProperties ();
3230
+
3231
+ public:
3232
+
3233
+ AddEquatableContext (NominalTypeDecl *Decl) : DC(Decl),
3234
+ Adopter (Decl->getDeclaredType ()), StartLoc(Decl->getBraces ().Start),
3235
+ ProtocolsLocations(Decl->getInherited ()),
3236
+ Protocols(Decl->getAllProtocols ()), ProtInsertStartLoc(Decl->getNameLoc ()),
3237
+ StoredProperties(Decl->getStoredProperties ()), Range(Decl->getMembers ()) {};
3238
+
3239
+ AddEquatableContext (ExtensionDecl *Decl) : DC(Decl),
3240
+ Adopter(Decl->getExtendedType ()), StartLoc(Decl->getBraces ().Start),
3241
+ ProtocolsLocations(Decl->getInherited ()),
3242
+ Protocols(Decl->getExtendedNominal ()->getAllProtocols()),
3243
+ ProtInsertStartLoc(Decl->getExtendedTypeRepr ()->getEndLoc()),
3244
+ StoredProperties(Decl->getExtendedNominal ()->getStoredProperties()), Range(Decl->getMembers ()) {};
3245
+
3246
+ AddEquatableContext () : DC(nullptr ), Adopter(), ProtocolsLocations(),
3247
+ Protocols(), StoredProperties(), Range(nullptr , nullptr ) {};
3248
+
3249
+ static AddEquatableContext getDeclarationContextFromInfo (ResolvedCursorInfo Info);
3250
+
3251
+ std::string getInsertionTextForProtocol ();
3252
+
3253
+ std::string getInsertionTextForFunction (SourceManager &SM);
3254
+
3255
+ bool isValid () {
3256
+ // FIXME: Allow to generate explicit == method for declarations which already have
3257
+ // compiler-generated == method
3258
+ return StartLoc.isValid () && ProtInsertStartLoc.isValid () &&
3259
+ !conformsToEquatableProtocol () && isPropertiesListValid () &&
3260
+ isRequirementValid ();
3261
+ }
3262
+
3263
+ SourceLoc getStartLocForProtocolDecl () {
3264
+ if (ProtocolsLocations.empty ()) {
3265
+ return ProtInsertStartLoc;
3266
+ }
3267
+ return ProtocolsLocations.back ().getSourceRange ().Start ;
3268
+ }
3269
+
3270
+ bool isMembersRangeEmpty () {
3271
+ return Range.empty ();
3272
+ }
3273
+
3274
+ SourceLoc getInsertStartLoc ();
3275
+ };
3276
+
3277
+ SourceLoc AddEquatableContext::
3278
+ getInsertStartLoc () {
3279
+ SourceLoc MaxLoc = StartLoc;
3280
+ for (auto Mem : Range) {
3281
+ if (Mem->getEndLoc ().getOpaquePointerValue () >
3282
+ MaxLoc.getOpaquePointerValue ()) {
3283
+ MaxLoc = Mem->getEndLoc ();
3284
+ }
3285
+ }
3286
+ return MaxLoc;
3287
+ }
3288
+
3289
+ std::string AddEquatableContext::
3290
+ getInsertionTextForProtocol () {
3291
+ StringRef ProtocolName = getProtocolName (KnownProtocolKind::Equatable);
3292
+ std::string Buffer;
3293
+ llvm::raw_string_ostream OS (Buffer);
3294
+ if (ProtocolsLocations.empty ()) {
3295
+ OS << " : " << ProtocolName;
3296
+ return Buffer;
3297
+ }
3298
+ OS << " , " << ProtocolName;
3299
+ return Buffer;
3300
+ }
3301
+
3302
+ std::string AddEquatableContext::
3303
+ getInsertionTextForFunction (SourceManager &SM) {
3304
+ auto Reqs = getProtocolRequirements ();
3305
+ auto Req = dyn_cast<FuncDecl>(Reqs[0 ]);
3306
+ auto Params = Req->getParameters ();
3307
+ StringRef ExtraIndent;
3308
+ StringRef CurrentIndent =
3309
+ Lexer::getIndentationForLine (SM, getInsertStartLoc (), &ExtraIndent);
3310
+ std::string Indent;
3311
+ if (isMembersRangeEmpty ()) {
3312
+ Indent = (CurrentIndent + ExtraIndent).str ();
3313
+ } else {
3314
+ Indent = CurrentIndent.str ();
3315
+ }
3316
+ PrintOptions Options = PrintOptions::printVerbose ();
3317
+ Options.PrintDocumentationComments = false ;
3318
+ Options.setBaseType (Adopter);
3319
+ Options.FunctionBody = [&](const ValueDecl *VD, ASTPrinter &Printer) {
3320
+ Printer << " {" ;
3321
+ Printer.printNewline ();
3322
+ printFunctionBody (Printer, ExtraIndent, Params);
3323
+ Printer.printNewline ();
3324
+ Printer << " }" ;
3325
+ };
3326
+ std::string Buffer;
3327
+ llvm::raw_string_ostream OS (Buffer);
3328
+ ExtraIndentStreamPrinter Printer (OS, Indent);
3329
+ Printer.printNewline ();
3330
+ if (!isMembersRangeEmpty ()) {
3331
+ Printer.printNewline ();
3332
+ }
3333
+ Reqs[0 ]->print (Printer, Options);
3334
+ return Buffer;
3335
+ }
3336
+
3337
+ std::vector<VarDecl *> AddEquatableContext::
3338
+ getUserAccessibleProperties () {
3339
+ std::vector<VarDecl *> PublicProperties;
3340
+ for (VarDecl *Decl : StoredProperties) {
3341
+ if (Decl->Decl ::isUserAccessible ()) {
3342
+ PublicProperties.push_back (Decl);
3343
+ }
3344
+ }
3345
+ return PublicProperties;
3346
+ }
3347
+
3348
+ std::vector<ValueDecl *> AddEquatableContext::
3349
+ getProtocolRequirements () {
3350
+ std::vector<ValueDecl *> Collection;
3351
+ auto Proto = DC->getASTContext ().getProtocol (KnownProtocolKind::Equatable);
3352
+ for (auto Member : Proto->getMembers ()) {
3353
+ auto Req = dyn_cast<ValueDecl>(Member);
3354
+ if (!Req || Req->isInvalid () || !Req->isProtocolRequirement ()) {
3355
+ continue ;
3356
+ }
3357
+ Collection.push_back (Req);
3358
+ }
3359
+ return Collection;
3360
+ }
3361
+
3362
+ AddEquatableContext AddEquatableContext::
3363
+ getDeclarationContextFromInfo (ResolvedCursorInfo Info) {
3364
+ if (Info.isInvalid ()) {
3365
+ return AddEquatableContext ();
3366
+ }
3367
+ if (!Info.IsRef ) {
3368
+ if (auto *NomDecl = dyn_cast<NominalTypeDecl>(Info.ValueD )) {
3369
+ return AddEquatableContext (NomDecl);
3370
+ }
3371
+ } else if (auto *ExtDecl = Info.ExtTyRef ) {
3372
+ return AddEquatableContext (ExtDecl);
3373
+ }
3374
+ return AddEquatableContext ();
3375
+ }
3376
+
3377
+ void AddEquatableContext::
3378
+ printFunctionBody (ASTPrinter &Printer, StringRef ExtraIndent, ParameterList *Params) {
3379
+ llvm::SmallString<128 > Return;
3380
+ llvm::raw_svector_ostream SS (Return);
3381
+ SS << tok::kw_return;
3382
+ StringRef Space = " " ;
3383
+ StringRef AdditionalSpace = " " ;
3384
+ StringRef Point = " ." ;
3385
+ StringRef Join = " == " ;
3386
+ StringRef And = " &&" ;
3387
+ auto Props = getUserAccessibleProperties ();
3388
+ auto FParam = Params->get (0 )->getName ();
3389
+ auto SParam = Params->get (1 )->getName ();
3390
+ auto Prop = Props[0 ]->getName ();
3391
+ Printer << ExtraIndent << Return << Space
3392
+ << FParam << Point << Prop << Join << SParam << Point << Prop;
3393
+ if (Props.size () > 1 ) {
3394
+ std::for_each (Props.begin () + 1 , Props.end (), [&](VarDecl *VD){
3395
+ auto Name = VD->getName ();
3396
+ Printer << And;
3397
+ Printer.printNewline ();
3398
+ Printer << ExtraIndent << AdditionalSpace << FParam << Point
3399
+ << Name << Join << SParam << Point << Name;
3400
+ });
3401
+ }
3402
+ }
3403
+
3404
+ bool RefactoringActionAddEquatableConformance::
3405
+ isApplicable (ResolvedCursorInfo Tok, DiagnosticEngine &Diag) {
3406
+ return AddEquatableContext::getDeclarationContextFromInfo (Tok).isValid ();
3407
+ }
3408
+
3409
+ bool RefactoringActionAddEquatableConformance::
3410
+ performChange () {
3411
+ auto Context = AddEquatableContext::getDeclarationContextFromInfo (CursorInfo);
3412
+ EditConsumer.insertAfter (SM, Context.getStartLocForProtocolDecl (),
3413
+ Context.getInsertionTextForProtocol ());
3414
+ EditConsumer.insertAfter (SM, Context.getInsertStartLoc (),
3415
+ Context.getInsertionTextForFunction (SM));
3416
+ return false ;
3417
+ }
3418
+
3175
3419
static CharSourceRange
3176
3420
findSourceRangeToWrapInCatch (ResolvedCursorInfo CursorInfo,
3177
3421
SourceFile *TheFile,
0 commit comments