@@ -2536,25 +2536,6 @@ reapplyFunctionConversion(
2536
2536
llvm_unreachable (" Unhandled function conversion instruction" );
2537
2537
}
2538
2538
2539
- static SubstitutionMap getSubstitutionMap (
2540
- SILValue value, SubstitutionMap substMap = SubstitutionMap()) {
2541
- if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value))
2542
- return getSubstitutionMap (thinToThick->getOperand (), substMap);
2543
- if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value))
2544
- return getSubstitutionMap (convertFn->getOperand (), substMap);
2545
- if (auto *partialApply = dyn_cast<PartialApplyInst>(value)) {
2546
- auto appliedSubstMap = partialApply->getSubstitutionMap ();
2547
- // TODO: Combine argument `substMap` with `appliedSubstMap`.
2548
- return getSubstitutionMap (partialApply->getCallee (), appliedSubstMap);
2549
- }
2550
- if (auto *apply = dyn_cast<ApplyInst>(value)) {
2551
- auto appliedSubstMap = apply->getSubstitutionMap ();
2552
- // TODO: Combine argument `substMap` with `appliedSubstMap`.
2553
- return getSubstitutionMap (apply->getCallee (), appliedSubstMap);
2554
- }
2555
- return substMap;
2556
- }
2557
-
2558
2539
// / Emits a reference to a derivative function of `original`, differentiated
2559
2540
// / with respect to a superset of `desiredIndices`. Returns the `SILValue` for
2560
2541
// / the derivative function and the actual indices that the derivative function
@@ -2617,7 +2598,6 @@ emitDerivativeFunctionReference(
2617
2598
peerThroughFunctionConversions<FunctionRefInst>(original)) {
2618
2599
auto loc = originalFRI->getLoc ();
2619
2600
auto *originalFn = originalFRI->getReferencedFunctionOrNull ();
2620
- auto substMap = getSubstitutionMap (original);
2621
2601
// Attempt to look up a `[differentiable]` attribute that minimally
2622
2602
// satisfies the specified indices.
2623
2603
// TODO(TF-482): Change `lookUpMinimalDifferentiableAttr` to additionally
@@ -2676,6 +2656,16 @@ emitDerivativeFunctionReference(
2676
2656
assert (minimalAttr);
2677
2657
// TODO(TF-482): Move generic requirement checking logic to
2678
2658
// `lookUpMinimalDifferentiableAttr`.
2659
+ // Get the substitution map for checking unmet generic requirements.
2660
+ // By default, use the forwarding substitution map of the original function.
2661
+ // If the original callee is a `partial_apply` or `apply` instruction, use
2662
+ // its substitution map instead.
2663
+ auto substMap = original->getFunction ()->getForwardingSubstitutionMap ();
2664
+ if (auto *pai = dyn_cast<PartialApplyInst>(original)) {
2665
+ substMap = pai->getSubstitutionMap ();
2666
+ } else if (auto *ai = dyn_cast<ApplyInst>(original)) {
2667
+ substMap = ai->getSubstitutionMap ();
2668
+ }
2679
2669
if (diagnoseUnsatisfiedRequirements (
2680
2670
context, minimalAttr->getDerivativeGenericSignature (), originalFn,
2681
2671
substMap, invoker, original.getLoc ().getSourceLoc ()))
0 commit comments