@@ -408,32 +408,6 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
408
408
return buildGenericSignature (ctx, sig, {}, reqs).getCanonicalSignature ();
409
409
}
410
410
411
- // / Given an original type, computes its tangent type for the purpose of
412
- // / building a linear map using this type. When the original type is an
413
- // / archetype or contains a type parameter, appends a new generic parameter and
414
- // / a corresponding replacement type to the given containers.
415
- static CanType getAutoDiffTangentTypeForLinearMap (
416
- Type originalType,
417
- LookupConformanceFn lookupConformance,
418
- SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
419
- SmallVectorImpl<Type> &substReplacements,
420
- ASTContext &context
421
- ) {
422
- auto maybeTanType = originalType->getAutoDiffTangentSpace (lookupConformance);
423
- assert (maybeTanType && " Type does not have a tangent space?" );
424
- auto tanType = maybeTanType->getCanonicalType ();
425
- // If concrete, the tangent type is concrete.
426
- if (!tanType->hasArchetype () && !tanType->hasTypeParameter ())
427
- return tanType;
428
- // Otherwise, the tangent type is a new generic parameter substituted for the
429
- // tangent type.
430
- auto gpIndex = substGenericParams.size ();
431
- auto gpType = CanGenericTypeParamType::get (0 , gpIndex, context);
432
- substGenericParams.push_back (gpType);
433
- substReplacements.push_back (tanType);
434
- return gpType;
435
- }
436
-
437
411
// / Returns the differential type for the given original function type,
438
412
// / parameter indices, and result index.
439
413
static CanSILFunctionType getAutoDiffDifferentialType (
@@ -510,32 +484,45 @@ static CanSILFunctionType getAutoDiffDifferentialType(
510
484
getDifferentiabilityParameters (originalFnTy, parameterIndices, diffParams);
511
485
SmallVector<SILParameterInfo, 8 > differentialParams;
512
486
for (auto ¶m : diffParams) {
513
- auto paramTanType = getAutoDiffTangentTypeForLinearMap (
514
- param.getInterfaceType (), lookupConformance,
515
- substGenericParams, substReplacements, ctx);
516
- auto paramConv = getTangentParameterConvention (
517
- // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
518
- param.getInterfaceType ()
519
- ->getAutoDiffTangentSpace (lookupConformance)
520
- ->getCanonicalType (),
521
- param.getConvention ());
522
- differentialParams.push_back ({paramTanType, paramConv});
487
+ auto paramTan =
488
+ param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
489
+ assert (paramTan && " Parameter type does not have a tangent space?" );
490
+ auto paramTanType = paramTan->getCanonicalType ();
491
+ auto paramConv = getTangentParameterConvention (paramTanType,
492
+ param.getConvention ());
493
+ if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
494
+ differentialParams.push_back (
495
+ {paramTan->getCanonicalType (), paramConv});
496
+ } else {
497
+ auto gpIndex = substGenericParams.size ();
498
+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
499
+ substGenericParams.push_back (gpType);
500
+ substReplacements.push_back (paramTanType);
501
+ differentialParams.push_back ({gpType, paramConv});
502
+ }
523
503
}
524
504
SmallVector<SILResultInfo, 1 > differentialResults;
525
505
for (auto resultIndex : resultIndices->getIndices ()) {
526
506
// Handle formal original result.
527
507
if (resultIndex < originalFnTy->getNumResults ()) {
528
508
auto &result = originalResults[resultIndex];
529
- auto resultTanType = getAutoDiffTangentTypeForLinearMap (
530
- result.getInterfaceType (), lookupConformance,
531
- substGenericParams, substReplacements, ctx);
532
- auto resultConv = getTangentResultConvention (
533
- // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
534
- result.getInterfaceType ()
535
- ->getAutoDiffTangentSpace (lookupConformance)
536
- ->getCanonicalType (),
537
- result.getConvention ());
538
- differentialResults.push_back ({resultTanType, resultConv});
509
+ auto resultTan =
510
+ result.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
511
+ assert (resultTan && " Result type does not have a tangent space?" );
512
+ auto resultTanType = resultTan->getCanonicalType ();
513
+ auto resultConv =
514
+ getTangentResultConvention (resultTanType, result.getConvention ());
515
+ if (!resultTanType->hasArchetype () &&
516
+ !resultTanType->hasTypeParameter ()) {
517
+ differentialResults.push_back (
518
+ {resultTan->getCanonicalType (), resultConv});
519
+ } else {
520
+ auto gpIndex = substGenericParams.size ();
521
+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
522
+ substGenericParams.push_back (gpType);
523
+ substReplacements.push_back (resultTanType);
524
+ differentialResults.push_back ({gpType, resultConv});
525
+ }
539
526
continue ;
540
527
}
541
528
// Handle original `inout` parameter.
@@ -550,11 +537,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
550
537
if (parameterIndices->contains (paramIndex))
551
538
continue ;
552
539
auto inoutParam = originalFnTy->getParameters ()[paramIndex];
553
- auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
554
- inoutParam. getInterfaceType (), lookupConformance,
555
- substGenericParams, substReplacements, ctx );
540
+ auto paramTan = inoutParam. getInterfaceType ()-> getAutoDiffTangentSpace (
541
+ lookupConformance);
542
+ assert (paramTan && " Parameter type does not have a tangent space? " );
556
543
differentialResults.push_back (
557
- {inoutParamTanType , ResultConvention::Indirect});
544
+ {paramTan-> getCanonicalType () , ResultConvention::Indirect});
558
545
}
559
546
560
547
SubstitutionMap substitutions;
@@ -661,16 +648,23 @@ static CanSILFunctionType getAutoDiffPullbackType(
661
648
// Handle formal original result.
662
649
if (resultIndex < originalFnTy->getNumResults ()) {
663
650
auto &origRes = originalResults[resultIndex];
664
- auto resultTanType = getAutoDiffTangentTypeForLinearMap (
665
- origRes.getInterfaceType (), lookupConformance,
666
- substGenericParams, substReplacements, ctx);
667
- auto paramConv = getTangentParameterConventionForOriginalResult (
668
- // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
669
- origRes.getInterfaceType ()
670
- ->getAutoDiffTangentSpace (lookupConformance)
671
- ->getCanonicalType (),
672
- origRes.getConvention ());
673
- pullbackParams.push_back ({resultTanType, paramConv});
651
+ auto resultTan = origRes.getInterfaceType ()->getAutoDiffTangentSpace (
652
+ lookupConformance);
653
+ assert (resultTan && " Result type does not have a tangent space?" );
654
+ auto resultTanType = resultTan->getCanonicalType ();
655
+ auto paramTanConvention = getTangentParameterConventionForOriginalResult (
656
+ resultTanType, origRes.getConvention ());
657
+ if (!resultTanType->hasArchetype () &&
658
+ !resultTanType->hasTypeParameter ()) {
659
+ auto resultTanType = resultTan->getCanonicalType ();
660
+ pullbackParams.push_back ({resultTanType, paramTanConvention});
661
+ } else {
662
+ auto gpIndex = substGenericParams.size ();
663
+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
664
+ substGenericParams.push_back (gpType);
665
+ substReplacements.push_back (resultTanType);
666
+ pullbackParams.push_back ({gpType, paramTanConvention});
667
+ }
674
668
continue ;
675
669
}
676
670
// Handle original `inout` parameter.
@@ -680,18 +674,28 @@ static CanSILFunctionType getAutoDiffPullbackType(
680
674
auto paramIndex =
681
675
std::distance (originalFnTy->getParameters ().begin (), &*inoutParamIt);
682
676
auto inoutParam = originalFnTy->getParameters ()[paramIndex];
677
+ auto paramTan = inoutParam.getInterfaceType ()->getAutoDiffTangentSpace (
678
+ lookupConformance);
679
+ assert (paramTan && " Parameter type does not have a tangent space?" );
683
680
// The pullback parameter convention depends on whether the original `inout`
684
681
// paramater is a differentiability parameter.
685
682
// - If yes, the pullback parameter convention is `@inout`.
686
683
// - If no, the pullback parameter convention is `@in_guaranteed`.
687
- auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap (
688
- inoutParam.getInterfaceType (), lookupConformance,
689
- substGenericParams, substReplacements, ctx);
690
684
bool isWrtInoutParameter = parameterIndices->contains (paramIndex);
691
685
auto paramTanConvention = isWrtInoutParameter
692
- ? inoutParam.getConvention ()
693
- : ParameterConvention::Indirect_In_Guaranteed;
694
- pullbackParams.push_back ({inoutParamTanType, paramTanConvention});
686
+ ? inoutParam.getConvention ()
687
+ : ParameterConvention::Indirect_In_Guaranteed;
688
+ auto paramTanType = paramTan->getCanonicalType ();
689
+ if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
690
+ pullbackParams.push_back (
691
+ SILParameterInfo (paramTanType, paramTanConvention));
692
+ } else {
693
+ auto gpIndex = substGenericParams.size ();
694
+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
695
+ substGenericParams.push_back (gpType);
696
+ substReplacements.push_back (paramTanType);
697
+ pullbackParams.push_back ({gpType, paramTanConvention});
698
+ }
695
699
}
696
700
697
701
// Collect pullback results.
@@ -703,16 +707,21 @@ static CanSILFunctionType getAutoDiffPullbackType(
703
707
// and always appear as pullback parameters.
704
708
if (param.isIndirectInOut ())
705
709
continue ;
706
- auto paramTanType = getAutoDiffTangentTypeForLinearMap (
707
- param.getInterfaceType (), lookupConformance,
708
- substGenericParams, substReplacements, ctx);
710
+ auto paramTan =
711
+ param.getInterfaceType ()->getAutoDiffTangentSpace (lookupConformance);
712
+ assert (paramTan && " Parameter type does not have a tangent space?" );
713
+ auto paramTanType = paramTan->getCanonicalType ();
709
714
auto resultTanConvention = getTangentResultConventionForOriginalParameter (
710
- // FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
711
- param.getInterfaceType ()
712
- ->getAutoDiffTangentSpace (lookupConformance)
713
- ->getCanonicalType (),
714
- param.getConvention ());
715
- pullbackResults.push_back ({paramTanType, resultTanConvention});
715
+ paramTanType, param.getConvention ());
716
+ if (!paramTanType->hasArchetype () && !paramTanType->hasTypeParameter ()) {
717
+ pullbackResults.push_back ({paramTanType, resultTanConvention});
718
+ } else {
719
+ auto gpIndex = substGenericParams.size ();
720
+ auto gpType = CanGenericTypeParamType::get (0 , gpIndex, ctx);
721
+ substGenericParams.push_back (gpType);
722
+ substReplacements.push_back (paramTanType);
723
+ pullbackResults.push_back ({gpType, resultTanConvention});
724
+ }
716
725
}
717
726
SubstitutionMap substitutions;
718
727
if (!substGenericParams.empty ()) {
0 commit comments