Skip to content

Commit 89078cc

Browse files
authored
[AutoDiff] Add tests. (#28150)
Add the following: - Crasher tests. - Activity analysis tests. - Control flow differentiation known incorrect derivative tests. Crasher tests are in `test/AutoDiff/compiler_crashers`. They each reference a JIRA issue and start with `RUN: not --crash`. They should be moved to `test/AutoDiff/compiler_crashers_fixed` when they are fixed. The following issues are newly tracked: TF-429, TF-756, TF-781, TF-881, TF-923, TF-928, TF-945. Some of these issues are in-progress.
1 parent 473d2ee commit 89078cc

8 files changed

+281
-4
lines changed

test/AutoDiff/activity_analysis.swift

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: %target-swift-emit-sil -Xllvm -debug-only=differentiation 2>&1 %s | %FileCheck %s
2+
3+
// Check that `@noDerivative` struct projections have "NONE" activity.
4+
5+
struct HasNoDerivativeProperty: Differentiable {
6+
var x: Float
7+
@noDerivative var y: Float
8+
}
9+
@differentiable
10+
func testNoDerivativeStructProjection(_ s: HasNoDerivativeProperty) -> Float {
11+
var tmp = s
12+
tmp.y = tmp.x
13+
return tmp.x
14+
}
15+
16+
// CHECK-LABEL: [AD] Activity info for ${{.*}}testNoDerivativeStructProjection{{.*}} at (source=0 parameters=(0))
17+
// CHECK: [ACTIVE] %0 = argument of bb0 : $HasNoDerivativeProperty
18+
// CHECK: [ACTIVE] %2 = alloc_stack $HasNoDerivativeProperty, var, name "tmp"
19+
// CHECK: [ACTIVE] %4 = begin_access [read] [static] %2 : $*HasNoDerivativeProperty
20+
// CHECK: [ACTIVE] %5 = struct_element_addr %4 : $*HasNoDerivativeProperty, #HasNoDerivativeProperty.x
21+
// CHECK: [VARIED] %6 = load [trivial] %5 : $*Float
22+
// CHECK: [ACTIVE] %8 = begin_access [modify] [static] %2 : $*HasNoDerivativeProperty
23+
// CHECK: [NONE] %9 = struct_element_addr %8 : $*HasNoDerivativeProperty, #HasNoDerivativeProperty.y
24+
// CHECK: [ACTIVE] %12 = begin_access [read] [static] %2 : $*HasNoDerivativeProperty
25+
// CHECK: [ACTIVE] %13 = struct_element_addr %12 : $*HasNoDerivativeProperty, #HasNoDerivativeProperty.x
26+
// CHECK: [ACTIVE] %14 = load [trivial] %13 : $*Float
27+
28+
// TF-781: check activity analysis for active local address + nested conditionals.
29+
30+
@differentiable(wrt: x)
31+
func TF_781_function(_ x: Float, _ y: Float) -> Float {
32+
var result = y
33+
if true {
34+
if true {
35+
result = result * x
36+
}
37+
}
38+
return result
39+
}
40+
41+
// FIXME(TF-781): `%4 = alloc_stack $Float` is active, so all `begin_access`
42+
// users (and the results of their users, recursively) should also be active.
43+
44+
// CHECK-LABEL: [AD] Activity info for ${{.*}}TF_781_function{{.*}} at (source=0 parameters=(0))
45+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Float
46+
// CHECK: [USEFUL] %1 = argument of bb0 : $Float
47+
// CHECK: [ACTIVE] %4 = alloc_stack $Float, var, name "result"
48+
// CHECK: [USEFUL] %19 = begin_access [read] [static] %4 : $*Float
49+
// CHECK: [USEFUL] %20 = load [trivial] %19 : $*Float
50+
// CHECK: [ACTIVE] %23 = apply %22(%20, %0, %18) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
51+
// CHECK: [ACTIVE] %24 = begin_access [modify] [static] %4 : $*Float
52+
// CHECK: [USEFUL] %31 = begin_access [read] [static] %4 : $*Float
53+
// CHECK: [USEFUL] %32 = load [trivial] %31 : $*Float
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: not --crash %target-swift-emit-sil -enable-library-evolution %s
2+
// REQUIRES: asserts
3+
4+
// TF-429: Differentiation transform does not support
5+
// `-enable-library-evolution` because it assumes that differential/pullback
6+
// structs are always loadable, i.e. have object value category.
7+
8+
// Function must be public to trigger library evolution crash.
9+
@differentiable
10+
public func TF_429(_ x: Float) -> Float { x }
11+
12+
// Assertion failed: (mainPullbackStruct->getType() == pbStructLoweredType), function run, file /Users/danielzheng/swift-merge/swift/lib/SILOptimizer/Mandatory/Differentiation.cpp, line 6279.
13+
// Stack dump:
14+
// ...
15+
// 1. Swift version 5.1.1-dev (Swift c3cdcba346)
16+
// 2. While running pass #17 SILModuleTransform "Differentiation".
17+
// ...
18+
// 7 swiftc 0x0000000101620642 (anonymous namespace)::PullbackEmitter::run() + 3122
19+
// 8 swiftc 0x00000001015cb1e8 (anonymous namespace)::VJPEmitter::run() + 1224
20+
// 9 swiftc 0x00000001015c3348 (anonymous namespace)::ADContext::processDifferentiableAttribute(swift::SILFunction*, swift::SILDifferentiableAttr*, (anonymous namespace)::DifferentiationInvoker) + 4536
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: not --crash %target-swift-emit-ir -primary-file %s
2+
// REQUIRES: asserts
3+
4+
// TF-756: IRGen crash for `witness_method` instruction generated by the
5+
// differentiation transform.
6+
7+
struct Tensor<Scalar> {}
8+
extension Tensor: Differentiable where Scalar == Float {}
9+
10+
extension Tensor where Scalar == Float {
11+
// Arbitrary `@differentiable` operation with >1 parameter, so that index
12+
// subset thunk may be generated.
13+
@differentiable(vjp: _vjpAdd)
14+
static func + (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
15+
return lhs
16+
}
17+
18+
static func _vjpAdd(lhs: Tensor, rhs: Tensor)
19+
-> (Tensor, (TangentVector) -> (TangentVector, TangentVector)) {
20+
return (lhs + rhs, { v in (v, v) })
21+
}
22+
}
23+
24+
@differentiable
25+
func TF_756(input: Tensor<Float>) -> Tensor<Float> {
26+
let other = Tensor<Float>()
27+
return other + input
28+
}
29+
30+
// Assertion failed: (!type->hasArchetype() && !type->hasTypeParameter()), function getAddrOfTypeMetadataAccessFunction, file /Users/danielzheng/swift-merge/swift/lib/IRGen/GenDecl.cpp, line 3352.
31+
// Stack dump:
32+
// ...
33+
// 1. Swift version 5.1.1-dev (Swift 3943c1e36b)
34+
// 2. While emitting IR SIL function "@AD__$s4main6TensorVAASfRszlE13TangentVectorVySf_GA2FIegyyd_A2FIegyd_TR_src_0_wrt_1_differential_index_subset_thunk".
35+
// for expression at [/Users/danielzheng/swift-merge/swift/test/AutoDiff/compiler_crashers/tf756-irgen-witness-method-archetype.swift:27:16 - line:27:16] RangeText=""
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: not --crash %target-swift-emit-silgen %s
2+
// REQUIRES: asserts
3+
4+
// TF-881: User-defined Swift derivative functions cannot capture local values.
5+
// Captured local values become extra SIL function arguments, breaking the
6+
// expected derivative function type logic.
7+
//
8+
// In the short term, we should diagnose these cases to prevent crashes.
9+
// In the long term, we should investigate supporting these cases.
10+
11+
do {
12+
let capturedValue: Int = 3
13+
14+
func original(_ x: Float) -> Float { x }
15+
16+
@differentiating(original)
17+
func vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
18+
// Reference a local variable.
19+
// This causes the top-level SIL function @vjp to have extra arguments.
20+
_ = capturedValue
21+
return (x, { $0 })
22+
}
23+
}
24+
25+
// SIL verification failed: apply doesn't have right number of arguments for function: site.getNumArguments() == substConv.getNumSILArguments()
26+
// Verifying instruction:
27+
// %0 = argument of bb0 : $Float // user: %2
28+
// // function_ref vjp #1 (_:) in
29+
// %1 = function_ref @$s4main3vjpL_ySf5value_S2fc8pullbacktSfF : $@convention(thin) (Float, Int) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %2
30+
// -> %2 = apply %1(%0) : $@convention(thin) (Float, Int) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %3
31+
// return %2 : $(Float, @callee_guaranteed (Float) -> Float) // id: %3
32+
// In function:
33+
// // AD__$s4main8originalL_yS2fF__vjp_src_0_wrt_0
34+
// sil hidden [always_inline] [ossa] @AD__$s4main8originalL_yS2fF__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
35+
// // %0 // user: %2
36+
// bb0(%0 : $Float):
37+
// // function_ref vjp #1 (_:) in
38+
// %1 = function_ref @$s4main3vjpL_ySf5value_S2fc8pullbacktSfF : $@convention(thin) (Float, Int) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %2
39+
// %2 = apply %1(%0) : $@convention(thin) (Float, Int) -> (Float, @owned @callee_guaranteed (Float) -> Float) // user: %3
40+
// return %2 : $(Float, @callee_guaranteed (Float) -> Float) // id: %3
41+
// } // end sil function 'AD__$s4main8originalL_yS2fF__vjp_src_0_wrt_0'
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: not --crash %target-swift-emit-sil %s
2+
// REQUIRES: asserts
3+
4+
// TF-923: Ownership verification error in pullback function generated by the
5+
// differentiation transform.
6+
7+
struct Tensor<Scalar> {
8+
class Box {
9+
init() {}
10+
}
11+
var box: Box = Box()
12+
}
13+
extension Tensor: Equatable where Scalar: Equatable {
14+
static func ==(_: Self, _: Self) -> Bool { fatalError() }
15+
}
16+
extension Tensor: AdditiveArithmetic where Scalar: AdditiveArithmetic {
17+
static var zero: Self { fatalError() }
18+
static func +(_: Self, _: Self) -> Self { fatalError() }
19+
static func -(_: Self, _: Self) -> Self { fatalError() }
20+
}
21+
extension Tensor: Differentiable where Scalar: Differentiable & AdditiveArithmetic {
22+
typealias TangentVector = Self
23+
}
24+
25+
struct Tuple<T: Differentiable & AdditiveArithmetic>: Differentiable {
26+
var first: Tensor<T>
27+
@noDerivative var second: Tensor<T>
28+
}
29+
30+
@differentiable(wrt: (input))
31+
func TF_923<T>(_ input: Tensor<T>, _ bool: Bool) -> Tuple<T> {
32+
let x = bool ? input : input
33+
return Tuple(first: x, second: x)
34+
}
35+
36+
// Function: 'AD__$s4main6TF_923yAA5TupleVyxGAA6TensorVyxG_Sbts18AdditiveArithmeticRzs14DifferentiableRzlF__pullback_src_0_wrt_0'
37+
// Found use after free due to unvisited non lifetime ending uses?!
38+
// Value: %22 = load [take] %10 : $*Tensor<τ_0_0> // users: %88, %60, %47
39+
// Remaining Users:
40+
// User: %60 = copy_value %22 : $Tensor<τ_0_0> // user: %65
41+
// User: %88 = copy_value %22 : $Tensor<τ_0_0> // user: %93
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: not --crash %target-swift-emit-sil %s
2+
// REQUIRES: asserts
3+
4+
// TF-928: Ownership verification error in pullback function generated by the
5+
// differentiation transform.
6+
7+
struct Tracked<T> {
8+
class Box {
9+
init() {}
10+
}
11+
var box: Box = Box()
12+
}
13+
extension Tracked : Equatable where T : Equatable {
14+
static func ==(_: Self, _: Self) -> Bool { fatalError() }
15+
}
16+
extension Tracked : AdditiveArithmetic where T : AdditiveArithmetic {
17+
static var zero: Self { fatalError() }
18+
static func +(_: Self, _: Self) -> Self { fatalError() }
19+
static func -(_: Self, _: Self) -> Self { fatalError() }
20+
}
21+
extension Tracked : Differentiable where T : Differentiable, T == T.TangentVector {
22+
typealias TangentVector = Tracked<T.TangentVector>
23+
}
24+
25+
func TF_928(
26+
_ lossFunction: @differentiable (Tracked<Float>, Tracked<Float>) -> Tracked<Float>,
27+
_ x: Tracked<Float>
28+
) {
29+
_ = pullback(at: x) { x in lossFunction(x, Tracked<Float>()) }
30+
}
31+
32+
// Function: 'AD__$s4main6TF_928yyAA7TrackedVySfGAE_AEtXF_AEtFA2EcfU___pullback_src_0_wrt_0'
33+
// Error! Found a leaked owned value that was never consumed.
34+
// Value: (%5, **%6**) = destructure_tuple %3 : $(Tracked<Float>, Tracked<Float>)
35+
// Stack dump:
36+
// While verifying SIL function "@AD__$s4main6TF_928yyAA7TrackedVySfGAE_AEtXF_AEtFA2EcfU___pullback_src_0_wrt_0".

