Skip to content

Commit 52ea38e

Browse files
authored
Remove remaining usages of CotangentVector and tangentVector(from:). (#24855)
1 parent 0a9181a commit 52ea38e

16 files changed

+25
-56
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3207,7 +3207,6 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
32073207
auto funcResultElt = derivativeResultTupleType->getElement(1);
32083208
// Get derivative kind and associated function identifier.
32093209
AutoDiffAssociatedFunctionKind kind;
3210-
Identifier autoDiffAssocTyId;
32113210
if (valueResultElt.getName().str() != "value") {
32123211
TC.diagnose(attr->getLocation(),
32133212
diag::differentiating_attr_invalid_result_tuple_value_label);
@@ -3216,10 +3215,8 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
32163215
}
32173216
if (funcResultElt.getName().str() == "differential") {
32183217
kind = AutoDiffAssociatedFunctionKind::JVP;
3219-
autoDiffAssocTyId = ctx.Id_TangentVector;
32203218
} else if (funcResultElt.getName().str() == "pullback") {
32213219
kind = AutoDiffAssociatedFunctionKind::VJP;
3222-
autoDiffAssocTyId = ctx.Id_TangentVector;
32233220
} else {
32243221
TC.diagnose(attr->getLocation(),
32253222
diag::differentiating_attr_invalid_result_tuple_func_label);
@@ -3376,14 +3373,14 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) {
33763373
assert(conf &&
33773374
"Expected checked parameter to conform to `Differentiable`");
33783375
auto paramAssocType = ProtocolConformanceRef::getTypeWitnessByName(
3379-
paramType, *conf, autoDiffAssocTyId, ctx.getLazyResolver());
3376+
paramType, *conf, ctx.Id_TangentVector, ctx.getLazyResolver());
33803377
return TupleTypeElt(paramAssocType);
33813378
});
33823379

33833380
// Check differential/pullback type.
33843381
// Get vector type: the associated type of the value result type.
33853382
auto vectorTy = ProtocolConformanceRef::getTypeWitnessByName(
3386-
valueResultType, *valueResultConf, autoDiffAssocTyId,
3383+
valueResultType, *valueResultConf, ctx.Id_TangentVector,
33873384
ctx.getLazyResolver());
33883385

33893386
// Compute expected differential/pullback type.

test/AutoDiff/array.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import StdlibUnittest
44

55
var ArrayAutodiffTests = TestSuite("ArrayAutodiff")
66

7-
typealias FloatArrayGrad = Array<Float>.CotangentVector
7+
typealias FloatArrayGrad = Array<Float>.TangentVector
88

