32
32
33
33
using namespace swift ;
34
34
35
- // Return the protocol requirement with the specified name.
36
- // TODO: Move function to shared place for use with other derived conformances.
35
+ // / Return the protocol requirement with the specified name.
36
+ // / TODO: Move function to shared place for use with other derived conformances.
37
37
static ValueDecl *getProtocolRequirement (ProtocolDecl *proto, Identifier name) {
38
38
auto lookup = proto->lookupDirect (name);
39
39
// Erase declarations that are not protocol requirements.
@@ -46,8 +46,8 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
46
46
return lookup.front ();
47
47
}
48
48
49
- // Get the stored properties of a nominal type that are relevant for
50
- // differentiation, except the ones tagged `@noDerivative`.
49
+ // / Get the stored properties of a nominal type that are relevant for
50
+ // / differentiation, except the ones tagged `@noDerivative`.
51
51
static void
52
52
getStoredPropertiesForDifferentiation (NominalTypeDecl *nominal,
53
53
DeclContext *DC,
@@ -71,8 +71,8 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal,
71
71
}
72
72
}
73
73
74
- // Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a
75
- // `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`.
74
+ // / Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a
75
+ // / `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`.
76
76
static StructDecl *convertToStructDecl (ValueDecl *v) {
77
77
if (auto *structDecl = dyn_cast<StructDecl>(v))
78
78
return structDecl;
@@ -83,10 +83,10 @@ static StructDecl *convertToStructDecl(ValueDecl *v) {
83
83
typeDecl->getDeclaredInterfaceType ()->getAnyNominal ());
84
84
}
85
85
86
- // Get the `Differentiable` protocol `TangentVector` associated type for the
87
- // given `VarDecl`.
88
- // TODO: Generalize and move function to shared place for use with other derived
89
- // conformances.
86
+ // / Get the `Differentiable` protocol `TangentVector` associated type for the
87
+ // / given `VarDecl`.
88
+ // / TODO: Generalize and move function to shared place for use with other derived
89
+ // / conformances.
90
90
static Type getTangentVectorType (VarDecl *decl, DeclContext *DC) {
91
91
auto &C = decl->getASTContext ();
92
92
auto *diffableProto = C.getProtocol (KnownProtocolKind::Differentiable);
@@ -196,7 +196,34 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
196
196
});
197
197
}
198
198
199
- // Synthesize body for a `Differentiable` method requirement.
199
+ // / Determine if a EuclideanDifferentiable requirement can be derived for a type.
200
+ // /
201
+ // / \returns True if the requirement can be derived.
202
+ bool DerivedConformance::canDeriveEuclideanDifferentiable (
203
+ NominalTypeDecl *nominal, DeclContext *DC) {
204
+ if (!canDeriveDifferentiable (nominal, DC))
205
+ return false ;
206
+ auto &C = nominal->getASTContext ();
207
+ auto *lazyResolver = C.getLazyResolver ();
208
+ auto *addArithProto = C.getProtocol (KnownProtocolKind::AdditiveArithmetic);
209
+ // Return true if all differentiation stored properties conform to
210
+ // `AdditiveArithmetic` and their `TangentVector` equals themselves.
211
+ SmallVector<VarDecl *, 16 > diffProperties;
212
+ getStoredPropertiesForDifferentiation (nominal, DC, diffProperties);
213
+ return llvm::all_of (diffProperties, [&](VarDecl *member) {
214
+ if (!member->hasInterfaceType ())
215
+ lazyResolver->resolveDeclSignature (member);
216
+ if (!member->hasInterfaceType ())
217
+ return false ;
218
+ auto varType = DC->mapTypeIntoContext (member->getValueInterfaceType ());
219
+ if (!TypeChecker::conformsToProtocol (varType, addArithProto, DC, None))
220
+ return false ;
221
+ auto memberAssocType = getTangentVectorType (member, DC);
222
+ return member->getType ()->isEqual (memberAssocType);
223
+ });
224
+ }
225
+
226
+ // / Synthesize body for a `Differentiable` method requirement.
200
227
static std::pair<BraceStmt *, bool >
201
228
deriveBodyDifferentiable_method (AbstractFunctionDecl *funcDecl,
202
229
Identifier methodName,
@@ -283,15 +310,15 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
283
310
return std::pair<BraceStmt *, bool >(braceStmt, false );
284
311
}
285
312
286
- // Synthesize body for `move(along:)`.
313
+ // / Synthesize body for `move(along:)`.
287
314
static std::pair<BraceStmt *, bool >
288
315
deriveBodyDifferentiable_move (AbstractFunctionDecl *funcDecl, void *) {
289
316
auto &C = funcDecl->getASTContext ();
290
317
return deriveBodyDifferentiable_method (funcDecl, C.Id_move ,
291
318
C.getIdentifier (" along" ));
292
319
}
293
320
294
- // Synthesize function declaration for a `Differentiable` method requirement.
321
+ // / Synthesize function declaration for a `Differentiable` method requirement.
295
322
static ValueDecl *deriveDifferentiable_method (
296
323
DerivedConformance &derived, Identifier methodName, Identifier argumentName,
297
324
Identifier parameterName, Type parameterType, Type returnType,
@@ -329,7 +356,7 @@ static ValueDecl *deriveDifferentiable_method(
329
356
return funcDecl;
330
357
}
331
358
332
- // Synthesize the `move(along:)` function declaration.
359
+ // / Synthesize the `move(along:)` function declaration.
333
360
static ValueDecl *deriveDifferentiable_move (DerivedConformance &derived) {
334
361
auto &C = derived.TC .Context ;
335
362
auto *parentDC = derived.getConformanceContext ();
@@ -343,8 +370,92 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
343
370
{deriveBodyDifferentiable_move, nullptr });
344
371
}
345
372
346
- // Return associated `TangentVector` struct for a nominal type, if it exists.
347
- // If not, synthesize the struct.
373
+ // / Synthesize the `vectorView` property declaration.
374
+ static ValueDecl *deriveEuclideanDifferentiable_vectorView (
375
+ DerivedConformance &derived) {
376
+ auto &C = derived.TC .Context ;
377
+ auto *parentDC = derived.getConformanceContext ();
378
+
379
+ auto *tangentDecl = getTangentVectorStructDecl (parentDC);
380
+ auto tangentType = tangentDecl->getDeclaredInterfaceType ();
381
+ auto tangentContextualType = parentDC->mapTypeIntoContext (tangentType);
382
+
383
+ VarDecl *vectorViewDecl;
384
+ PatternBindingDecl *pbDecl;
385
+ std::tie (vectorViewDecl, pbDecl) = derived.declareDerivedProperty (
386
+ C.Id_vectorView , tangentType, tangentContextualType, /* isStatic*/ false ,
387
+ /* isFinal*/ true );
388
+
389
+ struct GetterSynthesizerContext {
390
+ StructDecl *tangentDecl;
391
+ Type tangentContextualType;
392
+ };
393
+
394
+ auto getterSynthesizer = [](AbstractFunctionDecl *getterDecl, void *ctx)
395
+ -> std::pair<BraceStmt *, bool > {
396
+ auto *context = reinterpret_cast <GetterSynthesizerContext *>(ctx);
397
+ assert (context && " Invalid context" );
398
+ auto *parentDC = getterDecl->getParent ();
399
+ auto *nominal = parentDC->getSelfNominalTypeDecl ();
400
+ auto &C = nominal->getASTContext ();
401
+ SmallVector<VarDecl *, 8 > diffProperties;
402
+ getStoredPropertiesForDifferentiation (nominal, nominal->getDeclContext (),
403
+ diffProperties);
404
+
405
+ // Create a reference to the memberwise initializer: `TangentVector.init`.
406
+ auto *memberwiseInitDecl =
407
+ context->tangentDecl ->getEffectiveMemberwiseInitializer ();
408
+ assert (memberwiseInitDecl && " Memberwise initializer must exist" );
409
+ // `TangentVector`
410
+ auto *tangentTypeExpr =
411
+ TypeExpr::createImplicit (context->tangentContextualType , C);
412
+ // `TangentVector.init`
413
+ auto *initDRE = new (C) DeclRefExpr (memberwiseInitDecl, DeclNameLoc (),
414
+ /* Implicit*/ true );
415
+ initDRE->setFunctionRefKind (FunctionRefKind::SingleApply);
416
+ auto *initExpr = new (C) ConstructorRefCallExpr (initDRE, tangentTypeExpr);
417
+ initExpr->setThrows (false );
418
+ initExpr->setImplicit ();
419
+
420
+ // Create a call:
421
+ // TangentVector.init(
422
+ // <property_name_1...>: self.<property_name_1>,
423
+ // <property_name_2...>: self.<property_name_2>,
424
+ // ...
425
+ // )
426
+ SmallVector<Identifier, 8 > argLabels;
427
+ SmallVector<Expr *, 8 > memberRefs;
428
+ auto *selfDRE = new (C) DeclRefExpr (getterDecl->getImplicitSelfDecl (),
429
+ DeclNameLoc (),
430
+ /* Implicit*/ true );
431
+ for (auto *member : diffProperties) {
432
+ argLabels.push_back (member->getName ());
433
+ memberRefs.push_back (
434
+ new (C) MemberRefExpr (selfDRE, SourceLoc (), member, DeclNameLoc (),
435
+ /* Implicit*/ true ));
436
+ }
437
+ assert (memberRefs.size () == argLabels.size ());
438
+ CallExpr *callExpr =
439
+ CallExpr::createImplicit (C, initExpr, memberRefs, argLabels);
440
+
441
+ // Create a return statement: `return TangentVector.init(...)`.
442
+ ASTNode retStmt =
443
+ new (C) ReturnStmt (SourceLoc (), callExpr, /* implicit*/ true );
444
+ auto *braceStmt = BraceStmt::create (C, SourceLoc (), retStmt, SourceLoc (),
445
+ /* implicit*/ true );
446
+ return std::make_pair (braceStmt, false );
447
+ };
448
+ auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty (
449
+ vectorViewDecl, tangentContextualType);
450
+ getterDecl->setBodySynthesizer (
451
+ getterSynthesizer, /* context*/ C.AllocateObjectCopy (
452
+ GetterSynthesizerContext{tangentDecl, tangentContextualType}));
453
+ derived.addMembersToConformanceContext ({vectorViewDecl, pbDecl});
454
+ return vectorViewDecl;
455
+ }
456
+
457
+ // / Return associated `TangentVector` struct for a nominal type, if it exists.
458
+ // / If not, synthesize the struct.
348
459
static StructDecl *
349
460
getOrSynthesizeTangentVectorStruct (DerivedConformance &derived, Identifier id) {
350
461
auto &TC = derived.TC ;
@@ -362,8 +473,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
362
473
return structDecl;
363
474
}
364
475
365
- // Otherwise, synthesize a new struct. The struct must conform to
366
- // `Differentiable`.
476
+ // Otherwise, synthesize a new struct.
367
477
auto *diffableProto = C.getProtocol (KnownProtocolKind::Differentiable);
368
478
auto diffableType = TypeLoc::withoutLoc (diffableProto->getDeclaredType ());
369
479
auto *addArithProto = C.getProtocol (KnownProtocolKind::AdditiveArithmetic);
@@ -378,9 +488,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
378
488
auto *kpIterableProto = C.getProtocol (KnownProtocolKind::KeyPathIterable);
379
489
auto kpIterableType = TypeLoc::withoutLoc (kpIterableProto->getDeclaredType ());
380
490
381
- SmallVector<TypeLoc, 4 > inherited{diffableType};
382
- // `TangentVector` must conform to ` AdditiveArithmetic`.
383
- inherited. push_back ( addArithType) ;
491
+ // By definition, `TangentVector` must conform to `Differentiable` and
492
+ // `AdditiveArithmetic`.
493
+ SmallVector<TypeLoc, 4 > inherited{diffableType, addArithType} ;
384
494
385
495
// Cache original members and their associated types for later use.
386
496
SmallVector<VarDecl *, 8 > diffProperties;
@@ -551,8 +661,8 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
551
661
return structDecl;
552
662
}
553
663
554
- // Add a typealias declaration with the given name and underlying target
555
- // struct type to the given source nominal declaration context.
664
+ // / Add a typealias declaration with the given name and underlying target
665
+ // / struct type to the given source nominal declaration context.
556
666
static void addAssociatedTypeAliasDecl (Identifier name,
557
667
DeclContext *sourceDC,
558
668
StructDecl *target,
@@ -585,11 +695,11 @@ static void addAssociatedTypeAliasDecl(Identifier name,
585
695
C.addSynthesizedDecl (aliasDecl);
586
696
};
587
697
588
- // Diagnose stored properties in the nominal that do not have an explicit
589
- // `@noDerivative` attribute, but either:
590
- // - Do not conform to `Differentiable`.
591
- // - Are a `let` stored property.
592
- // Emit a warning and a fixit so that users will make the attribute explicit.
698
+ // / Diagnose stored properties in the nominal that do not have an explicit
699
+ // / `@noDerivative` attribute, but either:
700
+ // / - Do not conform to `Differentiable`.
701
+ // / - Are a `let` stored property.
702
+ // / Emit a warning and a fixit so that users will make the attribute explicit.
593
703
static void checkAndDiagnoseImplicitNoDerivative (TypeChecker &TC,
594
704
NominalTypeDecl *nominal,
595
705
DeclContext* DC) {
@@ -637,7 +747,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
637
747
}
638
748
}
639
749
640
- // Get or synthesize `TangentVector` struct type.
750
+ // / Get or synthesize `TangentVector` struct type.
641
751
static Type
642
752
getOrSynthesizeTangentVectorStructType (DerivedConformance &derived) {
643
753
auto &TC = derived.TC ;
@@ -669,7 +779,7 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
669
779
tangentStruct->getDeclaredInterfaceType ());
670
780
}
671
781
672
- // Synthesize the `TangentVector` struct type.
782
+ // / Synthesize the `TangentVector` struct type.
673
783
static Type
674
784
deriveDifferentiable_TangentVectorStruct (DerivedConformance &derived) {
675
785
auto &TC = derived.TC ;
@@ -756,3 +866,18 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
756
866
TC.diagnose (requirement->getLoc (), diag::broken_differentiable_requirement);
757
867
return nullptr ;
758
868
}
869
+
870
+ // / Derive a EuclideanDifferentiable requirement for a nominal type.
871
+ // /
872
+ // / \returns the derived member, which will also be added to the type.
873
+ ValueDecl *DerivedConformance::deriveEuclideanDifferentiable (
874
+ ValueDecl *requirement) {
875
+ // Diagnose conformances in disallowed contexts.
876
+ if (checkAndDiagnoseDisallowedContext (requirement))
877
+ return nullptr ;
878
+ if (requirement->getFullName () == TC.Context .Id_vectorView )
879
+ return deriveEuclideanDifferentiable_vectorView (*this );
880
+ TC.diagnose (requirement->getLoc (),
881
+ diag::broken_euclidean_differentiable_requirement);
882
+ return nullptr ;
883
+ }
0 commit comments