@@ -408,6 +408,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
408
408
return buildGenericSignature (ctx, sig, {}, reqs).getCanonicalSignature ();
409
409
}
410
410
411
+ // / Given an original type, computes its tangent type for the purpose of
412
+ // / building a linear map using this type. When the original type is an
413
+ // / archetype or contains a type parameter, appends a new generic parameter and
414
+ // / a corresponding replacement type to the given containers.
415
+ static CanType getAutoDiffTangentTypeForLinearMap (
416
+ Type originalType,
417
+ LookupConformanceFn lookupConformance,
418
+ SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
419
+ SmallVectorImpl<Type> &substReplacements,
420
+ ASTContext &context
421
+ ) {
422
+ auto maybeTanType = originalType->getAutoDiffTangentSpace (lookupConformance);
423
+ assert (maybeTanType && " Type does not have a tangent space?" );
424
+ auto tanType = maybeTanType->getCanonicalType ();
425
+ // If concrete, the tangent type is concrete.
426
+ if (!tanType->hasArchetype () && !tanType->hasTypeParameter ())
427
+ return tanType;
428
+ // Otherwise, the tangent type is a new generic parameter substituted for the
429
+ // tangent type.
430
+ auto gpIndex = substGenericParams.size ();
431
+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, context);
432
+ substGenericParams.push_back (gpType);
433
+ substReplacements.push_back (tanType);
434
+ return gpType;
435
+ }
436
+
411
437
// / Returns the differential type for the given original function type,
412
438
// / parameter indices, and result index.
413
439
static CanSILFunctionType getAutoDiffDifferentialType (
@@ -484,45 +510,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
484
510
getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
485
511
SmallVector<SILParameterInfo, 8 > differentialParams;
486
512
for (auto ¶m : diffParams) {
487
- auto paramTan =
488
- param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
489
- assert (paramTan && " Parameter type does not have a tangent space?" );
490
- auto paramTanType = paramTan->getCanonicalType ();
491
- auto paramConv = getTangentParameterConvention (paramTanType,
492
- param.getConvention ());
493
- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
494
- differentialParams.push_back (
495
- {paramTan->getCanonicalType (), paramConv});
496
- } else {
497
- auto gpIndex = substGenericParams.size ();
498
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
499
- substGenericParams.push_back (gpType);
500
- substReplacements.push_back (paramTanType);
501
- differentialParams.push_back ({gpType, paramConv});
502
- }
513
+ auto paramTanType = getAutoDiffTangentTypeForLinearMap (
514
+ param.getInterfaceType (), lookupConformance,
515
+ substGenericParams, substReplacements, ctx);
516
+ auto paramConv = getTangentParameterConvention (
517
+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
518
+ param.getInterfaceType ()
519
+ ->getAutoDiffTangentSpace (lookupConformance)
520
+ ->getCanonicalType (),
521
+ param.getConvention ());
522
+ differentialParams.push_back ({paramTanType, paramConv});
503
523
}
504
524
SmallVector<SILResultInfo, 1 > differentialResults;
505
525
for (auto resultIndex : resultIndices->getIndices ()) {
506
526
// Handle formal original result.
507
527
if (resultIndex < originalFnTy->getNumResults ()) {
508
528
auto &result = originalResults[resultIndex];
509
- auto resultTan =
510
- result.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
511
- assert (resultTan && " Result type does not have a tangent space?" );
512
- auto resultTanType = resultTan->getCanonicalType ();
513
- auto resultConv =
514
- getTangentResultConvention (resultTanType, result.getConvention ());
515
- if (!resultTanType->hasArchetype () &&
516
- !resultTanType->hasTypeParameter ()) {
517
- differentialResults.push_back (
518
- {resultTan->getCanonicalType (), resultConv});
519
- } else {
520
- auto gpIndex = substGenericParams.size ();
521
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
522
- substGenericParams.push_back (gpType);
523
- substReplacements.push_back (resultTanType);
524
- differentialResults.push_back ({gpType, resultConv});
525
- }
529
+ auto resultTanType = getAutoDiffTangentTypeForLinearMap (
530
+ result.getInterfaceType (), lookupConformance,
531
+ substGenericParams, substReplacements, ctx);
532
+ auto resultConv = getTangentResultConvention (
533
+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
534
+ result.getInterfaceType ()
535
+ ->getAutoDiffTangentSpace (lookupConformance)
536
+ ->getCanonicalType (),
537
+ result.getConvention ());
538
+ differentialResults.push_back ({resultTanType, resultConv});
526
539
continue ;
527
540
}
528
541
// Handle original `inout` parameter.
@@ -537,11 +550,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
537
550
if (parameterIndices->contains (paramIndex))
538
551
continue ;
539
552
auto inoutParam = originalFnTy->getParameters ()[paramIndex];
540
- auto paramTan = inoutParam. getInterfaceType ()-> getAutoDiffTangentSpace (
541
- lookupConformance);
542
- assert (paramTan && " Parameter type does not have a tangent space? " );
553
+ auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
554
+ inoutParam. getInterfaceType (), lookupConformance,
555
+ substGenericParams, substReplacements, ctx );
543
556
differentialResults.push_back (
544
- {paramTan-> getCanonicalType () , ResultConvention::Indirect});
557
+ {inoutParamTanType , ResultConvention::Indirect});
545
558
}
546
559
547
560
SubstitutionMap substitutions;
@@ -648,23 +661,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
648
661
// Handle formal original result.
649
662
if (resultIndex < originalFnTy->getNumResults ()) {
650
663
auto &origRes = originalResults[resultIndex];
651
- auto resultTan = origRes.getInterfaceType ()->getAutoDiffTangentSpace (
652
- lookupConformance);
653
- assert (resultTan && " Result type does not have a tangent space?" );
654
- auto resultTanType = resultTan->getCanonicalType ();
655
- auto paramTanConvention = getTangentParameterConventionForOriginalResult (
656
- resultTanType, origRes.getConvention ());
657
- if (!resultTanType->hasArchetype () &&
658
- !resultTanType->hasTypeParameter ()) {
659
- auto resultTanType = resultTan->getCanonicalType ();
660
- pullbackParams.push_back ({resultTanType, paramTanConvention});
661
- } else {
662
- auto gpIndex = substGenericParams.size ();
663
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
664
- substGenericParams.push_back (gpType);
665
- substReplacements.push_back (resultTanType);
666
- pullbackParams.push_back ({gpType, paramTanConvention});
667
- }
664
+ auto resultTanType = getAutoDiffTangentTypeForLinearMap (
665
+ origRes.getInterfaceType (), lookupConformance,
666
+ substGenericParams, substReplacements, ctx);
667
+ auto paramConv = getTangentParameterConventionForOriginalResult (
668
+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
669
+ origRes.getInterfaceType ()
670
+ ->getAutoDiffTangentSpace (lookupConformance)
671
+ ->getCanonicalType (),
672
+ origRes.getConvention ());
673
+ pullbackParams.push_back ({resultTanType, paramConv});
668
674
continue ;
669
675
}
670
676
// Handle original `inout` parameter.
@@ -674,28 +680,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
674
680
auto paramIndex =
675
681
std::distance (originalFnTy->getParameters ().begin (), &*inoutParamIt);
676
682
auto inoutParam = originalFnTy->getParameters ()[paramIndex];
677
- auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
678
- lookupConformance);
679
- assert (paramTan && " Parameter type does not have a tangent space?" );
680
683
// The pullback parameter convention depends on whether the original `inout`
681
684
// paramater is a differentiability parameter.
682
685
// - If yes, the pullback parameter convention is `@inout`.
683
686
// - If no, the pullback parameter convention is `@in_guaranteed`.
687
+ auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
688
+ inoutParam.getInterfaceType (), lookupConformance,
689
+ substGenericParams, substReplacements, ctx);
684
690
bool isWrtInoutParameter = parameterIndices->contains (paramIndex);
685
691
auto paramTanConvention = isWrtInoutParameter
686
- ? inoutParam.getConvention ()
687
- : ParameterConvention::Indirect_In_Guaranteed;
688
- auto paramTanType = paramTan->getCanonicalType ();
689
- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
690
- pullbackParams.push_back (
691
- SILParameterInfo (paramTanType, paramTanConvention));
692
- } else {
693
- auto gpIndex = substGenericParams.size ();
694
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
695
- substGenericParams.push_back (gpType);
696
- substReplacements.push_back (paramTanType);
697
- pullbackParams.push_back ({gpType, paramTanConvention});
698
- }
692
+ ? inoutParam.getConvention ()
693
+ : ParameterConvention::Indirect_In_Guaranteed;
694
+ pullbackParams.push_back ({inoutParamTanType, paramTanConvention});
699
695
}
700
696
701
697
// Collect pullback results.
@@ -707,21 +703,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
707
703
// and always appear as pullback parameters.
708
704
if (param.isIndirectInOut ())
709
705
continue ;
710
- auto paramTan =
711
- param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
712
- assert (paramTan && " Parameter type does not have a tangent space?" );
713
- auto paramTanType = paramTan->getCanonicalType ();
706
+ auto paramTanType = getAutoDiffTangentTypeForLinearMap (
707
+ param.getInterfaceType (), lookupConformance,
708
+ substGenericParams, substReplacements, ctx);
714
709
auto resultTanConvention = getTangentResultConventionForOriginalParameter (
715
- paramTanType, param.getConvention ());
716
- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
717
- pullbackResults.push_back ({paramTanType, resultTanConvention});
718
- } else {
719
- auto gpIndex = substGenericParams.size ();
720
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
721
- substGenericParams.push_back (gpType);
722
- substReplacements.push_back (paramTanType);
723
- pullbackResults.push_back ({gpType, resultTanConvention});
724
- }
710
+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
711
+ param.getInterfaceType ()
712
+ ->getAutoDiffTangentSpace (lookupConformance)
713
+ ->getCanonicalType (),
714
+ param.getConvention ());
715
+ pullbackResults.push_back ({paramTanType, resultTanConvention});
725
716
}
726
717
SubstitutionMap substitutions;
727
718
if (!substGenericParams.empty ()) {
0 commit comments