@@ -3245,13 +3245,8 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
3245
3245
CanAnyFunctionType outputSubstType) {
3246
3246
// Applies a thunk to all the components by extracting them, applying thunks
3247
3247
// to all of them, and then putting them back together.
3248
-
3249
3248
auto sourceType = fn.getType ().castTo <SILFunctionType>();
3250
3249
3251
- // We're never going to pass `fn` into anything that consumes it, so get its
3252
- // value without disabling cleanup.
3253
- auto fnValue = fn.getValue ();
3254
-
3255
3250
auto withoutDifferentiablePattern = [](AbstractionPattern pattern)
3256
3251
-> AbstractionPattern {
3257
3252
auto patternType = cast<AnyFunctionType>(pattern.getType ());
@@ -3269,10 +3264,13 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
3269
3264
auto outputOrigTypeNotDiff = withoutDifferentiablePattern (outputOrigType);
3270
3265
auto &expectedTLNotDiff = SGF.getTypeLowering (outputOrigTypeNotDiff,
3271
3266
outputSubstTypeNotDiff);
3272
- SILValue original = SGF.B .createAutoDiffFunctionExtractOriginal (loc, fnValue);
3273
- auto managedOriginal = original->getType ().isTrivial (SGF.F )
3274
- ? ManagedValue::forTrivialObjectRValue (original)
3275
- : ManagedValue::forBorrowedObjectRValue (original);
3267
+ // `autodiff_function_extract` is consuming; copy `fn` before passing as
3268
+ // operand.
3269
+ auto copiedFnValue = fn.copy (SGF, loc);
3270
+ auto *original = SGF.B .createAutoDiffFunctionExtractOriginal (
3271
+ loc, copiedFnValue.forward (SGF));
3272
+ auto managedOriginal = SGF.emitManagedRValueWithCleanup (original);
3273
+
3276
3274
ManagedValue originalThunk = createThunk (
3277
3275
SGF, loc, managedOriginal, inputOrigTypeNotDiff, inputSubstTypeNotDiff,
3278
3276
outputOrigTypeNotDiff, outputSubstTypeNotDiff, expectedTLNotDiff);
@@ -3309,12 +3307,12 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
3309
3307
kind);
3310
3308
auto &assocFnExpectedTL = SGF.getTypeLowering (assocFnOutputOrigType,
3311
3309
assocFnOutputSubstType);
3312
- auto assocFn = SGF. B . createAutoDiffFunctionExtract (
3313
- loc, kind,
3314
- /* differentiationOrder */ 1 , fnValue );
3315
- auto managedAssocFn = assocFn-> getType (). isTrivial ( SGF.F )
3316
- ? ManagedValue::forTrivialObjectRValue (assocFn)
3317
- : ManagedValue::forBorrowedObjectRValue (assocFn);
3310
+ // `autodiff_function_extract` is consuming; copy `fn` before passing as
3311
+ // operand.
3312
+ auto copiedFnValue = fn. copy (SGF, loc );
3313
+ auto *assocFn = SGF.B . createAutoDiffFunctionExtract (
3314
+ loc, kind, /* differentiationOrder */ 1 , copiedFnValue. forward (SGF));
3315
+ auto managedAssocFn = SGF. emitManagedRValueWithCleanup (assocFn);
3318
3316
return createThunk (SGF, loc, managedAssocFn, assocFnInputOrigType,
3319
3317
assocFnInputSubstType, assocFnOutputOrigType,
3320
3318
assocFnOutputSubstType, assocFnExpectedTL);
0 commit comments