Skip to content

Commit 7fdc5d0

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Add cloned curry thunks to generated function list. (#27720)
Add cloned curry thunks to generated function list so that they will be deleted upon clean up if any diagnostics are emitted. Resolves [TF-918](https://bugs.swift.org/browse/TF-918): crash due to invalid cloned curry thunk that was not deleted during clean up.
1 parent a1d9369 commit 7fdc5d0

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -878,10 +878,9 @@ class ADContext {
878878
/// Saved for deletion during cleanup.
879879
SmallVector<SILFunction *, 32> generatedFunctions;
880880

881-
/// List of derivative function references, generated via
882-
/// `emitDerivativeFunctionReference`.
881+
/// List of references to generated functions.
883882
/// Saved for deletion during cleanup.
884-
SmallVector<SILValue, 32> generatedDerivativeFunctionReferences;
883+
SmallVector<SILValue, 32> generatedFunctionReferences;
885884

886885
/// The AdditiveArithmetic protocol in the standard library.
887886
ProtocolDecl *additiveArithmeticProtocol =
@@ -933,8 +932,8 @@ class ADContext {
933932
return generatedFunctions;
934933
}
935934

936-
SmallVector<SILValue, 32> &getGeneratedDerivativeFunctionReferences() {
937-
return generatedDerivativeFunctionReferences;
935+
SmallVector<SILValue, 32> &getGeneratedFunctionReferences() {
936+
return generatedFunctionReferences;
938937
}
939938

940939
ProtocolDecl *getAdditiveArithmeticProtocol() const {
@@ -969,11 +968,11 @@ class ADContext {
969968
original->removeDifferentiableAttr(attr);
970969
}
971970
// Delete all references to generated functions.
972-
for (auto derivativeFn : generatedDerivativeFunctionReferences) {
973-
if (auto *fnRef =
974-
peerThroughFunctionConversions<FunctionRefInst>(derivativeFn)) {
975-
fnRef->replaceAllUsesWithUndef();
976-
fnRef->eraseFromParent();
971+
for (auto fnRef : generatedFunctionReferences) {
972+
if (auto *fnRefInst =
973+
peerThroughFunctionConversions<FunctionRefInst>(fnRef)) {
974+
fnRefInst->replaceAllUsesWithUndef();
975+
fnRefInst->eraseFromParent();
977976
}
978977
}
979978
// Delete all generated functions.
@@ -1226,6 +1225,10 @@ ADContext::emitNondifferentiabilityError(SILValue value,
12261225
getADDebugStream() << "With invoker:\n" << invoker << '\n';
12271226
});
12281227
auto valueLoc = value.getLoc().getSourceLoc();
1228+
// If instruction does not have a valid location, use the function location
1229+
// as a fallback. Improves diagnostics in some cases.
1230+
if (valueLoc.isInvalid())
1231+
valueLoc = value->getFunction()->getLocation().getSourceLoc();
12291232
return emitNondifferentiabilityError(valueLoc, invoker, diag,
12301233
std::forward<U>(args)...);
12311234
}
@@ -8566,13 +8569,15 @@ SILValue ADContext::promoteToDifferentiableFunction(
85668569
thunkBuilder.createReturn(loc, dfi);
85678570
retInst->eraseFromParent();
85688571

8572+
getGeneratedFunctions().push_back(newThunk);
85698573
getDifferentiableFunctionInsts().push_back(dfi);
85708574
if (processDifferentiableFunctionInst(dfi))
85718575
return nullptr;
85728576
}
85738577

85748578
// Apply the new curry thunk.
85758579
auto *newThunkRef = builder.createFunctionRef(loc, newThunk);
8580+
getGeneratedFunctionReferences().push_back(newThunkRef);
85768581
SmallVector<SILValue, 8> newArgs;
85778582
SmallVector<SILValue, 8> newArgsToDestroy;
85788583
SmallVector<AllocStackInst *, 1> newBuffersToDealloc;
@@ -8608,7 +8613,7 @@ SILValue ADContext::promoteToDifferentiableFunction(
86088613
return nullptr;
86098614

86108615
auto derivativeFn = derivativeFnAndIndices->first;
8611-
getGeneratedDerivativeFunctionReferences().push_back(derivativeFn);
8616+
getGeneratedFunctionReferences().push_back(derivativeFn);
86128617

86138618
// If desired indices are a subset of actual indices, create a "subset
86148619
// indices thunk" and destroy the emitted derivative function reference.

test/AutoDiff/differentiation_transform_diagnostics.swift

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ struct TF_675 : Differentiable {
332332
// expected-error @+1 {{function is not differentiable}}
333333
let _: @differentiable (Float) -> Float = TF_675().method
334334

335+
// TF-918: test parameter subset thunk + partially-applied original function.
336+
// expected-error @+2 {{function is not differentiable}}
337+
// expected-note @+1 {{cannot convert a direct method reference to a '@differentiable' function; use an explicit closure instead}}
338+
_ = gradient(at: Float(1), Float(2), in: (+) as @differentiable (Float, @nondiff Float) -> Float)
339+
335340
//===----------------------------------------------------------------------===//
336341
// Conversion to `@differentiable(linear)` (not yet supported)
337342
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)