@@ -68,10 +68,10 @@ template <typename T> static inline void debugDump(T &v) {
68
68
<< v << " \n ==== END DEBUG DUMP ====\n " );
69
69
}
70
70
71
+ // / Returns true if the module we are compiling is in an LLDB REPL.
71
72
static bool isInLLDBREPL (SILModule &module ) {
72
- llvm::StringRef module_name = module .getSwiftModule ()->getNameStr ();
73
73
// TODO(SR-9704): Use a more prinicpled way to do this check.
74
- return module_name .startswith (" __lldb_expr_" );
74
+ return module . getSwiftModule ()-> getNameStr () .startswith (" __lldb_expr_" );
75
75
}
76
76
77
77
// / Creates arguments in the entry block based on the function type.
@@ -197,31 +197,6 @@ static CanType joinElementTypesFromValues(SILValueRange &&range,
197
197
return TupleType::get (elts, ctx)->getCanonicalType ();
198
198
}
199
199
200
- // / Looks through the definition of a function value. If the source that
201
- // / produced this function value is `function_ref` and the function is visible
202
- // / (either in the same module or is serialized), returns the instruction.
203
- // / Otherwise, returns null.
204
- static FunctionRefInst *findReferenceToVisibleFunction (SILValue value) {
205
- auto *inst = value->getDefiningInstruction ();
206
- if (!inst)
207
- return nullptr ;
208
- if (auto *fri = dyn_cast<FunctionRefInst>(inst)) {
209
- auto *fn = fri->getReferencedFunction ();
210
- if (&fn->getModule () == &inst->getModule () ||
211
- fn->isSerialized () == IsSerialized)
212
- return fri;
213
- }
214
- if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(inst))
215
- return findReferenceToVisibleFunction (thinToThick->getOperand ());
216
- if (auto *convertFn = dyn_cast<ConvertFunctionInst>(inst))
217
- return findReferenceToVisibleFunction (convertFn->getOperand ());
218
- if (auto *convertFn = dyn_cast<ConvertEscapeToNoEscapeInst>(inst))
219
- return findReferenceToVisibleFunction (convertFn->getOperand ());
220
- if (auto *partialApply = dyn_cast<PartialApplyInst>(inst))
221
- return findReferenceToVisibleFunction (partialApply->getCallee ());
222
- return nullptr ;
223
- }
224
-
225
200
// / Given an operator name, such as "+", and a protocol, returns the
226
201
// / "+" operator with type `(Self, Self) -> Self`. If the operator does not
227
202
// / exist in the protocol, returns null.
@@ -693,6 +668,15 @@ class DifferentiationTask {
693
668
SILFunction *getJVP () const { return jvp; }
694
669
SILFunction *getVJP () const { return vjp; }
695
670
671
+ SILFunction *getAssociatedFunction (AutoDiffAssociatedFunctionKind kind) {
672
+ switch (kind) {
673
+ case AutoDiffAssociatedFunctionKind::JVP:
674
+ return jvp;
675
+ case AutoDiffAssociatedFunctionKind::VJP:
676
+ return vjp;
677
+ }
678
+ }
679
+
696
680
DenseMap<ApplyInst *, NestedApplyActivity> &getNestedApplyActivities () {
697
681
return nestedApplyActivities;
698
682
}
@@ -1581,21 +1565,18 @@ reapplyFunctionConversion(SILValue newFunc, SILValue oldFunc,
1581
1565
llvm_unreachable (" Unhandled function convertion instruction" );
1582
1566
}
1583
1567
1584
- // / Looks through function conversion instructions to find an underlying witness
1585
- // / method instruction. Returns `nullptr` if `value` does not come from a
1586
- // / `witness_method` or if there are unhandled conversion instructions between
1587
- // / `value` and the `witness_method`..
1588
- static WitnessMethodInst *findWitnessMethod (SILValue value) {
1589
- if (auto *witnessMethod = dyn_cast<WitnessMethodInst>(value))
1590
- return witnessMethod;
1568
+ template <class Inst >
1569
+ static Inst *peerThroughFunctionConversions (SILValue value) {
1570
+ if (auto *inst = dyn_cast<Inst>(value))
1571
+ return inst;
1591
1572
if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value))
1592
- return findWitnessMethod (thinToThick->getOperand ());
1573
+ return peerThroughFunctionConversions<Inst> (thinToThick->getOperand ());
1593
1574
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value))
1594
- return findWitnessMethod (convertFn->getOperand ());
1575
+ return peerThroughFunctionConversions<Inst> (convertFn->getOperand ());
1595
1576
if (auto *convertFn = dyn_cast<ConvertEscapeToNoEscapeInst>(value))
1596
- return findWitnessMethod (convertFn->getOperand ());
1577
+ return peerThroughFunctionConversions<Inst> (convertFn->getOperand ());
1597
1578
if (auto *partialApply = dyn_cast<PartialApplyInst>(value))
1598
- return findWitnessMethod (partialApply->getCallee ());
1579
+ return peerThroughFunctionConversions<Inst> (partialApply->getCallee ());
1599
1580
return nullptr ;
1600
1581
}
1601
1582
@@ -1615,8 +1596,8 @@ static WitnessMethodInst *findWitnessMethod(SILValue value) {
1615
1596
static Optional<std::pair<SILValue, SILAutoDiffIndices>>
1616
1597
emitAssociatedFunctionReference (ADContext &context, SILBuilder &builder,
1617
1598
const DifferentiationTask *parentTask, SILAutoDiffIndices desiredIndices,
1618
- AutoDiffAssociatedFunctionKind kind,
1619
- SILValue original, DifferentiationInvoker invoker,
1599
+ AutoDiffAssociatedFunctionKind kind, SILValue original,
1600
+ DifferentiationInvoker invoker,
1620
1601
std::function<void (DifferentiationTask *)> taskCallback) {
1621
1602
1622
1603
// If `original` is itself an `AutoDiffFunctionExtractInst` whose kind matches
@@ -1646,24 +1627,21 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
1646
1627
}
1647
1628
}
1648
1629
1649
- // TODO: Refactor this function to recursively handle function conversions,
1650
- // rather than using `findReferenceToVisibleFunction`, `findWitnessMethod`,
1651
- // and `reapplyFunctionConversion`.
1652
-
1653
- if (auto *originalFRI = findReferenceToVisibleFunction (original)) {
1630
+ // Find local function reference.
1631
+ if (auto *originalFRI =
1632
+ peerThroughFunctionConversions<FunctionRefInst>(original)) {
1654
1633
auto loc = originalFRI->getLoc ();
1655
1634
auto *originalFn = originalFRI->getReferencedFunction ();
1656
1635
auto *task =
1657
1636
context.lookUpMinimalDifferentiationTask (originalFn, desiredIndices);
1658
1637
if (!task) {
1659
1638
if (originalFn->isExternalDeclaration ()) {
1660
- // For lldb repl , we should attempt to load the function as
1639
+ // For LLDB REPL , we should attempt to load the function as
1661
1640
// this may be defined in a different cell.
1662
- if (isInLLDBREPL (*original->getModule ())) {
1641
+ if (isInLLDBREPL (*original->getModule ()))
1663
1642
original->getModule ()->loadFunction (originalFn);
1664
- }
1665
1643
// If we still don't have the definition, generate an error message.
1666
- if (! originalFn->isDefinition ()) {
1644
+ if (originalFn->isExternalDeclaration ()) {
1667
1645
context.emitNondifferentiabilityError (
1668
1646
original, parentTask,
1669
1647
diag::autodiff_external_nondifferentiable_function);
@@ -1675,22 +1653,75 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
1675
1653
}
1676
1654
assert (task);
1677
1655
taskCallback (task);
1678
- SILFunction *assocFn = nullptr ;
1679
- switch (kind) {
1680
- case AutoDiffAssociatedFunctionKind::JVP:
1681
- assocFn = task->getJVP ();
1682
- break ;
1683
- case AutoDiffAssociatedFunctionKind::VJP:
1684
- assocFn = task->getVJP ();
1685
- break ;
1686
- }
1687
- auto *ref = builder.createFunctionRef (loc, assocFn);
1688
- auto convertedRef =
1689
- reapplyFunctionConversion (ref, originalFRI, original, builder, loc);
1656
+ auto *ref =
1657
+ builder.createFunctionRef (loc, task->getAssociatedFunction (kind));
1658
+ auto convertedRef = reapplyFunctionConversion (
1659
+ ref, originalFRI, original, builder, loc);
1690
1660
return std::make_pair (convertedRef, task->getIndices ());
1691
1661
}
1692
1662
1693
- if (auto *witnessMethod = findWitnessMethod (original)) {
1663
+ // Find global `let` closure.
1664
+ if (auto *load = peerThroughFunctionConversions<LoadInst>(original)) {
1665
+ FunctionRefInst *initialFnRef = nullptr ;
1666
+ SILValue initVal;
1667
+ if (auto *globalAddr = dyn_cast<GlobalAddrInst>(load->getOperand ())) {
1668
+ // Search for the original function used to initialize this `let`
1669
+ // constant.
1670
+ if (auto *global = globalAddr->getReferencedGlobal ()) {
1671
+ if (!global->isLet ()) {
1672
+ context.emitNondifferentiabilityError (original, parentTask,
1673
+ diag::autodiff_cannot_differentiate_global_var_closures);
1674
+ return None;
1675
+ }
1676
+ // FIXME: In LLDB REPL, "main" will not be the function we should look
1677
+ // for.
1678
+ if (auto *mainFn = global->getModule ().lookUpFunction (" main" )) {
1679
+ if (mainFn->isDefinition ())
1680
+ for (auto &inst : mainFn->front ())
1681
+ if (auto *globalAddrInMain = dyn_cast<GlobalAddrInst>(&inst))
1682
+ if (globalAddrInMain->getReferencedGlobal () == global)
1683
+ for (auto *use : globalAddrInMain->getUses ())
1684
+ if (auto *store = dyn_cast<StoreInst>(use->getUser ()))
1685
+ if (store->getDest () == globalAddrInMain)
1686
+ initialFnRef = peerThroughFunctionConversions
1687
+ <FunctionRefInst>((initVal = store->getSrc ()));
1688
+ }
1689
+ }
1690
+ }
1691
+ if (initialFnRef) {
1692
+ assert (initVal);
1693
+ auto *initialFn = initialFnRef->getReferencedFunction ();
1694
+ auto *task =
1695
+ context.lookUpMinimalDifferentiationTask (initialFn, desiredIndices);
1696
+ if (!task) {
1697
+ if (initialFn->isExternalDeclaration ()) {
1698
+ if (isInLLDBREPL (*original->getModule ()))
1699
+ original->getModule ()->loadFunction (initialFn);
1700
+ if (initialFn->isExternalDeclaration ()) {
1701
+ context.emitNondifferentiabilityError (original, parentTask,
1702
+ diag::autodiff_global_let_closure_not_differentiable);
1703
+ return None;
1704
+ }
1705
+ }
1706
+ task = context.registerDifferentiationTask (
1707
+ initialFn, desiredIndices, invoker);
1708
+ }
1709
+ auto loc = original.getLoc ();
1710
+ auto *initialVJPRef = builder.createFunctionRef (
1711
+ loc, task->getAssociatedFunction (kind));
1712
+ auto converted =
1713
+ reapplyFunctionConversion (initialVJPRef, initialFnRef, initVal,
1714
+ builder, loc);
1715
+ converted = reapplyFunctionConversion (converted, load, original,
1716
+ builder, loc);
1717
+ SILAutoDiffIndices indices (0 , desiredIndices.parameters );
1718
+ return std::make_pair (converted, indices);
1719
+ }
1720
+ }
1721
+
1722
+ // Find witness method retrieval.
1723
+ if (auto *witnessMethod =
1724
+ peerThroughFunctionConversions<WitnessMethodInst>(original)) {
1694
1725
auto loc = witnessMethod->getLoc ();
1695
1726
auto requirement = witnessMethod->getMember ();
1696
1727
auto *requirementDecl = requirement.getDecl ();
@@ -1735,6 +1766,7 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
1735
1766
return std::make_pair (convertedRef, requirementIndices);
1736
1767
}
1737
1768
1769
+ // Emit the general opaque function error.
1738
1770
context.emitNondifferentiabilityError (original, parentTask,
1739
1771
diag::autodiff_opaque_function_not_differentiable);
1740
1772
return None;
0 commit comments