test/AutoDiff/control_flow.swift

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,48 @@ ControlFlowTests.test("NestedConditionals") {
355355
expectEqual((0, -1), gradient(at: 4, 21, in: nested3))
356356
expectEqual((1, 1), gradient(at: 4, 5, in: nested3))
357357
expectEqual((0, -1), gradient(at: -3, -2, in: nested3))
358+
359+
// TF-781: nested if derivative correctness.
360+
do {
361+
struct TF_781: Differentiable {
362+
var w: Float = 3
363+
364+
@differentiable(wrt: self) // wrt only self is important
365+
func callAsFunction(_ input: Float) -> Float {
366+
var x = input
367+
if true {
368+
if true {
369+
// Function application below should make `self` have non-zero
370+
// derivative.
371+
x = x * w
372+
}
373+
}
374+
return x
375+
}
376+
}
377+
let x: Float = 10
378+
// FIXME(TF-781): Fix zero gradients (related to activity analysis).
379+
// expectEqual(TF_781.TangentVector(w: x), gradient(at: TF_781()) { $0(x) })
380+
expectEqual(TF_781.TangentVector(w: 0), gradient(at: TF_781()) { $0(x) })
381+
}
382+
383+
// Non-method version of TF-781.
384+
do {
385+
@differentiable(wrt: x)
386+
func TF_781(_ x: Float, _ y: Float) -> Float {
387+
var result = y
388+
if true {
389+
if true {
390+
result = result * x
391+
}
392+
}
393+
return result
394+
}
395+
let x: Float = 10
396+
// FIXME(TF-781): Fix zero gradients (related to activity analysis).
397+
// expectEqual(x, gradient(at: 3) { TF_781($0, x) })
398+
expectEqual(0, gradient(at: 3) { TF_781($0, x) })
399+
}
358400
}
359401

