Skip to content

Commit 0268fc6

Browse files
committed
Make all necessary operators differentiable.
1 parent 83e7f75 commit 0268fc6

File tree

2 files changed

+186
-62
lines changed

2 files changed

+186
-62
lines changed

stdlib/public/core/SIMDVector.swift

Lines changed: 167 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,6 @@ extension SIMD {
213213
/// in this vector. Because of this, the index is always in-range and no trap
214214
/// can occur.
215215
@_alwaysEmitIntoClient
216-
// @differentiable(vjp: _vjpSubscript(index:)
217-
// where Self : SIMDDifferentiable,
218-
// Scalar : SIMDDifferentiable)
219216
public subscript<Index>(index: SIMD2<Index>) -> SIMD2<Scalar>
220217
where Index: FixedWidthInteger {
221218
var result = SIMD2<Scalar>()
@@ -785,23 +782,34 @@ extension SIMD where Scalar: FixedWidthInteger {
785782
extension SIMD where Scalar: FloatingPoint {
786783
@_transparent
787784
@differentiable(vjp: _vjpAdd(lhs:rhs:)
788-
where Self : SIMDDifferentiable)
785+
where Self : Differentiable,
786+
Self.CotangentVector : SIMD,
787+
Scalar : BinaryFloatingPoint,
788+
Self.CotangentVector.Scalar: BinaryFloatingPoint)
789789
public static func +(lhs: Self, rhs: Self) -> Self {
790790
var result = Self()
791791
for i in result.indices { result[i] = lhs[i] + rhs[i] }
792792
return result
793793
}
794794

795795
@_transparent
796-
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
797-
where Self : SIMDDifferentiable)
796+
@differentiable(vjp: _vjpAdd(lhs:rhs:)
797+
where Self: Differentiable,
798+
Self.CotangentVector: SIMD,
799+
Scalar : BinaryFloatingPoint,
800+
Self.CotangentVector.Scalar: BinaryFloatingPoint)
798801
public static func -(lhs: Self, rhs: Self) -> Self {
799802
var result = Self()
800803
for i in result.indices { result[i] = lhs[i] - rhs[i] }
801804
return result
802805
}
803806

804807
@_transparent
808+
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
809+
where Self: Differentiable,
810+
Self.CotangentVector: SIMD,
811+
Scalar : BinaryFloatingPoint,
812+
Self.CotangentVector == Self)
805813
public static func *(lhs: Self, rhs: Self) -> Self {
806814
var result = Self()
807815
for i in result.indices { result[i] = lhs[i] * rhs[i] }
@@ -851,8 +859,11 @@ extension SIMD where Scalar: FloatingPoint {
851859
/// Returns the sum of the scalars in the vector.
852860
@_alwaysEmitIntoClient
853861
@differentiable(vjp: _vjpSum
854-
where Self : SIMDDifferentiable,
855-
Scalar : SIMDDifferentiable)
862+
where Self : Differentiable,
863+
Self.CotangentVector : SIMD,
864+
Scalar : BinaryFloatingPoint & Differentiable,
865+
Scalar.CotangentVector : BinaryFloatingPoint,
866+
Self.CotangentVector == Self)
856867
public func sum() -> Scalar {
857868
// Implementation note: this eventually be defined to lower to either
858869
// llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
@@ -1166,59 +1177,95 @@ extension SIMD where Scalar: FixedWidthInteger {
11661177

11671178
extension SIMD where Scalar: FloatingPoint {
11681179

1169-
@_transparent // ????
1180+
@_transparent
11701181
public static prefix func -(rhs: Self) -> Self {
11711182
return 0 - rhs
11721183
}
11731184

11741185
@_transparent
11751186
@differentiable(vjp: _vjpAdd(lhs:rhs:)
1176-
where Scalar : SIMDDifferentiable,
1177-
Self : SIMDDifferentiable)
1187+
where Self: Differentiable,
1188+
Self.CotangentVector: SIMD,
1189+
Scalar : Differentiable & BinaryFloatingPoint,
1190+
Scalar.CotangentVector: BinaryFloatingPoint,
1191+
Self.CotangentVector.Scalar == Scalar.CotangentVector)
11781192
public static func +(lhs: Scalar, rhs: Self) -> Self {
11791193
return Self(repeating: lhs) + rhs
11801194
}
11811195

11821196
@_transparent
11831197
@differentiable(vjp: _vjpAdd(lhs:rhs:)
1184-
where Scalar : SIMDDifferentiable,
1185-
Self : SIMDDifferentiable)
1198+
where Self: Differentiable,
1199+
Self.CotangentVector: SIMD,
1200+
Scalar : Differentiable & BinaryFloatingPoint,
1201+
Scalar.CotangentVector: BinaryFloatingPoint,
1202+
Self.CotangentVector.Scalar == Scalar.CotangentVector)
11861203
public static func -(lhs: Scalar, rhs: Self) -> Self {
11871204
return Self(repeating: lhs) - rhs
11881205
}
11891206

11901207
@_transparent
1208+
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
1209+
where Self : Differentiable,
1210+
Self.CotangentVector : SIMD,
1211+
Scalar : BinaryFloatingPoint & Differentiable,
1212+
Self.CotangentVector == Self,
1213+
Scalar.CotangentVector == Scalar)
11911214
public static func *(lhs: Scalar, rhs: Self) -> Self {
11921215
return Self(repeating: lhs) * rhs
11931216
}
11941217

11951218
@_transparent
1219+
@differentiable(vjp: _vjpDivide(lhs:rhs:)
1220+
where Self : Differentiable,
1221+
Self.CotangentVector : SIMD,
1222+
Scalar : BinaryFloatingPoint & Differentiable,
1223+
Self.CotangentVector == Self,
1224+
Scalar.CotangentVector == Scalar)
11961225
public static func /(lhs: Scalar, rhs: Self) -> Self {
11971226
return Self(repeating: lhs) / rhs
11981227
}
11991228

12001229
@_transparent
12011230
@differentiable(vjp: _vjpAdd(lhs:rhs:)
1202-
where Scalar : SIMDDifferentiable,
1203-
Self : SIMDDifferentiable)
1231+
where Self: Differentiable,
1232+
Self.CotangentVector: SIMD,
1233+
Scalar : Differentiable & BinaryFloatingPoint,
1234+
Scalar.CotangentVector: BinaryFloatingPoint,
1235+
Self.CotangentVector.Scalar == Scalar.CotangentVector)
12041236
public static func +(lhs: Self, rhs: Scalar) -> Self {
12051237
return lhs + Self(repeating: rhs)
12061238
}
12071239

12081240
@_transparent
1209-
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
1210-
where Scalar : SIMDDifferentiable,
1211-
Self : SIMDDifferentiable)
1241+
@differentiable(vjp: _vjpAdd(lhs:rhs:)
1242+
where Self: Differentiable,
1243+
Self.CotangentVector: SIMD,
1244+
Scalar : Differentiable & BinaryFloatingPoint,
1245+
Scalar.CotangentVector: BinaryFloatingPoint,
1246+
Self.CotangentVector.Scalar == Scalar.CotangentVector)
12121247
public static func -(lhs: Self, rhs: Scalar) -> Self {
12131248
return lhs - Self(repeating: rhs)
12141249
}
12151250

