Skip to content

Commit 7e0faeb

Browse files
authored
[AutoDiff] Simplify logic for checking unmet generic requirements. (#27783)
The `SubstitutionMap` passed to `diagnoseUnsatisfiedRequirements` should default to the original function's forwarding substitution map. Delete the static `getSubstitutionMap` function, whose purpose is not clear and is not needed.
1 parent f745587 commit 7e0faeb

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2536,25 +2536,6 @@ reapplyFunctionConversion(
25362536
llvm_unreachable("Unhandled function conversion instruction");
25372537
}
25382538

2539-
static SubstitutionMap getSubstitutionMap(
2540-
SILValue value, SubstitutionMap substMap = SubstitutionMap()) {
2541-
if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value))
2542-
return getSubstitutionMap(thinToThick->getOperand(), substMap);
2543-
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value))
2544-
return getSubstitutionMap(convertFn->getOperand(), substMap);
2545-
if (auto *partialApply = dyn_cast<PartialApplyInst>(value)) {
2546-
auto appliedSubstMap = partialApply->getSubstitutionMap();
2547-
// TODO: Combine argument `substMap` with `appliedSubstMap`.
2548-
return getSubstitutionMap(partialApply->getCallee(), appliedSubstMap);
2549-
}
2550-
if (auto *apply = dyn_cast<ApplyInst>(value)) {
2551-
auto appliedSubstMap = apply->getSubstitutionMap();
2552-
// TODO: Combine argument `substMap` with `appliedSubstMap`.
2553-
return getSubstitutionMap(apply->getCallee(), appliedSubstMap);
2554-
}
2555-
return substMap;
2556-
}
2557-
25582539
/// Emits a reference to a derivative function of `original`, differentiated
25592540
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
25602541
/// the derivative function and the actual indices that the derivative function
@@ -2617,7 +2598,6 @@ emitDerivativeFunctionReference(
26172598
peerThroughFunctionConversions<FunctionRefInst>(original)) {
26182599
auto loc = originalFRI->getLoc();
26192600
auto *originalFn = originalFRI->getReferencedFunctionOrNull();
2620-
auto substMap = getSubstitutionMap(original);
26212601
// Attempt to look up a `[differentiable]` attribute that minimally
26222602
// satisfies the specified indices.
26232603
// TODO(TF-482): Change `lookUpMinimalDifferentiableAttr` to additionally
@@ -2676,6 +2656,16 @@ emitDerivativeFunctionReference(
26762656
assert(minimalAttr);
26772657
// TODO(TF-482): Move generic requirement checking logic to
26782658
// `lookUpMinimalDifferentiableAttr`.
2659+
// Get the substitution map for checking unmet generic requirements.
2660+
// By default, use the forwarding substitution map of the original function.
2661+
// If the original callee is a `partial_apply` or `apply` instruction, use
2662+
// its substitution map instead.
2663+
auto substMap = original->getFunction()->getForwardingSubstitutionMap();
2664+
if (auto *pai = dyn_cast<PartialApplyInst>(original)) {
2665+
substMap = pai->getSubstitutionMap();
2666+
} else if (auto *ai = dyn_cast<ApplyInst>(original)) {
2667+
substMap = ai->getSubstitutionMap();
2668+
}
26792669
if (diagnoseUnsatisfiedRequirements(
26802670
context, minimalAttr->getDerivativeGenericSignature(), originalFn,
26812671
substMap, invoker, original.getLoc().getSourceLoc()))

0 commit comments

Comments
 (0)