Skip to content

Commit 36be38e

Browse files
committed
Loosening checks that assume only one result. Adjusting tests.
1 parent 02104a6 commit 36be38e

File tree

5 files changed

+5
-12
lines changed

5 files changed

+5
-12
lines changed

lib/SILOptimizer/Differentiation/Common.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,10 +483,6 @@ findMinimalDerivativeConfiguration(AbstractFunctionDecl *original,
483483
SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
484484
SILModule &module, SILFunction *original, DifferentiabilityKind kind,
485485
IndexSubset *parameterIndices, IndexSubset *resultIndices) {
486-
// AST differentiability witnesses always have a single result.
487-
if (resultIndices->getCapacity() != 1 || !resultIndices->contains(0))
488-
return nullptr;
489-
490486
// Explicit differentiability witnesses only exist on SIL functions that come
491487
// from AST functions.
492488
auto *originalAFD = findAbstractFunctionDecl(original);

lib/SILOptimizer/Differentiation/Thunk.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -968,8 +968,8 @@ getOrCreateSubsetParametersThunkForDerivativeFunction(
968968
/*withoutActuallyEscaping*/ false);
969969
}
970970
assert(origFnType->getNumResults() +
971-
origFnType->getNumIndirectMutatingParameters() ==
972-
1);
971+
origFnType->getNumIndirectMutatingParameters() >
972+
0);
973973
if (origFnType->getNumResults() > 0 &&
974974
origFnType->getResults().front().isFormalDirect()) {
975975
auto result =

test/AutoDiff/Sema/DerivativeRegistrationCrossModule/Inputs/b.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ func multiply_swapCustom(_ x: Float, _ y: Float) -> Float {
2525
var tuple = (x, y)
2626
swapCustom(&tuple.0, &tuple.1)
2727
return tuple.0 * tuple.1
28-
}
28+
}

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -859,13 +859,13 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {}
859859

860860
extension InoutParameters {
861861
func multipleSemanticResults(_ x: inout Float) -> Float { x }
862-
@derivative(of: multipleSemanticResults)
862+
@derivative(of: multipleSemanticResults, wrt: x)
863863
func vjpMultipleSemanticResults(_ x: inout Float) -> (
864864
value: Float, pullback: (Float, inout Float) -> Void
865865
) { fatalError() }
866866

867867
func inoutVoid(_ x: Float, _ void: inout Void) -> Float {}
868-
@derivative(of: inoutVoid)
868+
@derivative(of: inoutVoid, wrt: (x, void))
869869
func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> (
870870
value: Float, pullback: (Float) -> Float
871871
) { fatalError() }

test/AutoDiff/validation-test/forward_mode_simple.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,8 +1063,6 @@ ForwardModeTests.test("FunctionCall") {
10631063
expectEqual(3, derivative(at: 3) { x in foo(x, 4) })
10641064
}
10651065

1066-
// FIXME(TF-1038): Support differentiable functions returning tuples.
1067-
/*
10681066
ForwardModeTests.test("ResultSelection") {
10691067
func tuple(_ x: Float, _ y: Float) -> (Float, Float) {
10701068
return (x + 1, y + 2)
@@ -1083,7 +1081,6 @@ ForwardModeTests.test("ResultSelection") {
10831081
expectEqual(1, derivative(at: 3, 3, of: tupleGenericSecond))
10841082
*/
10851083
}
1086-
*/
10871084

10881085
// TODO(TF-983): Support forward-mode differentiation of multiple results.
10891086
/*

0 commit comments

Comments
 (0)