Skip to content

Commit 3b8b632

Browse files
authored
[AutoDiff upstream] Add @differentiable protocol requirement tests. (#32939)
1 parent 954fc46 commit 3b8b632

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,23 @@ public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
544544
func logProbability(of value: Value) -> Float
545545
}
546546

547+
// Test failure to satisfy protocol requirement's `@differentiable` attribute.
548+
549+
public protocol HasRequirement {
550+
@differentiable
551+
// expected-note @+1 {{protocol requires function 'requirement' with type '<T> (T, T) -> T'; do you want to add a stub?}}
552+
func requirement<T: Differentiable>(_ x: T, _ y: T) -> T
553+
}
554+
555+
// expected-error @+1 {{type 'AttemptsToSatisfyRequirement' does not conform to protocol 'HasRequirement'}}
556+
public struct AttemptsToSatisfyRequirement: HasRequirement {
557+
// This `@differentiable` attribute does not satisfy the requirement because
558+
// it is mroe constrained than the requirement's `@differentiable` attribute.
559+
@differentiable(where T: CustomStringConvertible)
560+
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}}
561+
public func requirement<T: Differentiable>(_ x: T, _ y: T) -> T { x }
562+
}
563+
547564
// Test protocol requirement `@differentiable` attribute unsupported features.
548565

