47
47
#include " swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
48
48
#include " llvm/ADT/APSInt.h"
49
49
#include " llvm/ADT/BreadthFirstIterator.h"
50
+ #include " llvm/ADT/DenseMap.h"
50
51
#include " llvm/ADT/DenseSet.h"
51
52
#include " llvm/ADT/SmallSet.h"
52
53
#include " llvm/Support/CommandLine.h"
@@ -84,6 +85,9 @@ class DifferentiationTransformer {
84
85
// / Context necessary for performing the transformations.
85
86
ADContext context;
86
87
88
+ // / Cache used in getUnwrappedCurryThunkFunction.
89
+ llvm::DenseMap<AbstractFunctionDecl *, SILFunction *> afdToSILFn;
90
+
87
91
// / Promotes the given `differentiable_function` instruction to a valid
88
92
// / `@differentiable` function-typed value.
89
93
SILValue promoteToDifferentiableFunction (DifferentiableFunctionInst *inst,
@@ -96,6 +100,25 @@ class DifferentiationTransformer {
96
100
SILBuilder &builder, SILLocation loc,
97
101
DifferentiationInvoker invoker);
98
102
103
+ // / Emits a reference to a derivative function of `original`, differentiated
104
+ // / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
105
+ // / the derivative function and the actual indices that the derivative
106
+ // / function is with respect to.
107
+ // /
108
+ // / Returns `None` on failure, signifying that a diagnostic has been emitted
109
+ // / using `invoker`.
110
+ std::optional<std::pair<SILValue, AutoDiffConfig>>
111
+ emitDerivativeFunctionReference (
112
+ SILBuilder &builder, const AutoDiffConfig &desiredConfig,
113
+ AutoDiffDerivativeFunctionKind kind, SILValue original,
114
+ DifferentiationInvoker invoker,
115
+ SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc);
116
+
117
+ // / If the given function corresponds to AutoClosureExpr with either
118
+ // / SingleCurryThunk or DoubleCurryThunk kind, get the SILFunction
119
+ // / corresponding to the function being wrapped in the thunk.
120
+ SILFunction *getUnwrappedCurryThunkFunction (SILFunction *originalFn);
121
+
99
122
public:
100
123
// / Construct an `DifferentiationTransformer` for the given module.
101
124
explicit DifferentiationTransformer (SILModuleTransform &transform)
@@ -453,21 +476,55 @@ static SILValue reapplyFunctionConversion(
453
476
llvm_unreachable (" Unhandled function conversion instruction" );
454
477
}
455
478
456
- // / Emits a reference to a derivative function of `original`, differentiated
457
- // / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
458
- // / the derivative function and the actual indices that the derivative function
459
- // / is with respect to.
460
- // /
461
- // / Returns `None` on failure, signifying that a diagnostic has been emitted
462
- // / using `invoker`.
463
- static std::optional<std::pair<SILValue, AutoDiffConfig>>
464
- emitDerivativeFunctionReference (
465
- DifferentiationTransformer &transformer, SILBuilder &builder,
466
- const AutoDiffConfig &desiredConfig, AutoDiffDerivativeFunctionKind kind,
467
- SILValue original, DifferentiationInvoker invoker,
468
- SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
469
- ADContext &context = transformer.getContext ();
479
+ SILFunction *DifferentiationTransformer::getUnwrappedCurryThunkFunction (
480
+ SILFunction *originalFn) {
481
+ auto *abstractCE = originalFn->getDeclRef ().getAbstractClosureExpr ();
482
+ if (abstractCE == nullptr )
483
+ return nullptr ;
484
+ auto *autoCE = dyn_cast<AutoClosureExpr>(abstractCE);
485
+ if (autoCE == nullptr )
486
+ return nullptr ;
487
+
488
+ auto *afd =
489
+ cast<AbstractFunctionDecl>(autoCE->getUnwrappedCurryThunkCalledValue ());
490
+
491
+ auto silFnIt = afdToSILFn.find (afd);
492
+ if (silFnIt == afdToSILFn.end ()) {
493
+ assert (afdToSILFn.empty ());
494
+
495
+ SILModule *module = getTransform ().getModule ();
496
+ for (SILFunction ¤tFunc : module ->getFunctions ()) {
497
+ if (auto *currentAFD =
498
+ currentFunc.getDeclRef ().getAbstractFunctionDecl ()) {
499
+ // Update cache only with AFDs which might be potentially wrapped by a
500
+ // curry thunk. This includes member function references and references
501
+ // to functions having external property wrapper parameters (see
502
+ // ExprRewriter::buildDeclRef). If new use cases of curry thunks appear
503
+ // in future, the assertion after the loop will be a trigger for such
504
+ // cases being unhandled here.
505
+ //
506
+ // FIXME: References to functions having external property wrapper
507
+ // parameters are not handled since we can't now construct a test case
508
+ // for that due to the crash
509
+ // https://github.com/swiftlang/swift/issues/77613
510
+ if (currentAFD->hasCurriedSelf ())
511
+ afdToSILFn.insert ({currentAFD, ¤tFunc});
512
+ }
513
+ }
470
514
515
+ silFnIt = afdToSILFn.find (afd);
516
+ assert (silFnIt != afdToSILFn.end ());
517
+ }
518
+
519
+ return silFnIt->second ;
520
+ }
521
+
522
+ std::optional<std::pair<SILValue, AutoDiffConfig>>
523
+ DifferentiationTransformer::emitDerivativeFunctionReference (
524
+ SILBuilder &builder, const AutoDiffConfig &desiredConfig,
525
+ AutoDiffDerivativeFunctionKind kind, SILValue original,
526
+ DifferentiationInvoker invoker,
527
+ SmallVectorImpl<AllocStackInst *> &newBuffersToDealloc) {
471
528
// If `original` is itself an `DifferentiableFunctionExtractInst` whose kind
472
529
// matches the given kind and desired differentiation parameter indices,
473
530
// simply extract the derivative function of its function operand, retain the
@@ -610,26 +667,36 @@ emitDerivativeFunctionReference(
610
667
DifferentiabilityKind::Reverse, desiredParameterIndices,
611
668
desiredResultIndices, derivativeConstrainedGenSig, /* jvp*/ nullptr ,
612
669
/* vjp*/ nullptr , /* isSerialized*/ false );
613
- if (transformer. canonicalizeDifferentiabilityWitness (
614
- minimalWitness, invoker, IsNotSerialized))
670
+ if (canonicalizeDifferentiabilityWitness (minimalWitness, invoker,
671
+ IsNotSerialized))
615
672
return std::nullopt;
616
673
}
617
674
assert (minimalWitness);
618
- if (original->getFunction ()->isSerialized () &&
619
- !hasPublicVisibility (minimalWitness->getLinkage ())) {
620
- enum { Inlinable = 0 , DefaultArgument = 1 };
621
- unsigned fragileKind = Inlinable;
622
- // FIXME: This is not a very robust way of determining if the function is
623
- // a default argument. Also, we have not exhaustively listed all the kinds
624
- // of fragility.
625
- if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
626
- fragileKind = DefaultArgument;
627
- context.emitNondifferentiabilityError (
628
- original, invoker, diag::autodiff_private_derivative_from_fragile,
629
- fragileKind,
630
- isa_and_nonnull<AbstractClosureExpr>(
631
- originalFRI->getLoc ().getAsASTNode <Expr>()));
632
- return std::nullopt;
675
+ if (original->getFunction ()->isSerialized ()) {
676
+ // When dealing with curry thunk, look at the function being wrapped
677
+ // inside implicit closure. If it has public visibility, the corresponding
678
+ // differentiability witness also has public visibility. It should be OK
679
+ // for implicit wrapper closure and its witness to have private linkage.
680
+ SILFunction *unwrappedFn = getUnwrappedCurryThunkFunction (originalFn);
681
+ bool isWitnessPublic =
682
+ unwrappedFn == nullptr
683
+ ? hasPublicVisibility (minimalWitness->getLinkage ())
684
+ : hasPublicVisibility (unwrappedFn->getLinkage ());
685
+ if (!isWitnessPublic) {
686
+ enum { Inlinable = 0 , DefaultArgument = 1 };
687
+ unsigned fragileKind = Inlinable;
688
+ // FIXME: This is not a very robust way of determining if the function
689
+ // is a default argument. Also, we have not exhaustively listed all the
690
+ // kinds of fragility.
691
+ if (original->getFunction ()->getLinkage () == SILLinkage::PublicNonABI)
692
+ fragileKind = DefaultArgument;
693
+ context.emitNondifferentiabilityError (
694
+ original, invoker, diag::autodiff_private_derivative_from_fragile,
695
+ fragileKind,
696
+ isa_and_nonnull<AbstractClosureExpr>(
697
+ originalFRI->getLoc ().getAsASTNode <Expr>()));
698
+ return std::nullopt;
699
+ }
633
700
}
634
701
// TODO(TF-482): Move generic requirement checking logic to
635
702
// `getExactDifferentiabilityWitness` and
@@ -1121,8 +1188,8 @@ SILValue DifferentiationTransformer::promoteToDifferentiableFunction(
1121
1188
for (auto derivativeFnKind : {AutoDiffDerivativeFunctionKind::JVP,
1122
1189
AutoDiffDerivativeFunctionKind::VJP}) {
1123
1190
auto derivativeFnAndIndices = emitDerivativeFunctionReference (
1124
- * this , builder, desiredConfig, derivativeFnKind, origFnOperand,
1125
- invoker, newBuffersToDealloc);
1191
+ builder, desiredConfig, derivativeFnKind, origFnOperand, invoker ,
1192
+ newBuffersToDealloc);
1126
1193
// Show an error at the operator, highlight the argument, and show a note
1127
1194
// at the definition site of the argument.
1128
1195
if (!derivativeFnAndIndices)
0 commit comments