Skip to content

Commit 955b594

Browse files
committed
Revamp inout class argument tests.
Mutating `inout` class argument via `store`: correct derivatives. Mutating `inout` class argument via `modify` accessor: incorrect derivatives. Tracked at TF-1176.
1 parent a0d197e commit 955b594

File tree

2 files changed

+52
-11
lines changed

2 files changed

+52
-11
lines changed

test/AutoDiff/downstream/activity_analysis.swift

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ func testArrayUninitializedIntrinsicFunctionResult(_ x: Float, _ y: Float) -> [F
210210
// CHECK: [ACTIVE] %18 = apply %17(%0, %1, %16) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
211211

212212
// TF-975: Test nested array literals.
213-
// FIXME(TF-975): Some values are incorrectly not marked as active: `%0`, `%1`, etc.
214213
@differentiable
215214
func testArrayUninitializedIntrinsicNested(_ x: Float, _ y: Float) -> [Float] {
216215
let array = [x, y]
@@ -251,7 +250,6 @@ func testArrayUninitializedIntrinsicNested(_ x: Float, _ y: Float) -> [Float] {
251250
// CHECK: [NONE] %36 = apply %35<Float>(%29, %34, %30) : $@convention(method) <τ_0_0> (Int, @guaranteed Array<τ_0_0>) -> @out τ_0_0
252251

253252
// TF-978: Test array literal initialized with `apply` indirect results.
254-
// FIXME(TF-978): Some values are incorrectly not marked as active: `%0`, `%1`, etc.
255253
struct Wrapper<T: Differentiable>: Differentiable {
256254
var value: T
257255
}
@@ -585,6 +583,31 @@ class C: Differentiable {
585583
// CHECK: [ACTIVE] %8 = apply %7(%0, %6, %4) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
586584
}
587585

586+
// TF-1176: Test class property `modify` accessor.
587+
@differentiable
588+
func testClassModifyAccessor(_ c: inout C) {
589+
c.float *= c.float
590+
}
591+
592+
// FIXME(TF-1176): Some values are incorrectly not marked as active: `%16`, etc.
593+
// CHECK-LABEL: [AD] Activity info for ${{.*}}testClassModifyAccessor{{.*}} at (source=0 parameters=(0))
594+
// CHECK: [ACTIVE] %0 = argument of bb0 : $*C
595+
// CHECK: [NONE] %2 = metatype $@thin Float.Type
596+
// CHECK: [ACTIVE] %3 = begin_access [read] [static] %0 : $*C
597+
// CHECK: [VARIED] %4 = load [copy] %3 : $*C
598+
// CHECK: [ACTIVE] %6 = begin_access [read] [static] %0 : $*C
599+
// CHECK: [VARIED] %7 = load [copy] %6 : $*C
600+
// CHECK: [VARIED] %9 = begin_borrow %7 : $C
601+
// CHECK: [VARIED] %10 = class_method %9 : $C, #C.float!getter.1 : (C) -> () -> Float, $@convention(method) (@guaranteed C) -> Float
602+
// CHECK: [VARIED] %11 = apply %10(%9) : $@convention(method) (@guaranteed C) -> Float
603+
// CHECK: [VARIED] %14 = begin_borrow %4 : $C
604+
// CHECK: [VARIED] %15 = class_method %14 : $C, #C.float!modify.1 : (C) -> () -> (), $@yield_once @convention(method) (@guaranteed C) -> @yields @inout Float
605+
// CHECK: [VARIED] (**%16**, %17) = begin_apply %15(%14) : $@yield_once @convention(method) (@guaranteed C) -> @yields @inout Float
606+
// CHECK: [VARIED] (%16, **%17**) = begin_apply %15(%14) : $@yield_once @convention(method) (@guaranteed C) -> @yields @inout Float
607+
// CHECK: [NONE] // function_ref static Float.*= infix(_:_:)
608+
// CHECK: [NONE] %19 = apply %18(%16, %11, %2) : $@convention(method) (@inout Float, Float, @thin Float.Type) -> ()
609+
// CHECK: [NONE] %23 = tuple ()
610+
588611
//===----------------------------------------------------------------------===//
589612
// Enum differentiation
590613
//===----------------------------------------------------------------------===//

test/AutoDiff/downstream/inout_parameters.swift

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,17 +137,35 @@ InoutParametersTests.test("InoutClassParameter") {
137137
}
138138
}
139139

140-
// Semantically, an empty function with an `inout` parameter is an identity
141-
// function.
142-
func inoutIdentity(_ c: inout Class) {}
140+
do {
141+
func squaredViaMutation(_ c: inout Class) {
142+
c = Class(c.x * c.x)
143+
}
144+
func squared(_ x: Float) -> Float {
145+
var c = Class(x)
146+
squaredViaMutation(&c)
147+
return c.x
148+
}
149+
expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
150+
expectEqual(200, pullback(at: 10, in: squared)(10))
151+
}
143152

144-
func identity(_ x: Float) -> Float {
145-
var c = Class(x)
146-
inoutIdentity(&c)
147-
return c.x
153+
do {
154+
func squaredViaModifyAccessor(_ c: inout Class) {
155+
// The line below calls `Class.x.modify`.
156+
c.x *= c.x
157+
}
158+
func squared(_ x: Float) -> Float {
159+
var c = Class(x)
160+
squaredViaModifyAccessor(&c)
161+
return c.x
162+
}
163+
// FIXME(TF-1080): Fix incorrect class property `modify` accessor derivative values.
164+
// expectEqual((100, 20), valueWithGradient(at: 10, in: squared))
165+
// expectEqual(200, pullback(at: 10, in: squared)(10))
166+
expectEqual((100, 1), valueWithGradient(at: 10, in: squared))
167+
expectEqual(10, pullback(at: 10, in: squared)(10))
148168
}
149-
expectEqual(1, gradient(at: 10, in: identity))
150-
expectEqual(10, pullback(at: 10, in: identity)(10))
151169
}
152170

153171
runAllTests()

0 commit comments

Comments
 (0)