99
ArrayAutodiffTests.test("ArrayIdentity") {
1010
func arrayIdentity(_ x: [Float]) -> [Float] {
@@ -39,21 +39,21 @@ ArrayAutodiffTests.test("ArrayConcat") {
3939
}
4040

4141
expectEqual(
42-
TwoArrays.CotangentVector(
42+
TwoArrays.TangentVector(
4343
a: FloatArrayGrad([1, 1]),
4444
b: FloatArrayGrad([1, 0])),
4545
gradient(
4646
at: TwoArrays(a: [0, 0], b: [0, 0]),
4747
in: sumFirstThreeConcatted))
4848
expectEqual(
49-
TwoArrays.CotangentVector(
49+
TwoArrays.TangentVector(
5050
a: FloatArrayGrad([1, 1, 1, 0]),
5151
b: FloatArrayGrad([0, 0])),
5252
gradient(
5353
at: TwoArrays(a: [0, 0, 0, 0], b: [0, 0]),
5454
in: sumFirstThreeConcatted))
5555
expectEqual(
56-
TwoArrays.CotangentVector(
56+
TwoArrays.TangentVector(
5757
a: FloatArrayGrad([]),
5858
b: FloatArrayGrad([1, 1, 1, 0])),
5959
gradient(

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ extension S : Differentiable, VectorNumeric {
3636
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
3737

3838
typealias TangentVector = S
39-
typealias CotangentVector = S
4039
}
4140

4241
// expected-error @+2 {{function is not differentiable}}

test/AutoDiff/derivative_registration.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,17 @@ extension Wrapper {
5858

5959
@differentiating(multiply)
6060
func _vjpMultiply(_ x: Float)
61-
-> (value: Float, pullback: (Float) -> (Wrapper.CotangentVector, Float)) {
61+
-> (value: Float, pullback: (Float) -> (Wrapper.TangentVector, Float)) {
6262
return (float * x, { v in
63-
(Wrapper.CotangentVector(float: v * x), v * self.float)
63+
(Wrapper.TangentVector(float: v * x), v * self.float)
6464
})
6565
}
6666
}
6767
DerivativeRegistrationTests.test("InstanceMethod") {
6868
let x: Float = 2
6969
let wrapper = Wrapper(float: 3)
7070
let (𝛁wrapper, 𝛁x) = wrapper.gradient(at: x) { wrapper, x in wrapper.multiply(x) }
71-
expectEqual(Wrapper.CotangentVector(float: 2), 𝛁wrapper)
71+
expectEqual(Wrapper.TangentVector(float: 2), 𝛁wrapper)
7272
expectEqual(3, 𝛁x)
7373
}
7474

test/AutoDiff/derived_conformances.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ DerivedConformanceTests.test("MemberwiseInitializers") {
2828
}
2929
expectEqual(HasNoDerivativeConstant.AllDifferentiableVariables(x: 0),
3030
HasNoDerivativeConstant.AllDifferentiableVariables.zero)
31-
expectEqual(HasNoDerivativeConstant.CotangentVector(x: 0),
32-
HasNoDerivativeConstant.CotangentVector.zero)
31+
expectEqual(HasNoDerivativeConstant.TangentVector(x: 0),
32+
HasNoDerivativeConstant.TangentVector.zero)
3333
}
3434

3535
runAllTests()

test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,6 @@ extension Tensor : Differentiable where Scalar : Differentiable {
543543
typealias TangentVector = Tensor
544544
typealias AllDifferentiableVariables = Tensor
545545
func moved(along direction: Tensor) -> Tensor { return self }
546-
func tangentVector(from cotangent: Tensor) -> Tensor { return cotangent }
547546
}
548547
@differentiable(where Scalar : Differentiable)
549548
func where2<Scalar : Numeric>(x: Tensor<Scalar>) -> Tensor<Scalar> {

test/AutoDiff/differentiable_requirement_cross_module.swift

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,7 @@ import differentiable_requirement_other_module
88
// The `foo` protocol requirement is `@differentiable` and has an `Empty` parameter.
99
extension Empty : Differentiable {
1010
public typealias TangentVector = Empty
11-
public typealias CotangentVector = Empty
1211
public typealias AllDifferentiableVariables = Empty
13-
14-
public func tangentVector(from cotangent: CotangentVector) -> TangentVector {
15-
return cotangent
16-
}
1712
}
1813

1914
// expected-error @+1 {{type 'Conforming' does not conform to protocol 'DifferentiableRequirement'}}

test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ struct TangentSpace : VectorNumeric {
1414

1515
extension TangentSpace : Differentiable {
1616
typealias TangentVector = TangentSpace
17-
typealias CotangentVector = TangentSpace
1817
}
1918

2019
struct Space {
@@ -48,13 +47,9 @@ struct Space {
4847

4948
extension Space : Differentiable {
5049
typealias TangentVector = TangentSpace
51-
typealias CotangentVector = TangentSpace
5250
func moved(along: TangentSpace) -> Space {
5351
return Space(x: x + along.dx, y: y + along.dy)
5452
}
55-
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
56-
return cotangent
57-
}
5853
}
5954

6055
E2EDifferentiablePropertyTests.test("computed property") {
@@ -78,16 +73,16 @@ struct GenericMemberWrapper<T : Differentiable> : Differentiable {
7873
@differentiable(vjp: vjpX)
7974
var x: T
8075

81-
func vjpX() -> (T, (T.CotangentVector) -> GenericMemberWrapper.CotangentVector) {
82-
return (x, { CotangentVector(x: $0) })
76+
func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) {
77+
return (x, { TangentVector(x: $0) })
8378
}
8479
}
8580

8681
E2EDifferentiablePropertyTests.test("generic stored property") {
8782
let actualGrad = gradient(at: GenericMemberWrapper<Float>(x: 1)) { point in
8883
return 2 * point.x
8984
}
90-
let expectedGrad = GenericMemberWrapper<Float>.CotangentVector(x: 2)
85+
let expectedGrad = GenericMemberWrapper<Float>.TangentVector(x: 2)
9186
expectEqual(expectedGrad, actualGrad)
9287
}
9388

@@ -98,7 +93,6 @@ struct ProductSpaceSelfTangent : VectorNumeric {
9893

9994
extension ProductSpaceSelfTangent : Differentiable {
10095
typealias TangentVector = ProductSpaceSelfTangent
101-
typealias CotangentVector = ProductSpaceSelfTangent
10296
}
10397

10498
E2EDifferentiablePropertyTests.test("fieldwise product space, self tangent") {
@@ -115,7 +109,6 @@ struct ProductSpaceOtherTangentTangentSpace : VectorNumeric {
115109

116110
extension ProductSpaceOtherTangentTangentSpace : Differentiable {
117111
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
118-
typealias CotangentVector = ProductSpaceOtherTangentTangentSpace
119112
}
120113

121114
@_fieldwiseDifferentiable
@@ -125,13 +118,9 @@ struct ProductSpaceOtherTangent {
125118

126119
extension ProductSpaceOtherTangent : Differentiable {
127120
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
128-
typealias CotangentVector = ProductSpaceOtherTangentTangentSpace
129121
func moved(along: ProductSpaceOtherTangentTangentSpace) -> ProductSpaceOtherTangent {
130122
return ProductSpaceOtherTangent(x: x + along.x, y: y + along.y)
131123
}
132-
func tangentVector(from cotangent: CotangentVector) -> TangentVector {
133-
return cotangent
134-
}
135124
}
136125

137126
E2EDifferentiablePropertyTests.test("fieldwise product space, other tangent") {

test/AutoDiff/method.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ extension Parameter {
4040

4141
extension Parameter : Differentiable, VectorNumeric {
4242
typealias TangentVector = Parameter
43-
typealias CotangentVector = Parameter
4443
typealias Scalar = Float
4544
typealias Shape = ()
4645
init(repeating repeatedValue: Float, shape: ()) {
@@ -150,7 +149,7 @@ struct DiffWrtSelf : Differentiable {
150149
return (x, { (dself, dx, dy) in dx })
151150
}
152151
func _vjpCall<T : Differentiable, U : Differentiable>(_ x: T, _ y: U)
153-
-> (T, (T.CotangentVector) -> (DiffWrtSelf.CotangentVector, T.CotangentVector, U.CotangentVector)) {
152+
-> (T, (T.TangentVector) -> (DiffWrtSelf.TangentVector, T.TangentVector, U.TangentVector)) {
154153
return (x, { (.zero, $0, .zero) })
155154
}
156155
}
@@ -166,7 +165,6 @@ struct CustomParameter : Equatable {
166165

167166
extension CustomParameter : Differentiable, VectorNumeric {
168167
typealias TangentVector = CustomParameter
169-
typealias CotangentVector = CustomParameter
170168
typealias Scalar = Float
171169
typealias Shape = ()
172170
init(repeating repeatedValue: Float, shape: ()) {

test/AutoDiff/protocol_requirement_autodiff.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ protocol DiffReq : Differentiable {
99
func f(_ x: Float) -> Float
1010
}
1111

12-
extension DiffReq where TangentVector : AdditiveArithmetic, CotangentVector : AdditiveArithmetic {
13-
func gradF(at x: Float) -> (Self.CotangentVector, Float) {
12+
extension DiffReq where TangentVector : AdditiveArithmetic {
13+
func gradF(at x: Float) -> (Self.TangentVector, Float) {
1414
return (valueWithPullback(at: x) { s, x in s.f(x) }).1(1)
1515
}
1616
}
1717

1818
struct Quadratic : DiffReq, Equatable {
1919
typealias TangentVector = Quadratic
20-
typealias CotangentVector = Quadratic
2120

2221
@differentiable(wrt: (self), vjp: vjpA)
2322
let a: Float

test/AutoDiff/refcounting.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ public struct Vector : AdditiveArithmetic, VectorNumeric, Differentiable, Equata
1111
public var y: Float
1212
public var nonTrivialStuff = NonTrivialStuff()
1313
public typealias TangentVector = Vector
14-
public typealias CotangentVector = Vector
15-
public func tangentVector(from cotangent: CotangentVector) -> TangentVector { return cotangent }
1614
public typealias Scalar = Float
1715
public static var zero: Vector { return Vector(0) }
1816
public init(_ scalar: Float) { self.x = scalar; self.y = scalar }

test/AutoDiff/separate_cotangent_type.swift

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ struct DifferentiableSubset : Differentiable {
2323
typealias TangentVector = DifferentiableSubset.TangentVector
2424
var w: Float
2525
var b: Float
26-
func tangentVector(from cotan: TangentVector) -> TangentVector {
27-
return TangentVector(w: cotan.w, b: cotan.b)
28-
}
2926
}
3027
func moved(along v: TangentVector) -> DifferentiableSubset {
3128
return DifferentiableSubset(w: w.moved(along: v.w), b: b.moved(along: v.b), flag: flag)

test/AutoDiff/simple_math.swift

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ SimpleMathTests.test("StructMemberwiseInitializer") {
185185
let 𝛁foo = pullback(at: Float(4), in: { input -> Foo in
186186
let foo = Foo(stored: input)
187187
return foo + foo
188-
})(Foo.CotangentVector(stored: 1))
188+
})(Foo.TangentVector(stored: 1))
189189
expectEqual(2, 𝛁foo)
190190

191191
let 𝛁computed = gradient(at: Float(4)) { input -> Float in
@@ -214,7 +214,7 @@ SimpleMathTests.test("StructMemberwiseInitializer") {
214214
let 𝛁custom = pullback(at: Float(4), in: { input -> Custom in
215215
let foo = Custom(x: input)
216216
return foo + foo
217-
})(Custom.CotangentVector(x: 1))
217+
})(Custom.TangentVector(x: 1))
218218
expectEqual(2, 𝛁custom)
219219
}
220220

@@ -238,7 +238,7 @@ SimpleMathTests.test("StructConstantStoredProperty") {
238238
let model = TF_319(x: 10)
239239
return model.applied(to: input)
240240
}
241-
expectEqual(TF_319.CotangentVector(x: 6),
241+
expectEqual(TF_319.TangentVector(x: 6),
242242
gradient(at: TF_319(x: 10), in: { $0.applied(to: 3) }))
243243
expectEqual(20, gradient(at: 3, in: testStructInit))
244244
}
@@ -282,7 +282,8 @@ SimpleMathTests.test("StructSideEffects") {
282282
}
283283
}
284284
let model = Add(bias: 1)
285-
expectEqual(Add.CotangentVector(bias: 1), gradient(at: model) { m in m.applied(to: 1) })
285+
expectEqual(Add.TangentVector(bias: 1),
286+
gradient(at: model) { m in m.applied(to: 1) })
286287
}
287288

288289
SimpleMathTests.test("StructGeneric") {
@@ -295,7 +296,7 @@ SimpleMathTests.test("StructGeneric") {
295296
let 𝛁generic = pullback(at: Float(3), in: { input -> Generic<Float> in
296297
var generic = Generic(x: input, y: input, z: input)
297298
return generic
298-
})(Generic<Float>.CotangentVector(x: 1, y: 1, z: 1))
299+
})(Generic<Float>.TangentVector(x: 1, y: 1, z: 1))
299300
expectEqual(3, 𝛁generic)
300301

301302
func fifthPower(_ input: Float) -> Float {

test/AutoDiff/simple_model.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ struct DenseLayer : Equatable {
2121

2222
extension DenseLayer : Differentiable, VectorNumeric {
2323
typealias TangentVector = DenseLayer
24-
typealias CotangentVector = DenseLayer
2524
typealias Scalar = Float
2625
static var zero: DenseLayer {
2726
return DenseLayer(w: 0, b: 0)
@@ -68,7 +67,6 @@ struct Model : Equatable {
6867

6968
extension Model : Differentiable, VectorNumeric {
7069
typealias TangentVector = Model
71-
typealias CotangentVector = Model
7270
typealias Scalar = Float
7371
static var zero: Model {
7472
return Model(l1: DenseLayer.zero, l2: DenseLayer.zero, l3: DenseLayer.zero)

test/AutoDiff/witness_table_silgen.swift

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ struct S : Proto, VectorNumeric {
2020
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
2121

2222
typealias TangentVector = S
23-
typealias CotangentVector = S
2423

2524
@differentiable(wrt: (self), vjp: vjpP)
2625
let p: Float

test/Serialization/differentiating_attr.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func generic<T : Numeric>(x: T) -> T {
2626
return x
2727
}
2828
@differentiating(generic)
29-
func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.CotangentVector) -> T.CotangentVector)
29+
func vjpGeneric<T>(x: T) -> (value: T, pullback: (T.TangentVector) -> T.TangentVector)
3030
where T : Numeric, T : Differentiable
3131
{
3232
return (x, { v in v })
@@ -40,7 +40,7 @@ protocol InstanceMethod : Differentiable {
4040
}
4141
extension InstanceMethod {
4242
@differentiating(foo)
43-
func vjpFoo(x: Self) -> (value: Self, pullback: (Self.CotangentVector) -> (Self.CotangentVector, Self.CotangentVector)) {
43+
func vjpFoo(x: Self) -> (value: Self, pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
4444
return (x, { ($0, $0) })
4545
}
4646

0 commit comments

Comments
 (0)