@@ -196,6 +196,33 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
196
196
});
197
197
}
198
198
199
+ // / Determine if a EuclideanDifferentiable requirement can be derived for a type.
200
+ // /
201
+ // / \returns True if the requirement can be derived.
202
+ bool DerivedConformance::canDeriveEuclideanDifferentiable (
203
+ NominalTypeDecl *nominal, DeclContext *DC) {
204
+ if (!canDeriveDifferentiable (nominal, DC))
205
+ return false ;
206
+ auto &C = nominal->getASTContext ();
207
+ auto *lazyResolver = C.getLazyResolver ();
208
+ auto *addArithProto = C.getProtocol (KnownProtocolKind::AdditiveArithmetic);
209
+ // Return true if all differentiation stored properties conform to
210
+ // `AdditiveArithmetic` and their `TangentVector` equals themselves.
211
+ SmallVector<VarDecl *, 16 > diffProperties;
212
+ getStoredPropertiesForDifferentiation (nominal, DC, diffProperties);
213
+ return llvm::all_of (diffProperties, [&](VarDecl *member) {
214
+ if (!member->hasInterfaceType ())
215
+ lazyResolver->resolveDeclSignature (member);
216
+ if (!member->hasInterfaceType ())
217
+ return false ;
218
+ auto varType = DC->mapTypeIntoContext (member->getValueInterfaceType ());
219
+ if (!TypeChecker::conformsToProtocol (varType, addArithProto, DC, None))
220
+ return false ;
221
+ auto memberAssocType = getTangentVectorType (member, DC);
222
+ return member->getType ()->isEqual (memberAssocType);
223
+ });
224
+ }
225
+
199
226
// Synthesize body for a `Differentiable` method requirement.
200
227
static std::pair<BraceStmt *, bool >
201
228
deriveBodyDifferentiable_method (AbstractFunctionDecl *funcDecl,
@@ -343,6 +370,90 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
343
370
{deriveBodyDifferentiable_move, nullptr });
344
371
}
345
372
373
+ // Synthesize the `vectorView` property declaration.
374
+ static ValueDecl *deriveEuclideanDifferentiable_vectorView (
375
+ DerivedConformance &derived) {
376
+ auto &C = derived.TC .Context ;
377
+ auto *parentDC = derived.getConformanceContext ();
378
+
379
+ auto *tangentDecl = getTangentVectorStructDecl (parentDC);
380
+ auto tangentType = tangentDecl->getDeclaredInterfaceType ();
381
+ auto tangentContextualType = parentDC->mapTypeIntoContext (tangentType);
382
+
383
+ VarDecl *vectorViewDecl;
384
+ PatternBindingDecl *pbDecl;
385
+ std::tie (vectorViewDecl, pbDecl) = derived.declareDerivedProperty (
386
+ C.Id_vectorView , tangentType, tangentContextualType, /* isStatic*/ false ,
387
+ /* isFinal*/ true );
388
+
389
+ struct GetterSynthesizerContext {
390
+ StructDecl *tangentDecl;
391
+ Type tangentContextualType;
392
+ };
393
+
394
+ auto getterSynthesizer = [](AbstractFunctionDecl *getterDecl, void *ctx)
395
+ -> std::pair<BraceStmt *, bool > {
396
+ auto *context = reinterpret_cast <GetterSynthesizerContext *>(ctx);
397
+ assert (context && " Invalid context" );
398
+ auto *parentDC = getterDecl->getParent ();
399
+ auto *nominal = parentDC->getSelfNominalTypeDecl ();
400
+ auto &C = nominal->getASTContext ();
401
+ SmallVector<VarDecl *, 8 > diffProperties;
402
+ getStoredPropertiesForDifferentiation (nominal, nominal->getDeclContext (),
403
+ diffProperties);
404
+
405
+ // Create a reference to the memberwise initializer: `TangentVector.init`.
406
+ auto *memberwiseInitDecl =
407
+ context->tangentDecl ->getEffectiveMemberwiseInitializer ();
408
+ assert (memberwiseInitDecl && " Memberwise initializer must exist" );
409
+ // `TangentVector`
410
+ auto *tangentTypeExpr =
411
+ TypeExpr::createImplicit (context->tangentContextualType , C);
412
+ // `TangentVector.init`
413
+ auto *initDRE = new (C) DeclRefExpr (memberwiseInitDecl, DeclNameLoc (),
414
+ /* Implicit*/ true );
415
+ initDRE->setFunctionRefKind (FunctionRefKind::SingleApply);
416
+ auto *initExpr = new (C) ConstructorRefCallExpr (initDRE, tangentTypeExpr);
417
+ initExpr->setThrows (false );
418
+ initExpr->setImplicit ();
419
+
420
+ // Create a call:
421
+ // TangentVector.init(
422
+ // <property_name_1...>: self.<property_name_1>,
423
+ // <property_name_2...>: self.<property_name_2>,
424
+ // ...
425
+ // )
426
+ SmallVector<Identifier, 8 > argLabels;
427
+ SmallVector<Expr *, 8 > memberRefs;
428
+ auto *selfDRE = new (C) DeclRefExpr (getterDecl->getImplicitSelfDecl (),
429
+ DeclNameLoc (),
430
+ /* Implicit*/ true );
431
+ for (auto *member : diffProperties) {
432
+ argLabels.push_back (member->getName ());
433
+ memberRefs.push_back (
434
+ new (C) MemberRefExpr (selfDRE, SourceLoc (), member, DeclNameLoc (),
435
+ /* Implicit*/ true ));
436
+ }
437
+ assert (memberRefs.size () == argLabels.size ());
438
+ CallExpr *callExpr =
439
+ CallExpr::createImplicit (C, initExpr, memberRefs, argLabels);
440
+
441
+ // Create a return statement: `return TangentVector.init(...)`.
442
+ ASTNode retStmt =
443
+ new (C) ReturnStmt (SourceLoc (), callExpr, /* implicit*/ true );
444
+ auto *braceStmt = BraceStmt::create (C, SourceLoc (), retStmt, SourceLoc (),
445
+ /* implicit*/ true );
446
+ return std::make_tuple (braceStmt, false );
447
+ };
448
+ auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty (
449
+ vectorViewDecl, tangentContextualType);
450
+ getterDecl->setBodySynthesizer (
451
+ getterSynthesizer, /* context*/ C.AllocateObjectCopy (
452
+ GetterSynthesizerContext{tangentDecl, tangentContextualType}));
453
+ derived.addMembersToConformanceContext ({vectorViewDecl, pbDecl});
454
+ return vectorViewDecl;
455
+ }
456
+
346
457
// Return associated `TangentVector` struct for a nominal type, if it exists.
347
458
// If not, synthesize the struct.
348
459
static StructDecl *
@@ -362,8 +473,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
362
473
return structDecl;
363
474
}
364
475
365
- // Otherwise, synthesize a new struct. The struct must conform to
366
- // `Differentiable`.
476
+ // Otherwise, synthesize a new struct.
367
477
auto *diffableProto = C.getProtocol (KnownProtocolKind::Differentiable);
368
478
auto diffableType = TypeLoc::withoutLoc (diffableProto->getDeclaredType ());
369
479
auto *addArithProto = C.getProtocol (KnownProtocolKind::AdditiveArithmetic);
@@ -378,9 +488,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
378
488
auto *kpIterableProto = C.getProtocol (KnownProtocolKind::KeyPathIterable);
379
489
auto kpIterableType = TypeLoc::withoutLoc (kpIterableProto->getDeclaredType ());
380
490
381
- SmallVector<TypeLoc, 4 > inherited{diffableType};
382
- // `TangentVector` must conform to `AdditiveArithmetic`.
383
- inherited. push_back ( addArithType) ;
491
+ // By definition, `TangentVector` must conform to `EuclideanDifferentiable`
492
+ // and `AdditiveArithmetic`.
493
+ SmallVector<TypeLoc, 4 > inherited{diffableType, addArithType} ;
384
494
385
495
// Cache original members and their associated types for later use.
386
496
SmallVector<VarDecl *, 8 > diffProperties;
@@ -756,3 +866,18 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
756
866
TC.diagnose (requirement->getLoc (), diag::broken_differentiable_requirement);
757
867
return nullptr ;
758
868
}
869
+
870
+ // / Derive a EuclideanDifferentiable requirement for a nominal type.
871
+ // /
872
+ // / \returns the derived member, which will also be added to the type.
873
+ ValueDecl *DerivedConformance::deriveEuclideanDifferentiable (
874
+ ValueDecl *requirement) {
875
+ // Diagnose conformances in disallowed contexts.
876
+ if (checkAndDiagnoseDisallowedContext (requirement))
877
+ return nullptr ;
878
+ if (requirement->getFullName () == TC.Context .Id_vectorView )
879
+ return deriveEuclideanDifferentiable_vectorView (*this );
880
+ TC.diagnose (requirement->getLoc (),
881
+ diag::broken_euclidean_differentiable_requirement);
882
+ return nullptr ;
883
+ }
0 commit comments