Skip to content

Commit 91c4e26

Browse files
committed
Add suggested tests.
1 parent 24f9fc1 commit 91c4e26

File tree

2 files changed

+49
-2
lines changed

2 files changed

+49
-2
lines changed

test/AutoDiff/downstream/differentiation_transform_diagnostics.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,28 @@ func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) {
450450
nonactive = result.0
451451
}
452452

453+
func twoInoutParameters(_ x: inout Float, _ y: inout Float) {}
454+
// expected-error @+1 {{function is not differentiable}}
455+
@differentiable
456+
// expected-note @+1 {{when differentiating this function definition}}
457+
func testTwoInoutParameters(_ x: Float, _ y: Float) -> Float {
458+
var x = x
459+
var y = y
460+
// expected-note @+1 {{cannot differentiate through multiple results}}
461+
twoInoutParameters(&x, &y)
462+
return x
463+
}
464+
465+
func inoutParameterAndFormalResult(_ x: inout Float) -> Float { x }
466+
// expected-error @+1 {{function is not differentiable}}
467+
@differentiable
468+
// expected-note @+1 {{when differentiating this function definition}}
469+
func testInoutParameterAndFormalResult(_ x: Float) -> Float {
470+
var x = x
471+
// expected-note @+1 {{cannot differentiate through multiple results}}
472+
return inoutParameterAndFormalResult(&x)
473+
}
474+
453475
//===----------------------------------------------------------------------===//
454476
// Non-varied results
455477
//===----------------------------------------------------------------------===//

test/AutoDiff/downstream/inout_parameters.swift

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ InoutParametersTests.test("Float./=") {
5151
expectEqual((10, 10), pullback(at: 4, 5, in: mutatingDivideWrapper)(10))
5252
}
5353

54+
// Simplest possible `inout` parameter differentiation.
5455
InoutParametersTests.test("InoutIdentity") {
5556
// Semantically, an empty function with an `inout` parameter is an identity
5657
// function.
@@ -61,8 +62,8 @@ InoutParametersTests.test("InoutIdentity") {
6162
inoutIdentity(&result)
6263
return result
6364
}
64-
expectEqual(1, gradient(at: 1, in: identity))
65-
expectEqual(10, pullback(at: 1, in: identity)(10))
65+
expectEqual(1, gradient(at: 10, in: identity))
66+
expectEqual(10, pullback(at: 10, in: identity)(10))
6667
}
6768

6869
extension Float {
@@ -125,4 +126,28 @@ InoutParametersTests.test("SetAccessor") {
125126
expectEqual(8, gradient(at: 4, in: squared))
126127
}
127128

129+
// Test differentiation wrt `inout` parameters that have a class type.
130+
InoutParametersTests.test("InoutClassParameter") {
131+
class Class: Differentiable {
132+
@differentiable
133+
var x: Float
134+
135+
init(_ x: Float) {
136+
self.x = x
137+
}
138+
}
139+
140+
// Semantically, an empty function with an `inout` parameter is an identity
141+
// function.
142+
func inoutIdentity(_ c: inout Class) {}
143+
144+
func identity(_ x: Float) -> Float {
145+
var c = Class(x)
146+
inoutIdentity(&c)
147+
return c.x
148+
}
149+
expectEqual(1, gradient(at: 10, in: identity))
150+
expectEqual(10, pullback(at: 10, in: identity)(10))
151+
}
152+
128153
runAllTests()

0 commit comments

Comments
 (0)