549566
protocol ProtocolRequirementUnsupported: Differentiable {
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// RUN: %target-run-simple-swift
2+
3+
import StdlibUnittest
4+
import DifferentiationUnittest
5+
6+
// Test end-to-end differentiation of `@differentiable` protocol requirements.
7+
8+
var ProtocolRequirementAutodiffTests = TestSuite("ProtocolRequirementDifferentiation")
9+
10+
// MARK: - Method requirements.
11+
12+
protocol DiffReq: Differentiable {
13+
@differentiable(wrt: (self, x))
14+
func f(_ x: Tracked<Float>) -> Tracked<Float>
15+
}
16+
17+
extension DiffReq where TangentVector: AdditiveArithmetic {
18+
@inline(never) // Prevent specialization, to test all witness code.
19+
func gradF(at x: Tracked<Float>) -> (Self.TangentVector, Tracked<Float>) {
20+
return (valueWithPullback(at: self, x) { s, x in s.f(x) }).1(1)
21+
}
22+
}
23+
24+
struct Quadratic: DiffReq, AdditiveArithmetic {
25+
typealias TangentVector = Quadratic
26+
27+
@differentiable
28+
let a: Tracked<Float>
29+
30+
@differentiable
31+
let b: Tracked<Float>
32+
33+
@differentiable
34+
let c: Tracked<Float>
35+
36+
init(_ a: Tracked<Float>, _ b: Tracked<Float>, _ c: Tracked<Float>) {
37+
self.a = a
38+
self.b = b
39+
self.c = c
40+
}
41+
42+
@differentiable(wrt: (self, x))
43+
func f(_ x: Tracked<Float>) -> Tracked<Float> {
44+
return a * x * x + b * x + c
45+
}
46+
}
47+
48+
ProtocolRequirementAutodiffTests.testWithLeakChecking("func") {
49+
expectEqual((Quadratic(0, 0, 1), 12), Quadratic(11, 12, 13).gradF(at: 0))
50+
expectEqual((Quadratic(1, 1, 1), 2 * 11 + 12),
51+
Quadratic(11, 12, 13).gradF(at: 1))
52+
expectEqual((Quadratic(4, 2, 1), 2 * 11 * 2 + 12),
53+
Quadratic(11, 12, 13).gradF(at: 2))
54+
}
55+
56+
// MARK: - Constructor, accessor, and subscript requirements.
57+
58+
protocol FunctionsOfX: Differentiable {
59+
@differentiable
60+
init(x: Tracked<Float>)
61+
62+
@differentiable
63+
var x: Tracked<Float> { get }
64+
65+
@differentiable
66+
var y: Tracked<Float> { get }
67+
68+
@differentiable
69+
var z: Tracked<Float> { get }
70+
71+
@differentiable
72+
subscript() -> Tracked<Float> { get }
73+
}
74+
75+
struct TestFunctionsOfX: FunctionsOfX {
76+
@differentiable
77+
init(x: Tracked<Float>) {
78+
self.x = x
79+
self.y = x * x
80+
}
81+
82+
/// x = x
83+
var x: Tracked<Float>
84+
85+
/// y = x * x
86+
var y: Tracked<Float>
87+
88+
/// z = x * x + x
89+
var z: Tracked<Float> {
90+
return y + x
91+
}
92+
93+
@differentiable
94+
subscript() -> Tracked<Float> {
95+
return z
96+
}
97+
}
98+
99+
@inline(never) // Prevent specialization, to test all witness code.
100+
func derivatives<F: FunctionsOfX>(at x: Tracked<Float>, in: F.Type)
101+
-> (Tracked<Float>, Tracked<Float>, Tracked<Float>, Tracked<Float>)
102+
{
103+
let dxdx = gradient(at: x) { x in F(x: x).x }
104+
let dydx = gradient(at: x) { x in F(x: x).y }
105+
let dzdx = gradient(at: x) { x in F(x: x).z }
106+
let dsubscriptdx = gradient(at: x) { x in F(x: x)[] }
107+
return (dxdx, dydx, dzdx, dsubscriptdx)
108+
}
109+
110+
ProtocolRequirementAutodiffTests.testWithLeakChecking("constructor, accessor, subscript") {
111+
expectEqual(
112+
(1.0, 4.0, 5.0, 5.0),
113+
derivatives(at: 2.0, in: TestFunctionsOfX.self))
114+
}
115+
116+
// MARK: - Test witness method SIL type computation.
117+
118+
protocol P: Differentiable {
119+
@differentiable(wrt: (x, y))
120+
func foo(_ x: Tracked<Float>, _ y: Double) -> Tracked<Float>
121+
}
122+
struct S: P {
123+
@differentiable(wrt: (x, y))
124+
func foo(_ x: Tracked<Float>, _ y: Double) -> Tracked<Float> {
125+
return x
126+
}
127+
}
128+
129+
// MARK: - Overridding protocol method adding `@differentiable` attribute.
130+
131+
public protocol Distribution {
132+
associatedtype Value
133+
func logProbability(of value: Value) -> Tracked<Float>
134+
}
135+
136+
public protocol DifferentiableDistribution: Differentiable, Distribution {
137+
@differentiable(wrt: self)
138+
func logProbability(of value: Value) -> Tracked<Float>
139+
}
140+
141+
struct Foo: DifferentiableDistribution {
142+
@differentiable(wrt: self)
143+
func logProbability(of value: Tracked<Float>) -> Tracked<Float> {
144+
.zero
145+
}
146+
}
147+
148+
@differentiable
149+
func blah<T: DifferentiableDistribution>(_ x: T) -> Tracked<Float>
150+
where T.Value: AdditiveArithmetic {
151+
x.logProbability(of: .zero)
152+
}
153+
154+
// Adding a more general `@differentiable` attribute.
155+
156+
public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
157+
where Value: Differentiable {
158+
@differentiable(wrt: self)
159+
@differentiable(wrt: (self, value))
160+
func logProbability(of value: Value) -> Tracked<Float>
161+
}
162+
163+
@differentiable
164+
func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Tracked<Float>
165+
where T.Value: AdditiveArithmetic {
166+
x.logProbability(of: value)
167+
}
168+
169+
// Satisfying the requirement with more wrt parameter indices than are necessary.
170+
171+
protocol DifferentiableFoo {
172+
associatedtype T: Differentiable
173+
@differentiable(wrt: x)
174+
func foo(_ x: T) -> Tracked<Float>
175+
}
176+
177+
protocol MoreDifferentiableFoo: Differentiable, DifferentiableFoo {
178+
@differentiable(wrt: (self, x))
179+
func foo(_ x: T) -> Tracked<Float>
180+
}
181+
182+
struct MoreDifferentiableFooStruct: MoreDifferentiableFoo {
183+
@differentiable(wrt: (self, x))
184+
func foo(_ x: Tracked<Float>) -> Tracked<Float> {
185+
x
186+
}
187+
}
188+
189+
// Satisfying the requirement with a less-constrained derivative than is necessary.
190+
191+
protocol ExtraDerivativeConstraint {}
192+
193+
protocol HasExtraConstrainedDerivative {
194+
@differentiable
195+
func requirement<T: Differentiable & ExtraDerivativeConstraint>(_ x: T) -> T
196+
}
197+
198+
struct SatisfiesDerivativeWithLessConstraint: HasExtraConstrainedDerivative {
199+
@differentiable
200+
func requirement<T: Differentiable>(_ x: T) -> T { x }
201+
}
202+
203+
runAllTests()

0 commit comments

Comments
 (0)