@@ -260,11 +260,13 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
260
260
if (isDiffParamIndex (valueAndIndex.index ()))
261
261
diffParams.push_back (valueAndIndex.value ());
262
262
263
- // Get the canonical derivative function generic signature.
263
+ // Get the "constrained" derivative function generic signature.
264
264
if (!derivativeFnGenSig)
265
265
derivativeFnGenSig = getSubstGenericSignature ();
266
- derivativeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature (
267
- this , parameterIndices, derivativeFnGenSig).getCanonicalSignature ();
266
+ derivativeFnGenSig =
267
+ autodiff::getConstrainedDerivativeGenericSignature (
268
+ this , parameterIndices, derivativeFnGenSig, lookupConformance)
269
+ .getCanonicalSignature ();
268
270
269
271
// Given a type, returns its formal SIL parameter info.
270
272
auto getTangentParameterInfoForOriginalResult =
@@ -401,6 +403,92 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
401
403
ctx, getWitnessMethodConformanceOrInvalid ());
402
404
}
403
405
406
+ CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType (
407
+ IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
408
+ LookupConformanceFn lookupConformance,
409
+ CanGenericSignature transposeFnGenSig) {
410
+ // Get the "constrained" transpose function generic signature.
411
+ if (!transposeFnGenSig)
412
+ transposeFnGenSig = getSubstGenericSignature ();
413
+ transposeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature (
414
+ this , parameterIndices, transposeFnGenSig,
415
+ lookupConformance, /* isLinear*/ true )
416
+ .getCanonicalSignature ();
417
+
418
+ // Given a type, returns its formal SIL parameter info.
419
+ auto getParameterInfoForOriginalResult =
420
+ [&](const SILResultInfo &result) -> SILParameterInfo {
421
+ AbstractionPattern pattern (transposeFnGenSig, result.getInterfaceType ());
422
+ auto &tl = TC.getTypeLowering (pattern, result.getInterfaceType (),
423
+ TypeExpansionContext::minimal ());
424
+ ParameterConvention newConv;
425
+ switch (result.getConvention ()) {
426
+ case ResultConvention::Owned:
427
+ case ResultConvention::Autoreleased:
428
+ newConv = tl.isTrivial () ? ParameterConvention::Direct_Unowned
429
+ : ParameterConvention::Direct_Guaranteed;
430
+ break ;
431
+ case ResultConvention::Unowned:
432
+ case ResultConvention::UnownedInnerPointer:
433
+ newConv = ParameterConvention::Direct_Unowned;
434
+ break ;
435
+ case ResultConvention::Indirect:
436
+ newConv = ParameterConvention::Indirect_In_Guaranteed;
437
+ break ;
438
+ }
439
+ return {result.getInterfaceType ()->getCanonicalType (transposeFnGenSig),
440
+ newConv};
441
+ };
442
+
443
+ // Given a type, returns its formal SIL result info.
444
+ auto getResultInfoForOriginalParameter =
445
+ [&](const SILParameterInfo ¶m) -> SILResultInfo {
446
+ AbstractionPattern pattern (transposeFnGenSig, param.getInterfaceType ());
447
+ auto &tl = TC.getTypeLowering (pattern, param.getInterfaceType (),
448
+ TypeExpansionContext::minimal ());
449
+ ResultConvention newConv;
450
+ switch (param.getConvention ()) {
451
+ case ParameterConvention::Direct_Owned:
452
+ case ParameterConvention::Direct_Guaranteed:
453
+ case ParameterConvention::Direct_Unowned:
454
+ newConv =
455
+ tl.isTrivial () ? ResultConvention::Unowned : ResultConvention::Owned;
456
+ break ;
457
+ case ParameterConvention::Indirect_In:
458
+ case ParameterConvention::Indirect_Inout:
459
+ case ParameterConvention::Indirect_In_Constant:
460
+ case ParameterConvention::Indirect_In_Guaranteed:
461
+ case ParameterConvention::Indirect_InoutAliasable:
462
+ newConv = ResultConvention::Indirect;
463
+ break ;
464
+ }
465
+ return {param.getInterfaceType ()->getCanonicalType (transposeFnGenSig),
466
+ newConv};
467
+ };
468
+
469
+ SmallVector<SILParameterInfo, 4 > newParameters;
470
+ SmallVector<SILResultInfo, 4 > newResults;
471
+ for (auto param : llvm::enumerate (getParameters ())) {
472
+ if (parameterIndices->contains (param.index ()))
473
+ newResults.push_back (getResultInfoForOriginalParameter (param.value ()));
474
+ else
475
+ newParameters.push_back (param.value ());
476
+ }
477
+ for (auto &res : getResults ())
478
+ newParameters.push_back (getParameterInfoForOriginalResult (res));
479
+ // Transpose function type has a generic signature only if the original
480
+ // function type does, and if `transposeFnGenSig` does not have all concrete
481
+ // generic parameters.
482
+ CanGenericSignature canGenSig;
483
+ if (getSubstGenericSignature () && transposeFnGenSig &&
484
+ !transposeFnGenSig->areAllParamsConcrete ())
485
+ canGenSig = transposeFnGenSig;
486
+ return SILFunctionType::get (
487
+ canGenSig, getExtInfo (), getCoroutineKind (), getCalleeConvention (),
488
+ newParameters, getYields (), newResults, getOptionalErrorResult (),
489
+ getSubstitutions (), isGenericSignatureImplied (), getASTContext ());
490
+ }
491
+
404
492
static CanType getKnownType (Optional<CanType> &cacheSlot, ASTContext &C,
405
493
StringRef moduleName, StringRef typeName) {
406
494
if (!cacheSlot) {
0 commit comments