Skip to content

[AutoDiff] Add cloned curry thunks to generated function list. #27720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,10 +878,9 @@ class ADContext {
/// Saved for deletion during cleanup.
SmallVector<SILFunction *, 32> generatedFunctions;

/// List of derivative function references, generated via
/// `emitDerivativeFunctionReference`.
/// List of references to generated functions.
/// Saved for deletion during cleanup.
SmallVector<SILValue, 32> generatedDerivativeFunctionReferences;
SmallVector<SILValue, 32> generatedFunctionReferences;

/// The AdditiveArithmetic protocol in the standard library.
ProtocolDecl *additiveArithmeticProtocol =
Expand Down Expand Up @@ -933,8 +932,8 @@ class ADContext {
return generatedFunctions;
}

SmallVector<SILValue, 32> &getGeneratedDerivativeFunctionReferences() {
return generatedDerivativeFunctionReferences;
SmallVector<SILValue, 32> &getGeneratedFunctionReferences() {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: function references to generated subset parameter thunks (for linear map and derivative functions) are not tracked. This is not a problem because all references to such thunks are in other generated functions, so cleanup already deletes all such references.

I decided not to track references to subset parameter thunks for now, since there's no problem.

return generatedFunctionReferences;
}

ProtocolDecl *getAdditiveArithmeticProtocol() const {
Expand Down Expand Up @@ -969,11 +968,11 @@ class ADContext {
original->removeDifferentiableAttr(attr);
}
// Delete all references to generated functions.
for (auto derivativeFn : generatedDerivativeFunctionReferences) {
if (auto *fnRef =
peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) {
fnRef->replaceAllUsesWithUndef();
fnRef->eraseFromParent();
for (auto fnRef : generatedFunctionReferences) {
if (auto *fnRefInst =
peerThroughFunctionConversions<FunctionRefInst>(fnRef)) {
fnRefInst->replaceAllUsesWithUndef();
fnRefInst->eraseFromParent();
}
}
// Delete all generated functions.
Expand Down Expand Up @@ -1226,6 +1225,10 @@ ADContext::emitNondifferentiabilityError(SILValue value,
getADDebugStream() << "With invoker:\n" << invoker << '\n';
});
auto valueLoc = value.getLoc().getSourceLoc();
// If instruction does not have a valid location, use the function location
// as a fallback. Improves diagnostics in some cases.
if (valueLoc.isInvalid())
valueLoc = value->getFunction()->getLocation().getSourceLoc();
return emitNondifferentiabilityError(valueLoc, invoker, diag,
std::forward<U>(args)...);
}
Expand Down Expand Up @@ -8566,13 +8569,15 @@ SILValue ADContext::promoteToDifferentiableFunction(
thunkBuilder.createReturn(loc, dfi);
retInst->eraseFromParent();

getGeneratedFunctions().push_back(newThunk);
getDifferentiableFunctionInsts().push_back(dfi);
if (processDifferentiableFunctionInst(dfi))
return nullptr;
}

// Apply the new curry thunk.
auto *newThunkRef = builder.createFunctionRef(loc, newThunk);
getGeneratedFunctionReferences().push_back(newThunkRef);
SmallVector<SILValue, 8> newArgs;
SmallVector<SILValue, 8> newArgsToDestroy;
SmallVector<AllocStackInst *, 1> newBuffersToDealloc;
Expand Down Expand Up @@ -8608,7 +8613,7 @@ SILValue ADContext::promoteToDifferentiableFunction(
return nullptr;

auto derivativeFn = derivativeFnAndIndices->first;
getGeneratedDerivativeFunctionReferences().push_back(derivativeFn);
getGeneratedFunctionReferences().push_back(derivativeFn);

// If desired indices are a subset of actual indices, create a "subset
// indices thunk" and destroy the emitted derivative function reference.
Expand Down
5 changes: 5 additions & 0 deletions test/AutoDiff/differentiation_transform_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ struct TF_675 : Differentiable {
// expected-error @+1 {{function is not differentiable}}
let _: @differentiable (Float) -> Float = TF_675().method

// TF-918: test parameter subset thunk + partially-applied original function.
// expected-error @+2 {{function is not differentiable}}
// expected-note @+1 {{cannot convert a direct method reference to a '@differentiable' function; use an explicit closure instead}}
_ = gradient(at: Float(1), Float(2), in: (+) as @differentiable (Float, @nondiff Float) -> Float)

//===----------------------------------------------------------------------===//
// Conversion to `@differentiable(linear)` (not yet supported)
//===----------------------------------------------------------------------===//
Expand Down