12161251
@_transparent
1252+
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
1253+
where Self : Differentiable,
1254+
Self.CotangentVector : SIMD,
1255+
Scalar : BinaryFloatingPoint & Differentiable,
1256+
Self.CotangentVector == Self,
1257+
Scalar.CotangentVector == Scalar)
12171258
public static func *(lhs: Self, rhs: Scalar) -> Self {
12181259
return lhs * Self(repeating: rhs)
12191260
}
12201261

12211262
@_transparent
1263+
@differentiable(vjp: _vjpDivide(lhs:rhs:)
1264+
where Self : Differentiable,
1265+
Self.CotangentVector : SIMD,
1266+
Scalar : BinaryFloatingPoint & Differentiable,
1267+
Self.CotangentVector == Self,
1268+
Scalar.CotangentVector == Scalar)
12221269
public static func /(lhs: Self, rhs: Scalar) -> Self {
12231270
return lhs / Self(repeating: rhs)
12241271
}
@@ -1415,18 +1462,18 @@ where T: SIMD, T.Scalar: FloatingPoint {
14151462
return result
14161463
}
14171464

1418-
public protocol SIMDDifferentiable : Differentiable
1419-
where Self == Self.TangentVector,
1420-
Self == Self.CotangentVector,
1421-
Self == Self.AllDifferentiableVariables {}
1422-
14231465
extension SIMD
1424-
where Self : SIMDDifferentiable,
1425-
Scalar : FloatingPoint {
1466+
where Self: Differentiable,
1467+
CotangentVector: SIMD,
1468+
Scalar : BinaryFloatingPoint,
1469+
/* Required in order to use unary negation operator due to following error:
1470+
>Self.CotangentVector.Scalar' does not conform to protocol 'FloatingPoint'
1471+
*/
1472+
CotangentVector.Scalar: BinaryFloatingPoint {
14261473
@inlinable
14271474
static func _vjpAdd(
14281475
lhs: Self, rhs: Self
1429-
) -> (Self, (Self) -> (Self, Self)) {
1476+
) -> (Self, (CotangentVector) -> (CotangentVector, CotangentVector)) {
14301477
return (lhs + rhs, { v in
14311478
return (v, v)
14321479
})
@@ -1435,73 +1482,133 @@ extension SIMD
14351482
@inlinable
14361483
static func _vjpSubtract(
14371484
lhs: Self, rhs: Self
1438-
) -> (Self, (Self) -> (Self, Self)) {
1439-
return (lhs - rhs, { v in
1485+
) -> (Self, (CotangentVector) -> (CotangentVector, CotangentVector)) {
1486+
return (lhs - rhs, { (v: CotangentVector) in
14401487
return (v, -v)
14411488
})
14421489
}
14431490
}
14441491

14451492
extension SIMD
1446-
where Self : SIMDDifferentiable,
1447-
Scalar : FloatingPoint & SIMDDifferentiable {
1493+
where Self: Differentiable,
1494+
CotangentVector: SIMD,
1495+
// error: generic parameter 'Self' could not be inferred: return (lhs * rhs,...
1496+
Scalar : BinaryFloatingPoint,
1497+
// binary operator '*' cannot be applied to operands of type 'Self.CotangentVector' and 'Self'
1498+
Self.CotangentVector == Self {
1499+
@inlinable
1500+
static func _vjpMultiply(
1501+
lhs: Self, rhs: Self
1502+
) -> (Self, (CotangentVector) -> (CotangentVector, CotangentVector)) {
1503+
return (lhs * rhs, { (v: CotangentVector) in
1504+
return (v * rhs, v * lhs)
1505+
})
1506+
}
1507+
1508+
@inlinable
1509+
static func _vjpDivide(
1510+
lhs: Self, rhs: Self
1511+
) -> (Self, (CotangentVector) -> (CotangentVector, CotangentVector)) {
1512+
return (lhs / rhs, { (v: CotangentVector) in
1513+
(v / rhs, -lhs / (rhs * rhs) * v)
1514+
})
1515+
}
1516+
}
1517+
1518+
extension SIMD
1519+
where Self : Differentiable,
1520+
CotangentVector : SIMD,
1521+
Scalar : BinaryFloatingPoint & Differentiable,
1522+
Scalar.CotangentVector: BinaryFloatingPoint,
1523+
CotangentVector.Scalar == Scalar.CotangentVector {
14481524
@inlinable
14491525
static func _vjpAdd(
14501526
lhs: Scalar, rhs: Self
1451-
) -> (Self, (Self) -> (Scalar, Self)) {
1452-
return (lhs + rhs, { v in
1527+
) -> (Self, (CotangentVector) -> (Scalar.CotangentVector, CotangentVector)) {
1528+
return (lhs + rhs, { (v: CotangentVector) in
14531529
return (v.sum(), v)
14541530
})
14551531
}
1456-
1532+
14571533
@inlinable
14581534
static func _vjpSubtract(
14591535
lhs: Scalar, rhs: Self
1460-
) -> (Self, (Self) -> (Scalar, Self)) {
1461-
return (lhs + rhs, { v in
1536+
) -> (Self, (CotangentVector) -> (Scalar.CotangentVector, CotangentVector)) {
1537+
return (lhs + rhs, { (v: CotangentVector) in
14621538
return (v.sum(), -v)
14631539
})
14641540
}
1465-
1541+
14661542
@inlinable
14671543
static func _vjpAdd(
14681544
lhs: Self, rhs: Scalar
1469-
) -> (Self, (Self) -> (Self, Scalar)) {
1470-
return (lhs + rhs, { v in
1545+
) -> (Self, (CotangentVector) -> (CotangentVector, Scalar.CotangentVector)) {
1546+
return (lhs + rhs, { (v: CotangentVector) in
14711547
return (v, v.sum())
14721548
})
14731549
}
1474-
1550+
14751551
@inlinable
14761552
static func _vjpSubtract(
14771553
lhs: Self, rhs: Scalar
1478-
) -> (Self, (Self) -> (Self, Scalar)) {
1479-
return (lhs + rhs, { v in
1554+
) -> (Self, (CotangentVector) -> (CotangentVector, Scalar.CotangentVector)) {
1555+
return (lhs + rhs, { (v: CotangentVector) in
14801556
return (v, -v.sum())
14811557
})
14821558
}
14831559
}
14841560

1485-
14861561
extension SIMD
1487-
where Self: SIMDDifferentiable,
1488-
Scalar: SIMDDifferentiable & FloatingPoint {
1489-
public static func _vjpSum() -> (Scalar, (Scalar) -> Self) {
1490-
return (self.sum(), { v in SIMD(repeating: v) })
1562+
where Self : Differentiable,
1563+
CotangentVector : SIMD,
1564+
Scalar : BinaryFloatingPoint & Differentiable,
1565+
Self.CotangentVector == Self,
1566+
Scalar.CotangentVector == Scalar {
1567+
@inlinable
1568+
static func _vjpMultiply(
1569+
lhs: Self, rhs: Scalar
1570+
) -> (Self, (CotangentVector) -> (CotangentVector, Scalar.CotangentVector)) {
1571+
return (lhs * rhs, { (v: CotangentVector) in
1572+
return (v * rhs, (v * lhs).sum())
1573+
})
1574+
}
1575+
1576+
@inlinable
1577+
static func _vjpDivide(
1578+
lhs: Self, rhs: Scalar
1579+
) -> (Self, (CotangentVector) -> (CotangentVector, Scalar.CotangentVector)) {
1580+
return (lhs / rhs, { (v: CotangentVector) in
1581+
(-lhs / (rhs * rhs) * v, (v / rhs).sum())
1582+
})
1583+
}
1584+
1585+
@inlinable
1586+
static func _vjpMultiply(
1587+
lhs: Scalar, rhs: Self
1588+
) -> (Self, (CotangentVector) -> (Scalar.CotangentVector, CotangentVector)) {
1589+
return (lhs * rhs, { (v: CotangentVector) in
1590+
return ((v * lhs).sum(), v * rhs)
1591+
})
1592+
}
1593+
1594+
@inlinable
1595+
static func _vjpDivide(
1596+
lhs: Scalar, rhs: Self
1597+
) -> (Self, (CotangentVector) -> (Scalar.CotangentVector, CotangentVector)) {
1598+
return (lhs / rhs, { (v: CotangentVector) in
1599+
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
1600+
})
14911601
}
14921602
}
14931603

1494-
//extension SIMD
1495-
// where Self : SIMDDifferentiable,
1496-
// Scalar : SIMDDifferentiable & SIMDScalar {
1497-
// public func _vjpSubscript<Index>(index: SIMD2<Index>) ->
1498-
// (SIMD2<Scalar>, (SIMD2<Self.Scalar>) -> (CotangentVector, SIMD2<Index>)) where Index: FixedWidthInteger & SIMDScalar
1499-
// {
1500-
// func pullback(_ v: SIMD2<Self.Scalar>) -> (CotangentVector, SIMD2<Index>) {
1501-
// var gradientOut = SIMD(repeating: 0)
1502-
// gradientOut[index] = gradientIn
1503-
// return (CotangentVector(gradientOut), index)
1504-
// }
1505-
// return (self[index], pullback)
1506-
// }
1507-
//}
1604+
extension SIMD
1605+
where Self : Differentiable,
1606+
CotangentVector : SIMD,
1607+
Scalar : BinaryFloatingPoint & Differentiable,
1608+
Scalar.CotangentVector : BinaryFloatingPoint,
1609+
CotangentVector == Self {
1610+
@usableFromInline
1611+
func _vjpSum() -> (Scalar, (Scalar.CotangentVector) -> CotangentVector) {
1612+
return (sum(), { (v: Scalar.CotangentVector) in Self.init(repeating: Scalar(v)) })
1613+
}
1614+
}

0 commit comments

Comments
 (0)