@@ -896,16 +896,23 @@ namespace {
896
896
DifferentiableFunctionExtractee::Original,
897
897
TC.getTypeLowering (origFnTy, getResilienceExpansion ())
898
898
});
899
- for (auto kind : {AutoDiffAssociatedFunctionKind::JVP,
900
- AutoDiffAssociatedFunctionKind::VJP}) {
899
+ for (AutoDiffAssociatedFunctionKind kind :
900
+ {AutoDiffAssociatedFunctionKind::JVP,
901
+ AutoDiffAssociatedFunctionKind::VJP}) {
901
902
auto assocFnTy = origFnTy->getAutoDiffAssociatedFunctionType (
902
903
paramIndices, 0 , kind, TC,
903
904
LookUpConformanceInModule (&TC.M ));
904
905
auto silTy = SILType::getPrimitiveObjectType (assocFnTy);
906
+ auto extractee = DifferentiableFunctionExtractee (kind);
907
+
908
+ // A bug caused by implicit conversions caused us to get the wrong
909
+ // extractee, so assert that we have the right extractee to prevent
910
+ // reoccurrence of the bug.
911
+ assert (extractee.getExtracteeAsAssociatedFunction () ==
912
+ Optional<AutoDiffAssociatedFunctionKind>(kind));
913
+
905
914
children.push_back (Child{
906
- DifferentiableFunctionExtractee (kind),
907
- TC.getTypeLowering (silTy, getResilienceExpansion ())
908
- });
915
+ extractee, TC.getTypeLowering (silTy, getResilienceExpansion ())});
909
916
}
910
917
assert (children.size () == 3 );
911
918
}
0 commit comments