Skip to content

[AutoDiff] Simplify logic for checking unmet generic requirements. #27783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 19, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2536,25 +2536,6 @@ reapplyFunctionConversion(
llvm_unreachable("Unhandled function conversion instruction");
}

static SubstitutionMap getSubstitutionMap(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More context: I added this getSubstitutionMap utility as a hack in #24845, I don't think the old behavior was necessary. The new simpler behavior seems sufficient for now.

SILValue value, SubstitutionMap substMap = SubstitutionMap()) {
if (auto *thinToThick = dyn_cast<ThinToThickFunctionInst>(value))
return getSubstitutionMap(thinToThick->getOperand(), substMap);
if (auto *convertFn = dyn_cast<ConvertFunctionInst>(value))
return getSubstitutionMap(convertFn->getOperand(), substMap);
if (auto *partialApply = dyn_cast<PartialApplyInst>(value)) {
auto appliedSubstMap = partialApply->getSubstitutionMap();
// TODO: Combine argument `substMap` with `appliedSubstMap`.
return getSubstitutionMap(partialApply->getCallee(), appliedSubstMap);
}
if (auto *apply = dyn_cast<ApplyInst>(value)) {
auto appliedSubstMap = apply->getSubstitutionMap();
// TODO: Combine argument `substMap` with `appliedSubstMap`.
return getSubstitutionMap(apply->getCallee(), appliedSubstMap);
}
return substMap;
}

/// Emits a reference to a derivative function of `original`, differentiated
/// with respect to a superset of `desiredIndices`. Returns the `SILValue` for
/// the derivative function and the actual indices that the derivative function
Expand Down Expand Up @@ -2617,7 +2598,6 @@ emitDerivativeFunctionReference(
peerThroughFunctionConversions<FunctionRefInst>(original)) {
auto loc = originalFRI->getLoc();
auto *originalFn = originalFRI->getReferencedFunctionOrNull();
auto substMap = getSubstitutionMap(original);
// Attempt to look up a `[differentiable]` attribute that minimally
// satisfies the specified indices.
// TODO(TF-482): Change `lookUpMinimalDifferentiableAttr` to additionally
Expand Down Expand Up @@ -2676,6 +2656,16 @@ emitDerivativeFunctionReference(
assert(minimalAttr);
// TODO(TF-482): Move generic requirement checking logic to
// `lookUpMinimalDifferentiableAttr`.
// Get the substitution map for checking unmet generic requirements.
// By default, use the forwarding substitution map of the original function.
// If the original callee is a `partial_apply` or `apply` instruction, use
// its substitution map instead.
auto substMap = original->getFunction()->getForwardingSubstitutionMap();
if (auto *pai = dyn_cast<PartialApplyInst>(original)) {
substMap = pai->getSubstitutionMap();
} else if (auto *ai = dyn_cast<ApplyInst>(original)) {
substMap = ai->getSubstitutionMap();
}
if (diagnoseUnsatisfiedRequirements(
context, minimalAttr->getDerivativeGenericSignature(), originalFn,
substMap, invoker, original.getLoc().getSourceLoc()))
Expand Down