26
26
#include " swift/AST/Module.h"
27
27
#include " swift/AST/ModuleLoader.h"
28
28
#include " swift/AST/ProtocolConformance.h"
29
+ #include " swift/AST/TypeCheckRequests.h"
29
30
#include " swift/ClangImporter/ClangImporter.h"
30
31
#include " swift/SIL/SILModule.h"
31
32
#include " swift/SIL/SILType.h"
@@ -360,6 +361,41 @@ getSemanticResults(SILFunctionType *functionType, IndexSubset *parameterIndices,
360
361
IndexSubset::get (C, parameterIndices->getCapacity (), inoutParamIndices);
361
362
}
362
363
364
+ static CanGenericSignature buildDifferentiableGenericSignature (CanGenericSignature sig,
365
+ CanType tanType) {
366
+ if (!sig)
367
+ return sig;
368
+
369
+ llvm::DenseSet<CanType> types;
370
+
371
+ auto &ctx = tanType->getASTContext ();
372
+
373
+ (void ) tanType.findIf ([&](Type t) -> bool {
374
+ if (auto *dmt = t->getAs <DependentMemberType>()) {
375
+ if (dmt->getName () == ctx.Id_TangentVector )
376
+ types.insert (dmt->getBase ()->getCanonicalType ());
377
+ }
378
+
379
+ return false ;
380
+ });
381
+
382
+ SmallVector<Requirement, 2 > reqs;
383
+ auto *proto = ctx.getProtocol (KnownProtocolKind::Differentiable);
384
+ assert (proto != nullptr );
385
+
386
+ for (auto type : types) {
387
+ if (!sig->requiresProtocol (type, proto)) {
388
+ reqs.push_back (Requirement (RequirementKind::Conformance, type,
389
+ proto->getDeclaredInterfaceType ()));
390
+ }
391
+ }
392
+
393
+ return evaluateOrDefault (
394
+ ctx.evaluator ,
395
+ AbstractGenericSignatureRequest{sig.getPointer (), {}, reqs},
396
+ GenericSignature ()).getCanonicalSignature ();
397
+ }
398
+
363
399
// / Returns the differential type for the given original function type,
364
400
// / parameter indices, and result index.
365
401
static CanSILFunctionType getAutoDiffDifferentialType (
@@ -371,10 +407,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
371
407
auto getTangentParameterConvention =
372
408
[&](CanType tanType,
373
409
ParameterConvention origParamConv) -> ParameterConvention {
374
- tanType =
375
- tanType->getCanonicalType (originalFnTy->getSubstGenericSignature ());
376
- AbstractionPattern pattern (originalFnTy->getSubstGenericSignature (),
377
- tanType);
410
+ auto sig = buildDifferentiableGenericSignature (
411
+ originalFnTy->getSubstGenericSignature (), tanType);
412
+
413
+ tanType = tanType->getCanonicalType (sig);
414
+ AbstractionPattern pattern (sig, tanType);
378
415
auto &tl =
379
416
TC.getTypeLowering (pattern, tanType, TypeExpansionContext::minimal ());
380
417
// When the tangent type is address only, we must ensure that the tangent
@@ -398,10 +435,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
398
435
auto getTangentResultConvention =
399
436
[&](CanType tanType,
400
437
ResultConvention origResConv) -> ResultConvention {
401
- tanType =
402
- tanType->getCanonicalType (originalFnTy->getSubstGenericSignature ());
403
- AbstractionPattern pattern (originalFnTy->getSubstGenericSignature (),
404
- tanType);
438
+ auto sig = buildDifferentiableGenericSignature (
439
+ originalFnTy->getSubstGenericSignature (), tanType);
440
+
441
+ tanType = tanType->getCanonicalType (sig);
442
+ AbstractionPattern pattern (sig, tanType);
405
443
auto &tl =
406
444
TC.getTypeLowering (pattern, tanType, TypeExpansionContext::minimal ());
407
445
// When the tangent type is address only, we must ensure that the tangent
@@ -530,10 +568,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
530
568
auto getTangentParameterConventionForOriginalResult =
531
569
[&](CanType tanType,
532
570
ResultConvention origResConv) -> ParameterConvention {
533
- tanType =
534
- tanType->getCanonicalType (originalFnTy->getSubstGenericSignature ());
535
- AbstractionPattern pattern (originalFnTy->getSubstGenericSignature (),
536
- tanType);
571
+ auto sig = buildDifferentiableGenericSignature (
572
+ originalFnTy->getSubstGenericSignature (), tanType);
573
+
574
+ tanType = tanType->getCanonicalType (sig);
575
+ AbstractionPattern pattern (sig, tanType);
537
576
auto &tl =
538
577
TC.getTypeLowering (pattern, tanType, TypeExpansionContext::minimal ());
539
578
ParameterConvention conv;
@@ -560,10 +599,11 @@ static CanSILFunctionType getAutoDiffPullbackType(
560
599
auto getTangentResultConventionForOriginalParameter =
561
600
[&](CanType tanType,
562
601
ParameterConvention origParamConv) -> ResultConvention {
563
- tanType =
564
- tanType->getCanonicalType (originalFnTy->getSubstGenericSignature ());
565
- AbstractionPattern pattern (originalFnTy->getSubstGenericSignature (),
566
- tanType);
602
+ auto sig = buildDifferentiableGenericSignature (
603
+ originalFnTy->getSubstGenericSignature (), tanType);
604
+
605
+ tanType = tanType->getCanonicalType (sig);
606
+ AbstractionPattern pattern (sig, tanType);
567
607
auto &tl =
568
608
TC.getTypeLowering (pattern, tanType, TypeExpansionContext::minimal ());
569
609
ResultConvention conv;
0 commit comments