Skip to content

Commit 4918248

Browse files
committed
Added tuple result tests, extracting tuple elements as semantic result types.
1 parent 709155a commit 4918248

File tree

4 files changed

+76
-2
lines changed

4 files changed

+76
-2
lines changed

lib/AST/AutoDiff.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,16 @@ void autodiff::getFunctionSemanticResultTypes(
196196
functionType->getResult()->getAs<AnyFunctionType>()) {
197197
formalResultType = resultFunctionType->getResult();
198198
}
199-
if (!formalResultType->isEqual(ctx.TheEmptyTupleType))
200-
result.push_back({remap(formalResultType), /*isInout*/ false});
199+
if (!formalResultType->isEqual(ctx.TheEmptyTupleType)) {
200+
// Separate tuple elements into individual results.
201+
if (formalResultType->is<TupleType>()) {
202+
for (auto elt : formalResultType->castTo<TupleType>()->getElements()) {
203+
result.push_back({remap(elt.getType()), /*isInout*/ false});
204+
}
205+
} else {
206+
result.push_back({remap(formalResultType), /*isInout*/ false});
207+
}
208+
}
201209

202210
// Collect `inout` parameters as semantic results.
203211
for (auto param : functionType->getParams())

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,22 @@ extension InoutParameters {
902902
) { fatalError() }
903903
}
904904

905+
// Test tuple results.
906+
907+
extension InoutParameters {
908+
func tupleResults(_ x: Float) -> (Float, Float) { (x, x) }
909+
@derivative(of: tupleResults, wrt: x)
910+
func vjpTupleResults(_ x: Float) -> (
911+
value: (Float, Float), pullback: (Float, Float) -> Float
912+
) { fatalError() }
913+
914+
func tupleResultsInt(_ x: Float) -> (Int, Float) { (1, x) }
915+
@derivative(of: tupleResultsInt, wrt: x)
916+
func vjpTupleResults(_ x: Float) -> (
917+
value: (Int, Float), pullback: (Float) -> Float
918+
) { fatalError() }
919+
}
920+
905921
// Test original/derivative function `inout` parameter mismatches.
906922

907923
extension InoutParameters {

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,19 @@ extension InoutParameters {
696696
mutating func mutatingMethod(_ other: Self) -> Self {}
697697
}
698698

699+
// Test tuple results.
700+
701+
extension InoutParameters {
702+
@differentiable(reverse)
703+
static func tupleResults(_ x: Self) -> (Self, Self) {}
704+
705+
@differentiable(reverse)
706+
static func tupleResultsInt(_ x: Self) -> (Int, Self) {}
707+
708+
@differentiable(reverse)
709+
static func tupleResultsInt2(_ x: Self) -> (Self, Int) {}
710+
}
711+
699712
// Test accessors: `set`, `_read`, `_modify`.
700713

701714
struct Accessors: Differentiable {

test/AutoDiff/validation-test/simple_math.swift

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,43 @@ SimpleMathTests.test("MultipleResultsWithCustomPullback") {
147147
expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapCustom))
148148
}
149149

150+
// Test functions returning tuples.
151+
@differentiable(reverse)
152+
func swapTuple(_ x: Float, _ y: Float) -> (Float, Float) {
153+
return (y, x)
154+
}
155+
156+
@differentiable(reverse)
157+
func swapTupleCustom(_ x: Float, _ y: Float) -> (Float, Float) {
158+
return (y, x)
159+
}
160+
@derivative(of: swapTupleCustom)
161+
func vjpSwapTupleCustom(_ x: Float, _ y: Float) -> (
162+
value: (Float, Float), pullback: (Float, Float) -> (Float, Float)
163+
) {
164+
return (swapTupleCustom(x, y), {v1, v2 in
165+
return (v2, v1)
166+
})
167+
}
168+
169+
SimpleMathTests.test("ReturningTuples") {
170+
func multiply_swapTuple(_ x: Float, _ y: Float) -> Float {
171+
let result = swapTuple(x, y)
172+
return result.0 * result.1
173+
}
174+
175+
expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTuple))
176+
expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTuple))
177+
178+
func multiply_swapTupleCustom(_ x: Float, _ y: Float) -> Float {
179+
let result = swapTupleCustom(x, y)
180+
return result.0 * result.1
181+
}
182+
183+
expectEqual((4, 3), gradient(at: 3, 4, of: multiply_swapTupleCustom))
184+
expectEqual((10, 5), gradient(at: 5, 10, of: multiply_swapTupleCustom))
185+
}
186+
150187
SimpleMathTests.test("CaptureLocal") {
151188
let z: Float = 10
152189
func foo(_ x: Float) -> Float {

0 commit comments

Comments
 (0)