Skip to content

Commit e2a1862

Browse files
authored
[AutoDiff upstream] Add inout parameter differentiation tests. (#33555)
1 parent 363dc4e commit e2a1862

File tree

1 file changed

+163
-0
lines changed

1 file changed

+163
-0
lines changed

test/AutoDiff/validation-test/inout_parameters.swift

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,174 @@
11
// RUN: %target-run-simple-swift
22
// REQUIRES: executable_test
33

4+
// `inout` parameter differentiation tests.
5+
46
import DifferentiationUnittest
57
import StdlibUnittest
68

79
var InoutParameterAutoDiffTests = TestSuite("InoutParameterDifferentiation")
810

11+
// TODO(TF-1173): Move floating-point mutating operation tests to
12+
// `test/AutoDiff/stdlib/floating_point.swift.gyb` when forward-mode
13+
// differentiation supports `inout` parameter differentiation.
14+
15+
InoutParameterAutoDiffTests.test("Float.+=") {
16+
func mutatingAddWrapper(_ x: Float, _ y: Float) -> Float {
17+
var result: Float = x
18+
result += y
19+
return result
20+
}
21+
expectEqual((1, 1), gradient(at: 4, 5, in: mutatingAddWrapper))
22+
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingAddWrapper)(10))
23+
}
24+
25+
InoutParameterAutoDiffTests.test("Float.-=") {
26+
func mutatingSubtractWrapper(_ x: Float, _ y: Float) -> Float {
27+
var result: Float = x
28+
result += y
29+
return result
30+
}
31+
expectEqual((1, 1), gradient(at: 4, 5, in: mutatingSubtractWrapper))
32+
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingSubtractWrapper)(10))
33+
}
34+
35+
InoutParameterAutoDiffTests.test("Float.*=") {
36+
func mutatingMultiplyWrapper(_ x: Float, _ y: Float) -> Float {
37+
var result: Float = x
38+
result += y
39+
return result
40+
}
41+
expectEqual((1, 1), gradient(at: 4, 5, in: mutatingMultiplyWrapper))
42+
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingMultiplyWrapper)(10))
43+
}
44+
45+
InoutParameterAutoDiffTests.test("Float./=") {
46+
func mutatingDivideWrapper(_ x: Float, _ y: Float) -> Float {
47+
var result: Float = x
48+
result += y
49+
return result
50+
}
51+
expectEqual((1, 1), gradient(at: 4, 5, in: mutatingDivideWrapper))
52+
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingDivideWrapper)(10))
53+
}
54+
55+
// Simplest possible `inout` parameter differentiation.
56+
InoutParameterAutoDiffTests.test("InoutIdentity") {
57+
// Semantically, an empty function with an `inout` parameter is an identity
58+
// function.
59+
func inoutIdentity(_ x: inout Float) {}
60+
61+
func identity(_ x: Float) -> Float {
62+
var result = x
63+
inoutIdentity(&result)
64+
return result
65+
}
66+
expectEqual(1, gradient(at: 10, in: identity))
67+
expectEqual(10, pullback(at: 10, in: identity)(10))
68+
}
69+
70+
extension Float {
71+
// Custom version of `Float.*=`, implemented using `Float.*` and mutation.
72+
// Verify that its generated derivative has the same behavior as the
73+
// registered derivative for `Float.*=`.
74+
@differentiable
75+
static func multiplyAssign(_ lhs: inout Float, _ rhs: Float) {
76+
lhs = lhs * rhs
77+
}
78+
}
79+
80+
InoutParameterAutoDiffTests.test("ControlFlow") {
81+
func sum(_ array: [Float]) -> Float {
82+
var result: Float = 0
83+
for i in withoutDerivative(at: array.indices) {
84+
result += array[i]
85+
}
86+
return result
87+
}
88+
expectEqual([1, 1, 1], gradient(at: [1, 2, 3], in: sum))
89+
90+
func product(_ array: [Float]) -> Float {
91+
var result: Float = 1
92+
for i in withoutDerivative(at: array.indices) {
93+
result *= array[i]
94+
}
95+
return result
96+
}
97+
expectEqual([20, 15, 12], gradient(at: [3, 4, 5], in: product))
98+
99+
func productCustom(_ array: [Float]) -> Float {
100+
var result: Float = 1
101+
for i in withoutDerivative(at: array.indices) {
102+
Float.multiplyAssign(&result, array[i])
103+
}
104+
return result
105+
}
106+
expectEqual([20, 15, 12], gradient(at: [3, 4, 5], in: productCustom))
107+
}
108+
109+
InoutParameterAutoDiffTests.test("SetAccessor") {
110+
struct S: Differentiable {
111+
var x: Float
112+
113+
var computed: Float {
114+
get { x }
115+
set { x = newValue }
116+
}
117+
}
118+
119+
// `squared` implemented using a `set` accessor.
120+
func squared(_ x: Float) -> Float {
121+
var s = S(x: 1)
122+
s.x *= x
123+
s.computed *= x
124+
return s.x
125+
}
126+
expectEqual(6, gradient(at: 3, in: squared))
127+
expectEqual(8, gradient(at: 4, in: squared))
128+
}
129+
130+
// Test differentiation wrt `inout` parameters that have a class type.
131+
InoutParameterAutoDiffTests.test("InoutClassParameter") {
132+
class Class: Differentiable {
133+
@differentiable
134+
var x: Float
135+
136+
init(_ x: Float) {
137+
self.x = x
138+
}
139+
}
140+
141+
do {
142+
func squaredViaMutation(_ c: inout Class) {
143+
c = Class(c.x * c.x)
144+
}
145+
func squared(_ x: Float) -> Float {
146+
var c = Class(x)
147+
squaredViaMutation(&c)
148+
return c.x
149+
}
150+
expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
151+
expectEqual(200, pullback(at: 10, in: squared)(10))
152+
}
153+
154+
do {
155+
func squaredViaModifyAccessor(_ c: inout Class) {
156+
// The line below calls `Class.x.modify`.
157+
c.x *= c.x
158+
}
159+
func squared(_ x: Float) -> Float {
160+
var c = Class(x)
161+
squaredViaModifyAccessor(&c)
162+
return c.x
163+
}
164+
// FIXME(TF-1080): Fix incorrect class property `modify` accessor derivative values.
165+
// expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
166+
// expectEqual(200, pullback(at: 10, in: squared)(10))
167+
expectEqual((100, 1), valueWithGradient(at: 10, in: squared))
168+
expectEqual(10, pullback(at: 10, in: squared)(10))
169+
}
170+
}
171+
9172
// SR-13305: Test function with non-wrt `inout` parameter, which should be
10173
// treated as a differentiability result.
11174

0 commit comments

Comments
 (0)