@@ -878,10 +878,9 @@ class ADContext {
878
878
// / Saved for deletion during cleanup.
879
879
SmallVector<SILFunction *, 32 > generatedFunctions;
880
880
881
- // / List of derivative function references, generated via
882
- // / `emitDerivativeFunctionReference`.
881
+ // / List of references to generated functions.
883
882
// / Saved for deletion during cleanup.
884
- SmallVector<SILValue, 32 > generatedDerivativeFunctionReferences ;
883
+ SmallVector<SILValue, 32 > generatedFunctionReferences ;
885
884
886
885
// / The AdditiveArithmetic protocol in the standard library.
887
886
ProtocolDecl *additiveArithmeticProtocol =
@@ -933,8 +932,8 @@ class ADContext {
933
932
return generatedFunctions;
934
933
}
935
934
936
- SmallVector<SILValue, 32 > &getGeneratedDerivativeFunctionReferences () {
937
- return generatedDerivativeFunctionReferences ;
935
+ SmallVector<SILValue, 32 > &getGeneratedFunctionReferences () {
936
+ return generatedFunctionReferences ;
938
937
}
939
938
940
939
ProtocolDecl *getAdditiveArithmeticProtocol () const {
@@ -969,11 +968,11 @@ class ADContext {
969
968
original->removeDifferentiableAttr (attr);
970
969
}
971
970
// Delete all references to generated functions.
972
- for (auto derivativeFn : generatedDerivativeFunctionReferences ) {
973
- if (auto *fnRef =
974
- peerThroughFunctionConversions<FunctionRefInst>(derivativeFn )) {
975
- fnRef ->replaceAllUsesWithUndef ();
976
- fnRef ->eraseFromParent ();
971
+ for (auto fnRef : generatedFunctionReferences ) {
972
+ if (auto *fnRefInst =
973
+ peerThroughFunctionConversions<FunctionRefInst>(fnRef )) {
974
+ fnRefInst ->replaceAllUsesWithUndef ();
975
+ fnRefInst ->eraseFromParent ();
977
976
}
978
977
}
979
978
// Delete all generated functions.
@@ -1226,6 +1225,10 @@ ADContext::emitNondifferentiabilityError(SILValue value,
1226
1225
getADDebugStream () << " With invoker:\n " << invoker << ' \n ' ;
1227
1226
});
1228
1227
auto valueLoc = value.getLoc ().getSourceLoc ();
1228
+ // If instruction does not have a valid location, use the function location
1229
+ // as a fallback. Improves diagnostics in some cases.
1230
+ if (valueLoc.isInvalid ())
1231
+ valueLoc = value->getFunction ()->getLocation ().getSourceLoc ();
1229
1232
return emitNondifferentiabilityError (valueLoc, invoker, diag,
1230
1233
std::forward<U>(args)...);
1231
1234
}
@@ -8566,13 +8569,15 @@ SILValue ADContext::promoteToDifferentiableFunction(
8566
8569
thunkBuilder.createReturn (loc, dfi);
8567
8570
retInst->eraseFromParent ();
8568
8571
8572
+ getGeneratedFunctions ().push_back (newThunk);
8569
8573
getDifferentiableFunctionInsts ().push_back (dfi);
8570
8574
if (processDifferentiableFunctionInst (dfi))
8571
8575
return nullptr ;
8572
8576
}
8573
8577
8574
8578
// Apply the new curry thunk.
8575
8579
auto *newThunkRef = builder.createFunctionRef (loc, newThunk);
8580
+ getGeneratedFunctionReferences ().push_back (newThunkRef);
8576
8581
SmallVector<SILValue, 8 > newArgs;
8577
8582
SmallVector<SILValue, 8 > newArgsToDestroy;
8578
8583
SmallVector<AllocStackInst *, 1 > newBuffersToDealloc;
@@ -8608,7 +8613,7 @@ SILValue ADContext::promoteToDifferentiableFunction(
8608
8613
return nullptr ;
8609
8614
8610
8615
auto derivativeFn = derivativeFnAndIndices->first ;
8611
- getGeneratedDerivativeFunctionReferences ().push_back (derivativeFn);
8616
+ getGeneratedFunctionReferences ().push_back (derivativeFn);
8612
8617
8613
8618
// If desired indices are a subset of actual indices, create a "subset
8614
8619
// indices thunk" and destroy the emitted derivative function reference.
0 commit comments