@@ -396,6 +396,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
396
396
GenericSignature ()).getCanonicalSignature ();
397
397
}
398
398
399
+ // / Given an original type, computes its tangent type for the purpose of
400
+ // / building a linear map using this type. When the original type is an
401
+ // / archetype or contains a type parameter, appends a new generic parameter and
402
+ // / a corresponding replacement type to the given containers.
403
+ static CanType getAutoDiffTangentTypeForLinearMap (
404
+ Type originalType,
405
+ LookupConformanceFn lookupConformance,
406
+ SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
407
+ SmallVectorImpl<Type> &substReplacements,
408
+ ASTContext &context
409
+ ) {
410
+ auto maybeTanType = originalType->getAutoDiffTangentSpace (lookupConformance);
411
+ assert (maybeTanType && " Type does not have a tangent space?" );
412
+ auto tanType = maybeTanType->getCanonicalType ();
413
+ // If concrete, the tangent type is concrete.
414
+ if (!tanType->hasArchetype () && !tanType->hasTypeParameter ())
415
+ return tanType;
416
+ // Otherwise, the tangent type is a new generic parameter substituted for the
417
+ // tangent type.
418
+ auto gpIndex = substGenericParams.size ();
419
+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, context);
420
+ substGenericParams.push_back (gpType);
421
+ substReplacements.push_back (tanType);
422
+ return gpType;
423
+ }
424
+
399
425
// / Returns the differential type for the given original function type,
400
426
// / parameter indices, and result index.
401
427
static CanSILFunctionType getAutoDiffDifferentialType (
@@ -471,45 +497,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
471
497
getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
472
498
SmallVector<SILParameterInfo, 8 > differentialParams;
473
499
for (auto ¶m : diffParams) {
474
- auto paramTan =
475
- param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
476
- assert (paramTan && " Parameter type does not have a tangent space?" );
477
- auto paramTanType = paramTan->getCanonicalType ();
478
- auto paramConv = getTangentParameterConvention (paramTanType,
479
- param.getConvention ());
480
- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
481
- differentialParams.push_back (
482
- {paramTan->getCanonicalType (), paramConv});
483
- } else {
484
- auto gpIndex = substGenericParams.size ();
485
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
486
- substGenericParams.push_back (gpType);
487
- substReplacements.push_back (paramTanType);
488
- differentialParams.push_back ({gpType, paramConv});
489
- }
500
+ auto paramTanType = getAutoDiffTangentTypeForLinearMap (
501
+ param.getInterfaceType (), lookupConformance,
502
+ substGenericParams, substReplacements, ctx);
503
+ auto paramConv = getTangentParameterConvention (
504
+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
505
+ param.getInterfaceType ()
506
+ ->getAutoDiffTangentSpace (lookupConformance)
507
+ ->getCanonicalType (),
508
+ param.getConvention ());
509
+ differentialParams.push_back ({paramTanType, paramConv});
490
510
}
491
511
SmallVector<SILResultInfo, 1 > differentialResults;
492
512
for (auto resultIndex : resultIndices->getIndices ()) {
493
513
// Handle formal original result.
494
514
if (resultIndex < originalFnTy->getNumResults ()) {
495
515
auto &result = originalResults[resultIndex];
496
- auto resultTan =
497
- result.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
498
- assert (resultTan && " Result type does not have a tangent space?" );
499
- auto resultTanType = resultTan->getCanonicalType ();
500
- auto resultConv =
501
- getTangentResultConvention (resultTanType, result.getConvention ());
502
- if (!resultTanType->hasArchetype () &&
503
- !resultTanType->hasTypeParameter ()) {
504
- differentialResults.push_back (
505
- {resultTan->getCanonicalType (), resultConv});
506
- } else {
507
- auto gpIndex = substGenericParams.size ();
508
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
509
- substGenericParams.push_back (gpType);
510
- substReplacements.push_back (resultTanType);
511
- differentialResults.push_back ({gpType, resultConv});
512
- }
516
+ auto resultTanType = getAutoDiffTangentTypeForLinearMap (
517
+ result.getInterfaceType (), lookupConformance,
518
+ substGenericParams, substReplacements, ctx);
519
+ auto resultConv = getTangentResultConvention (
520
+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
521
+ result.getInterfaceType ()
522
+ ->getAutoDiffTangentSpace (lookupConformance)
523
+ ->getCanonicalType (),
524
+ result.getConvention ());
525
+ differentialResults.push_back ({resultTanType, resultConv});
513
526
continue ;
514
527
}
515
528
// Handle original `inout` parameter.
@@ -524,11 +537,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
524
537
if (parameterIndices->contains (paramIndex))
525
538
continue ;
526
539
auto inoutParam = originalFnTy->getParameters ()[paramIndex];
527
- auto paramTan = inoutParam. getInterfaceType ()-> getAutoDiffTangentSpace (
528
- lookupConformance);
529
- assert (paramTan && " Parameter type does not have a tangent space? " );
540
+ auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
541
+ inoutParam. getInterfaceType (), lookupConformance,
542
+ substGenericParams, substReplacements, ctx );
530
543
differentialResults.push_back (
531
- {paramTan-> getCanonicalType () , ResultConvention::Indirect});
544
+ {inoutParamTanType , ResultConvention::Indirect});
532
545
}
533
546
534
547
SubstitutionMap substitutions;
@@ -635,23 +648,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
635
648
// Handle formal original result.
636
649
if (resultIndex < originalFnTy->getNumResults ()) {
637
650
auto &origRes = originalResults[resultIndex];
638
- auto resultTan = origRes.getInterfaceType ()->getAutoDiffTangentSpace (
639
- lookupConformance);
640
- assert (resultTan && " Result type does not have a tangent space?" );
641
- auto resultTanType = resultTan->getCanonicalType ();
642
- auto paramTanConvention = getTangentParameterConventionForOriginalResult (
643
- resultTanType, origRes.getConvention ());
644
- if (!resultTanType->hasArchetype () &&
645
- !resultTanType->hasTypeParameter ()) {
646
- auto resultTanType = resultTan->getCanonicalType ();
647
- pullbackParams.push_back ({resultTanType, paramTanConvention});
648
- } else {
649
- auto gpIndex = substGenericParams.size ();
650
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
651
- substGenericParams.push_back (gpType);
652
- substReplacements.push_back (resultTanType);
653
- pullbackParams.push_back ({gpType, paramTanConvention});
654
- }
651
+ auto resultTanType = getAutoDiffTangentTypeForLinearMap (
652
+ origRes.getInterfaceType (), lookupConformance,
653
+ substGenericParams, substReplacements, ctx);
654
+ auto paramConv = getTangentParameterConventionForOriginalResult (
655
+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
656
+ origRes.getInterfaceType ()
657
+ ->getAutoDiffTangentSpace (lookupConformance)
658
+ ->getCanonicalType (),
659
+ origRes.getConvention ());
660
+ pullbackParams.push_back ({resultTanType, paramConv});
655
661
continue ;
656
662
}
657
663
// Handle original `inout` parameter.
@@ -661,28 +667,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
661
667
auto paramIndex =
662
668
std::distance (originalFnTy->getParameters ().begin (), &*inoutParamIt);
663
669
auto inoutParam = originalFnTy->getParameters ()[paramIndex];
664
- auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
665
- lookupConformance);
666
- assert (paramTan && " Parameter type does not have a tangent space?" );
667
670
// The pullback parameter convention depends on whether the original `inout`
668
671
// paramater is a differentiability parameter.
669
672
// - If yes, the pullback parameter convention is `@inout`.
670
673
// - If no, the pullback parameter convention is `@in_guaranteed`.
674
+ auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
675
+ inoutParam.getInterfaceType (), lookupConformance,
676
+ substGenericParams, substReplacements, ctx);
671
677
bool isWrtInoutParameter = parameterIndices->contains (paramIndex);
672
678
auto paramTanConvention = isWrtInoutParameter
673
- ? inoutParam.getConvention ()
674
- : ParameterConvention::Indirect_In_Guaranteed;
675
- auto paramTanType = paramTan->getCanonicalType ();
676
- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
677
- pullbackParams.push_back (
678
- SILParameterInfo (paramTanType, paramTanConvention));
679
- } else {
680
- auto gpIndex = substGenericParams.size ();
681
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
682
- substGenericParams.push_back (gpType);
683
- substReplacements.push_back (paramTanType);
684
- pullbackParams.push_back ({gpType, paramTanConvention});
685
- }
679
+ ? inoutParam.getConvention ()
680
+ : ParameterConvention::Indirect_In_Guaranteed;
681
+ pullbackParams.push_back ({inoutParamTanType, paramTanConvention});
686
682
}
687
683
688
684
// Collect pullback results.
@@ -694,21 +690,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
694
690
// and always appear as pullback parameters.
695
691
if (param.isIndirectInOut ())
696
692
continue ;
697
- auto paramTan =
698
- param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
699
- assert (paramTan && " Parameter type does not have a tangent space?" );
700
- auto paramTanType = paramTan->getCanonicalType ();
693
+ auto paramTanType = getAutoDiffTangentTypeForLinearMap (
694
+ param.getInterfaceType (), lookupConformance,
695
+ substGenericParams, substReplacements, ctx);
701
696
auto resultTanConvention = getTangentResultConventionForOriginalParameter (
702
- paramTanType, param.getConvention ());
703
- if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
704
- pullbackResults.push_back ({paramTanType, resultTanConvention});
705
- } else {
706
- auto gpIndex = substGenericParams.size ();
707
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
708
- substGenericParams.push_back (gpType);
709
- substReplacements.push_back (paramTanType);
710
- pullbackResults.push_back ({gpType, resultTanConvention});
711
- }
697
+ // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
698
+ param.getInterfaceType ()
699
+ ->getAutoDiffTangentSpace (lookupConformance)
700
+ ->getCanonicalType (),
701
+ param.getConvention ());
702
+ pullbackResults.push_back ({paramTanType, resultTanConvention});
712
703
}
713
704
SubstitutionMap substitutions;
714
705
if (!substGenericParams.empty ()) {
0 commit comments