360402
ControlFlowTests.test("Recursion") {
@@ -397,7 +439,7 @@ ControlFlowTests.test("Recursion") {
397439
}
398440
return y
399441
}
400-
// FIXME: Fix zero gradients (related to activity analysis).
442+
// FIXME(TF-933): Fix zero gradients (related to activity analysis).
401443
// See `factorial_var1` for the working version.
402444
/*
403445
expectEqual(0, gradient(at: 1, in: factorial_var2))
@@ -537,7 +579,7 @@ ControlFlowTests.test("Loops") {
537579
}
538580
return result
539581
}
540-
// TODO(TF-933): Fix incorrect derivatives when `var result` is not initially
582+
// FIXME(TF-933): Fix incorrect derivatives when `var result` is not initially
541583
// assigned to `x`.
542584
// expectEqual((4, 4), valueWithGradient(at: 2, in: for_loop_nonactive_initial_value))
543585
// expectEqual((9, 6), valueWithGradient(at: 3, in: for_loop_nonactive_initial_value))
@@ -565,7 +607,7 @@ ControlFlowTests.test("Loops") {
565607
}
566608
return result
567609
}
568-
// TODO(TF-933): Fix incorrect derivatives when `var result` is not initially
610+
// FIXME(TF-933): Fix incorrect derivatives when `var result` is not initially
569611
// assigned to `x`.
570612
// expectEqual((4, 4), valueWithGradient(at: 2, in: while_loop_nonactive_initial_value))
571613
// expectEqual((9, 6), valueWithGradient(at: 3, in: while_loop_nonactive_initial_value))
@@ -597,7 +639,7 @@ ControlFlowTests.test("Loops") {
597639
} while i < 2
598640
return result
599641
}
600-
// TODO(TF-584, TF-933): Fix incorrect derivatives when `var result` is not
642+
// FIXME(TF-584, TF-933): Fix incorrect derivatives when `var result` is not
601643
// initially assigned to `x`.
602644
// expectEqual((4, 4), valueWithGradient(at: 2, in: repeat_while_loop_nonactive_initial_value))
603645
// expectEqual((9, 6), valueWithGradient(at: 3, in: repeat_while_loop_nonactive_initial_value))

test/AutoDiff/simple_math.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,15 @@ SimpleMathTests.test("SideEffects") {
9797
expectEqual(4 * 27, gradient(at: 3, in: fourthPower))
9898
}
9999

100+
SimpleMathTests.test("Tuple") {
101+
// TF-945: Nested tuple projections.
102+
func nested(_ x: Float) -> Float {
103+
var tuple = (1, 1, ((x, 1), 1))
104+
return tuple.2.0.0
105+
}
106+
expectEqual(1, gradient(at: 3, in: nested))
107+
}
108+
100109
SimpleMathTests.test("TupleSideEffects") {
101110
func foo(_ x: Float) -> Float {
102111
var tuple = (x, x)

0 commit comments

Comments
 (0)