Skip to content

Commit 5e00fdd

Browse files
authored
Fix AutoDiff tests. (#24564)
* Fix AutoDiff tests. - Fix `createAutoDiffThunk` to pass ownership verification. - The `@differentiable` function operand must be copied if non-trivial. - Fix `AutoDiff/builtin_differential_operators.swift` so that it actually runs. * Revert changes to `emitBuiltinAutoDiffApplyAssociatedFunction`. Re-add `createDestroyAddr`. Fix indentation.
1 parent 79518be commit 5e00fdd

File tree

3 files changed

+26
-24
lines changed

3 files changed

+26
-24
lines changed

lib/SILGen/SILGenBuiltin.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,12 +1078,11 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
10781078
continue;
10791079

10801080
if (argumentValue->getType().isObject()) {
1081-
SGF.B.createDestroyValue(loc, argumentValue);
1081+
SGF.B.emitDestroyValueOperation(loc, argumentValue);
10821082
continue;
10831083
}
10841084

1085-
if (false)
1086-
SGF.B.createDestroyAddr(loc, argumentValue);
1085+
SGF.B.createDestroyAddr(loc, argumentValue);
10871086
}
10881087
};
10891088

@@ -1096,7 +1095,8 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
10961095
auto curryLevelArgVals = ArrayRef<SILValue>(origFnArgVals).slice(
10971096
currentParameter, curryLevel->getNumParameters());
10981097
auto applyResult = SGF.B.createApply(
1099-
loc, assocFn, SubstitutionMap(), curryLevelArgVals, /*isNonThrowing*/ false);
1098+
loc, assocFn, SubstitutionMap(), curryLevelArgVals,
1099+
/*isNonThrowing*/ false);
11001100
currentParameter += curryLevel->getNumParameters();
11011101

11021102
if (assocFnNeedsDestroy)
@@ -1125,8 +1125,8 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
11251125
currentParameter);
11261126
for (auto origFnArgVal : curryLevelArgVals)
11271127
applyArgs.push_back(origFnArgVal);
1128-
auto differential = SGF.B.createApply(loc, assocFn, SubstitutionMap(), applyArgs,
1129-
/*isNonThrowing*/ false);
1128+
auto differential = SGF.B.createApply(
1129+
loc, assocFn, SubstitutionMap(), applyArgs, /*isNonThrowing*/ false);
11301130

11311131
if (assocFnNeedsDestroy)
11321132
SGF.B.createDestroyValue(loc, assocFn);
@@ -1143,8 +1143,9 @@ static ManagedValue emitBuiltinAutoDiffApplyAssociatedFunction(
11431143
// Apply the last curry level, in the case where it only has direct results.
11441144
auto curryLevelArgVals = ArrayRef<SILValue>(origFnArgVals).slice(
11451145
currentParameter);
1146-
auto resultTuple = SGF.B.createApply(loc, assocFn, SubstitutionMap(), curryLevelArgVals,
1147-
/*isNonThrowing*/ false);
1146+
auto resultTuple = SGF.B.createApply(
1147+
loc, assocFn, SubstitutionMap(), curryLevelArgVals,
1148+
/*isNonThrowing*/ false);
11481149

11491150
if (assocFnNeedsDestroy)
11501151
SGF.B.createDestroyValue(loc, assocFn);

lib/SILGen/SILGenPoly.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3245,13 +3245,8 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
32453245
CanAnyFunctionType outputSubstType) {
32463246
// Applies a thunk to all the components by extracting them, applying thunks
32473247
// to all of them, and then putting them back together.
3248-
32493248
auto sourceType = fn.getType().castTo<SILFunctionType>();
32503249

3251-
// We're never going to pass `fn` into anything that consumes it, so get its
3252-
// value without disabling cleanup.
3253-
auto fnValue = fn.getValue();
3254-
32553250
auto withoutDifferentiablePattern = [](AbstractionPattern pattern)
32563251
-> AbstractionPattern {
32573252
auto patternType = cast<AnyFunctionType>(pattern.getType());
@@ -3269,10 +3264,13 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
32693264
auto outputOrigTypeNotDiff = withoutDifferentiablePattern(outputOrigType);
32703265
auto &expectedTLNotDiff = SGF.getTypeLowering(outputOrigTypeNotDiff,
32713266
outputSubstTypeNotDiff);
3272-
SILValue original = SGF.B.createAutoDiffFunctionExtractOriginal(loc, fnValue);
3273-
auto managedOriginal = original->getType().isTrivial(SGF.F)
3274-
? ManagedValue::forTrivialObjectRValue(original)
3275-
: ManagedValue::forBorrowedObjectRValue(original);
3267+
// `autodiff_function_extract` is consuming; copy `fn` before passing as
3268+
// operand.
3269+
auto copiedFnValue = fn.copy(SGF, loc);
3270+
auto *original = SGF.B.createAutoDiffFunctionExtractOriginal(
3271+
loc, copiedFnValue.forward(SGF));
3272+
auto managedOriginal = SGF.emitManagedRValueWithCleanup(original);
3273+
32763274
ManagedValue originalThunk = createThunk(
32773275
SGF, loc, managedOriginal, inputOrigTypeNotDiff, inputSubstTypeNotDiff,
32783276
outputOrigTypeNotDiff, outputSubstTypeNotDiff, expectedTLNotDiff);
@@ -3309,12 +3307,12 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
33093307
kind);
33103308
auto &assocFnExpectedTL = SGF.getTypeLowering(assocFnOutputOrigType,
33113309
assocFnOutputSubstType);
3312-
auto assocFn = SGF.B.createAutoDiffFunctionExtract(
3313-
loc, kind,
3314-
/*differentiationOrder*/ 1, fnValue);
3315-
auto managedAssocFn = assocFn->getType().isTrivial(SGF.F)
3316-
? ManagedValue::forTrivialObjectRValue(assocFn)
3317-
: ManagedValue::forBorrowedObjectRValue(assocFn);
3310+
// `autodiff_function_extract` is consuming; copy `fn` before passing as
3311+
// operand.
3312+
auto copiedFnValue = fn.copy(SGF, loc);
3313+
auto *assocFn = SGF.B.createAutoDiffFunctionExtract(
3314+
loc, kind, /*differentiationOrder*/ 1, copiedFnValue.forward(SGF));
3315+
auto managedAssocFn = SGF.emitManagedRValueWithCleanup(assocFn);
33183316
return createThunk(SGF, loc, managedAssocFn, assocFnInputOrigType,
33193317
assocFnInputSubstType, assocFnOutputOrigType,
33203318
assocFnOutputSubstType, assocFnExpectedTL);

test/AutoDiff/builtin_differential_operators.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
// RUN: %target-run
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-build-swift %s -parse-stdlib -o %t/Builtins
3+
// RUN: %target-codesign %t/Builtins
4+
// RUN: %target-run %t/Builtins
25
// REQUIRES: executable_test
36

47
import Swift

0 commit comments

Comments
 (0)