Skip to content

Commit 88c5f65

Browse files
committed
Replace SIMD use of @differentiable(jvp:vjp:) with @Derivative(of:)
1 parent bba00c6 commit 88c5f65

File tree

2 files changed

+32
-119
lines changed

2 files changed

+32
-119
lines changed

stdlib/public/core/SIMDVector.swift

Lines changed: 30 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,6 @@ extension SIMD {
9090

9191
/// A vector with the specified value in all lanes.
9292
@_transparent
93-
// SWIFT_ENABLE_TENSORFLOW
94-
@differentiable(vjp: _vjpInit(repeating:)
95-
where Self : Differentiable,
96-
Self.TangentVector : SIMD,
97-
Scalar : BinaryFloatingPoint & Differentiable,
98-
Self.TangentVector == Self,
99-
Scalar.TangentVector == Scalar)
10093
public init(repeating value: Scalar) {
10194
self.init()
10295
for i in indices { self[i] = value }
@@ -788,51 +781,27 @@ extension SIMD where Scalar: FixedWidthInteger {
788781
// be replaced with @_semantics to lower directly to vector IR nodes.
789782
extension SIMD where Scalar : FloatingPoint {
790783
@_transparent
791-
// SWIFT_ENABLE_TENSORFLOW
792-
@differentiable(vjp: _vjpAdd(lhs:rhs:)
793-
where Self : Differentiable,
794-
Self.TangentVector : SIMD,
795-
Scalar : BinaryFloatingPoint,
796-
Self.TangentVector.Scalar : BinaryFloatingPoint)
797784
public static func +(lhs: Self, rhs: Self) -> Self {
798785
var result = Self()
799786
for i in result.indices { result[i] = lhs[i] + rhs[i] }
800787
return result
801788
}
802789

803790
@_transparent
804-
// SWIFT_ENABLE_TENSORFLOW
805-
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
806-
where Self : Differentiable,
807-
Self.TangentVector : SIMD,
808-
Scalar : BinaryFloatingPoint,
809-
Self.TangentVector.Scalar : BinaryFloatingPoint)
810791
public static func -(lhs: Self, rhs: Self) -> Self {
811792
var result = Self()
812793
for i in result.indices { result[i] = lhs[i] - rhs[i] }
813794
return result
814795
}
815796

816797
@_transparent
817-
// SWIFT_ENABLE_TENSORFLOW
818-
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
819-
where Self : Differentiable,
820-
Self.TangentVector : SIMD,
821-
Scalar : BinaryFloatingPoint,
822-
Self.TangentVector == Self)
823798
public static func *(lhs: Self, rhs: Self) -> Self {
824799
var result = Self()
825800
for i in result.indices { result[i] = lhs[i] * rhs[i] }
826801
return result
827802
}
828803

829804
@_transparent
830-
// SWIFT_ENABLE_TENSORFLOW
831-
@differentiable(vjp: _vjpDivide(lhs:rhs:)
832-
where Self : Differentiable,
833-
Self.TangentVector : SIMD,
834-
Scalar : BinaryFloatingPoint,
835-
Self.TangentVector == Self)
836805
public static func /(lhs: Self, rhs: Self) -> Self {
837806
var result = Self()
838807
for i in result.indices { result[i] = lhs[i] / rhs[i] }
@@ -877,12 +846,6 @@ extension SIMD where Scalar : FloatingPoint {
877846
// FIXME: TF-545 we want the sum() func to be marked as
878847
// `@_alwaysEmitIntoClient` like before when we define the VJP
879848
@inlinable
880-
@differentiable(vjp: _vjpSum
881-
where Self : Differentiable,
882-
Self.TangentVector : SIMD,
883-
Scalar : BinaryFloatingPoint & Differentiable,
884-
Scalar.TangentVector : BinaryFloatingPoint,
885-
Self.TangentVector == Self)
886849
public func sum() -> Scalar {
887850
// Implementation note: this eventually be defined to lower to either
888851
// llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
@@ -1197,108 +1160,46 @@ extension SIMD where Scalar: FixedWidthInteger {
11971160
extension SIMD where Scalar: FloatingPoint {
11981161

11991162
@_transparent
1200-
// SWIFT_ENABLE_TENSORFLOW
1201-
@differentiable(vjp: _vjpNegate(rhs:)
1202-
where Self : Differentiable,
1203-
Self.TangentVector : SIMD,
1204-
Scalar : BinaryFloatingPoint,
1205-
Self.TangentVector.Scalar : BinaryFloatingPoint)
12061163
public static prefix func -(rhs: Self) -> Self {
12071164
return 0 - rhs
12081165
}
12091166

12101167
@_transparent
1211-
// SWIFT_ENABLE_TENSORFLOW
1212-
@differentiable(vjp: _vjpAdd(lhs:rhs:)
1213-
where Self : Differentiable,
1214-
Self.TangentVector : SIMD,
1215-
Scalar : Differentiable & BinaryFloatingPoint,
1216-
Scalar.TangentVector : BinaryFloatingPoint,
1217-
Self.TangentVector.Scalar == Scalar.TangentVector)
12181168
public static func +(lhs: Scalar, rhs: Self) -> Self {
12191169
return Self(repeating: lhs) + rhs
12201170
}
12211171

12221172
@_transparent
1223-
// SWIFT_ENABLE_TENSORFLOW
1224-
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
1225-
where Self : Differentiable,
1226-
Self.TangentVector : SIMD,
1227-
Scalar : Differentiable & BinaryFloatingPoint,
1228-
Scalar.TangentVector : BinaryFloatingPoint,
1229-
Self.TangentVector.Scalar == Scalar.TangentVector)
12301173
public static func -(lhs: Scalar, rhs: Self) -> Self {
12311174
return Self(repeating: lhs) - rhs
12321175
}
12331176

12341177
@_transparent
1235-
// SWIFT_ENABLE_TENSORFLOW
1236-
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
1237-
where Self : Differentiable,
1238-
Self.TangentVector : SIMD,
1239-
Scalar : BinaryFloatingPoint & Differentiable,
1240-
Self.TangentVector == Self,
1241-
Scalar.TangentVector == Scalar)
12421178
public static func *(lhs: Scalar, rhs: Self) -> Self {
12431179
return Self(repeating: lhs) * rhs
12441180
}
12451181

12461182
@_transparent
1247-
// SWIFT_ENABLE_TENSORFLOW
1248-
@differentiable(vjp: _vjpDivide(lhs:rhs:)
1249-
where Self : Differentiable,
1250-
Self.TangentVector : SIMD,
1251-
Scalar : BinaryFloatingPoint & Differentiable,
1252-
Self.TangentVector == Self,
1253-
Scalar.TangentVector == Scalar)
12541183
public static func /(lhs: Scalar, rhs: Self) -> Self {
12551184
return Self(repeating: lhs) / rhs
12561185
}
12571186

12581187
@_transparent
1259-
// SWIFT_ENABLE_TENSORFLOW
1260-
@differentiable(vjp: _vjpAdd(lhs:rhs:)
1261-
where Self : Differentiable,
1262-
Self.TangentVector : SIMD,
1263-
Scalar : Differentiable & BinaryFloatingPoint,
1264-
Scalar.TangentVector : BinaryFloatingPoint,
1265-
Self.TangentVector.Scalar == Scalar.TangentVector)
12661188
public static func +(lhs: Self, rhs: Scalar) -> Self {
12671189
return lhs + Self(repeating: rhs)
12681190
}
12691191

12701192
@_transparent
1271-
// SWIFT_ENABLE_TENSORFLOW
1272-
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
1273-
where Self : Differentiable,
1274-
Self.TangentVector : SIMD,
1275-
Scalar : Differentiable & BinaryFloatingPoint,
1276-
Scalar.TangentVector : BinaryFloatingPoint,
1277-
Self.TangentVector.Scalar == Scalar.TangentVector)
12781193
public static func -(lhs: Self, rhs: Scalar) -> Self {
12791194
return lhs - Self(repeating: rhs)
12801195
}
12811196

12821197
@_transparent
1283-
// SWIFT_ENABLE_TENSORFLOW
1284-
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
1285-
where Self : Differentiable,
1286-
Self.TangentVector : SIMD,
1287-
Scalar : BinaryFloatingPoint & Differentiable,
1288-
Self.TangentVector == Self,
1289-
Scalar.TangentVector == Scalar)
12901198
public static func *(lhs: Self, rhs: Scalar) -> Self {
12911199
return lhs * Self(repeating: rhs)
12921200
}
12931201

12941202
@_transparent
1295-
// SWIFT_ENABLE_TENSORFLOW
1296-
@differentiable(vjp: _vjpDivide(lhs:rhs:)
1297-
where Self : Differentiable,
1298-
Self.TangentVector : SIMD,
1299-
Scalar : BinaryFloatingPoint & Differentiable,
1300-
Self.TangentVector == Self,
1301-
Scalar.TangentVector == Scalar)
13021203
public static func /(lhs: Self, rhs: Scalar) -> Self {
13031204
return lhs / Self(repeating: rhs)
13041205
}
@@ -1520,24 +1421,27 @@ extension SIMD
15201421
Scalar : BinaryFloatingPoint,
15211422
TangentVector.Scalar : BinaryFloatingPoint {
15221423
@inlinable
1424+
@derivative(of: +)
15231425
static func _vjpAdd(lhs: Self, rhs: Self)
1524-
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1426+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
15251427
return (lhs + rhs, { v in
15261428
return (v, v)
15271429
})
15281430
}
15291431

15301432
@inlinable
1433+
@derivative(of: -)
15311434
static func _vjpSubtract(lhs: Self, rhs: Self)
1532-
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1435+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
15331436
return (lhs - rhs, { v in
15341437
return (v, -v)
15351438
})
15361439
}
15371440

15381441
@inlinable
1442+
@derivative(of: -)
15391443
static func _vjpNegate(rhs: Self)
1540-
-> (Self, (TangentVector) -> (TangentVector)) {
1444+
-> (value: Self, pullback: (TangentVector) -> (TangentVector)) {
15411445
return (-rhs, { v in
15421446
return -v
15431447
})
@@ -1550,16 +1454,18 @@ extension SIMD
15501454
Scalar : BinaryFloatingPoint,
15511455
Self.TangentVector == Self {
15521456
@inlinable
1457+
@derivative(of: *)
15531458
static func _vjpMultiply(lhs: Self, rhs: Self)
1554-
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1459+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
15551460
return (lhs * rhs, { v in
15561461
return (v * rhs, v * lhs)
15571462
})
15581463
}
15591464

15601465
@inlinable
1466+
@derivative(of: /)
15611467
static func _vjpDivide(lhs: Self, rhs: Self)
1562-
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1468+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
15631469
return (lhs / rhs, { v in
15641470
(v / rhs, -lhs / (rhs * rhs) * v)
15651471
})
@@ -1573,32 +1479,36 @@ extension SIMD
15731479
Scalar.TangentVector : BinaryFloatingPoint,
15741480
TangentVector.Scalar == Scalar.TangentVector {
15751481
@inlinable
1482+
@derivative(of: +)
15761483
static func _vjpAdd(lhs: Scalar, rhs: Self)
1577-
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1484+
-> (value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
15781485
return (lhs + rhs, { v in
15791486
return (v.sum(), v)
15801487
})
15811488
}
15821489

15831490
@inlinable
1491+
@derivative(of: -)
15841492
static func _vjpSubtract(lhs: Scalar, rhs: Self)
1585-
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1493+
-> (value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
15861494
return (lhs - rhs, { v in
15871495
return (v.sum(), -v)
15881496
})
15891497
}
15901498

15911499
@inlinable
1500+
@derivative(of: +)
15921501
static func _vjpAdd(lhs: Self, rhs: Scalar)
1593-
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1502+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
15941503
return (lhs + rhs, { v in
15951504
return (v, v.sum())
15961505
})
15971506
}
15981507

15991508
@inlinable
1509+
@derivative(of: -)
16001510
static func _vjpSubtract(lhs: Self, rhs: Scalar)
1601-
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1511+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
16021512
return (lhs - rhs, { v in
16031513
return (v, -v.sum())
16041514
})
@@ -1612,32 +1522,36 @@ extension SIMD
16121522
Self.TangentVector == Self,
16131523
Scalar.TangentVector == Scalar {
16141524
@inlinable
1525+
@derivative(of: *)
16151526
static func _vjpMultiply(lhs: Self, rhs: Scalar)
1616-
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1527+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
16171528
return (lhs * rhs, { v in
16181529
return (v * rhs, (v * lhs).sum())
16191530
})
16201531
}
16211532

16221533
@inlinable
1534+
@derivative(of: /)
16231535
static func _vjpDivide(lhs: Self, rhs: Scalar)
1624-
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1536+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
16251537
return (lhs / rhs, { v in
16261538
(v / rhs, (-lhs / (rhs * rhs) * v).sum())
16271539
})
16281540
}
16291541

16301542
@inlinable
1543+
@derivative(of: *)
16311544
static func _vjpMultiply(lhs: Scalar, rhs: Self)
1632-
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1545+
-> (value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
16331546
return (lhs * rhs, { v in
16341547
return ((v * rhs).sum(), v * lhs)
16351548
})
16361549
}
16371550

16381551
@inlinable
1552+
@derivative(of: /)
16391553
static func _vjpDivide(lhs: Scalar, rhs: Self)
1640-
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1554+
-> (value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
16411555
return (lhs / rhs, { v in
16421556
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
16431557
})
@@ -1651,7 +1565,8 @@ extension SIMD
16511565
Scalar.TangentVector : BinaryFloatingPoint,
16521566
TangentVector == Self {
16531567
@inlinable
1654-
func _vjpSum() -> (Scalar, (Scalar.TangentVector) -> TangentVector) {
1568+
@derivative(of: sum)
1569+
func _vjpSum() -> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector) {
16551570
return (sum(), { v in Self(repeating: Scalar(v)) })
16561571
}
16571572
}
@@ -1663,8 +1578,9 @@ extension SIMD
16631578
Self.TangentVector == Self,
16641579
Scalar.TangentVector == Scalar {
16651580
@usableFromInline
1581+
@derivative(of: init(repeating:))
16661582
static func _vjpInit(repeating value: Scalar)
1667-
-> (Self, (TangentVector) -> Scalar.TangentVector) {
1583+
-> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector) {
16681584
return (Self(repeating: value), { v in v.sum() })
16691585
}
16701586
}

stdlib/public/core/SIMDVectorTypes.swift.gyb

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,6 @@ public struct SIMD${n}<Scalar>: SIMD where Scalar: SIMDScalar {
4545
}
4646

4747
/// Accesses the scalar at the specified position.
48-
// SWIFT_ENABLE_TENSORFLOW
49-
@differentiable(vjp: _vjpSubscript
50-
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
51-
Scalar.TangentVector : BinaryFloatingPoint)
5248
public subscript(index: Int) -> Scalar {
5349
@_transparent get {
5450
_precondition(indices.contains(index))
@@ -207,8 +203,9 @@ extension SIMD${n}
207203
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
208204
Scalar.TangentVector : BinaryFloatingPoint {
209205
@usableFromInline
206+
@derivative(of: subscript(index:))
210207
internal func _vjpSubscript(index: Int)
211-
-> (Scalar, (Scalar.TangentVector) -> TangentVector) {
208+
-> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector) {
212209
return (self[index], { v in
213210
var zeros = Self.zero
214211
zeros[index] = Scalar(v)

0 commit comments

Comments
 (0)