Skip to content

Commit 3417a08

Browse files
authored
[TF] TF-496: Conditionally conform SIMD types to Differentiable (#24786)
- Add protocol and struct gyb file conformances to make SIMD Differentiable - Add tests TODO: - Figure out how to add back the default protocol implementation of `-=`/`+=` due to `AdditiveArithmetic` default implementation - Had to remove `@_alwaysEmitIntoClient` from `sum()` due to bug (filed as TF-545) - Remove some of the redundant custom VJPs which are blocked by a bug (filed as TF-547)
1 parent c1211a3 commit 3417a08

File tree

3 files changed

+567
-13
lines changed

3 files changed

+567
-13
lines changed

stdlib/public/core/SIMDVector.swift

Lines changed: 261 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ extension SIMD {
9090

9191
/// A vector with the specified value in all lanes.
9292
@_transparent
93+
// SWIFT_ENABLE_TENSORFLOW
94+
@differentiable(vjp: _vjpInit(repeating:)
95+
where Self : Differentiable,
96+
Self.TangentVector : SIMD,
97+
Scalar : BinaryFloatingPoint & Differentiable,
98+
Self.TangentVector == Self,
99+
Scalar.TangentVector == Scalar)
93100
public init(repeating value: Scalar) {
94101
self.init()
95102
for i in indices { self[i] = value }
@@ -779,29 +786,53 @@ extension SIMD where Scalar: FixedWidthInteger {
779786

780787
// Implementations of floating-point operations. These should eventually all
781788
// be replaced with @_semantics to lower directly to vector IR nodes.
782-
extension SIMD where Scalar: FloatingPoint {
783-
@_transparent
789+
extension SIMD where Scalar : FloatingPoint {
790+
@_transparent
791+
// SWIFT_ENABLE_TENSORFLOW
792+
@differentiable(vjp: _vjpAdd(lhs:rhs:)
793+
where Self : Differentiable,
794+
Self.TangentVector : SIMD,
795+
Scalar : BinaryFloatingPoint,
796+
Self.TangentVector.Scalar : BinaryFloatingPoint)
784797
public static func +(lhs: Self, rhs: Self) -> Self {
785798
var result = Self()
786799
for i in result.indices { result[i] = lhs[i] + rhs[i] }
787800
return result
788801
}
789802

790803
@_transparent
804+
// SWIFT_ENABLE_TENSORFLOW
805+
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
806+
where Self : Differentiable,
807+
Self.TangentVector : SIMD,
808+
Scalar : BinaryFloatingPoint,
809+
Self.TangentVector.Scalar : BinaryFloatingPoint)
791810
public static func -(lhs: Self, rhs: Self) -> Self {
792811
var result = Self()
793812
for i in result.indices { result[i] = lhs[i] - rhs[i] }
794813
return result
795814
}
796815

797816
@_transparent
817+
// SWIFT_ENABLE_TENSORFLOW
818+
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
819+
where Self : Differentiable,
820+
Self.TangentVector : SIMD,
821+
Scalar : BinaryFloatingPoint,
822+
Self.TangentVector == Self)
798823
public static func *(lhs: Self, rhs: Self) -> Self {
799824
var result = Self()
800825
for i in result.indices { result[i] = lhs[i] * rhs[i] }
801826
return result
802827
}
803828

804829
@_transparent
830+
// SWIFT_ENABLE_TENSORFLOW
831+
@differentiable(vjp: _vjpDivide(lhs:rhs:)
832+
where Self : Differentiable,
833+
Self.TangentVector : SIMD,
834+
Scalar : BinaryFloatingPoint,
835+
Self.TangentVector == Self)
805836
public static func /(lhs: Self, rhs: Self) -> Self {
806837
var result = Self()
807838
for i in result.indices { result[i] = lhs[i] / rhs[i] }
@@ -842,7 +873,16 @@ extension SIMD where Scalar: FloatingPoint {
842873
}
843874

844875
/// Returns the sum of the scalars in the vector.
845-
@_alwaysEmitIntoClient
876+
// SWIFT_ENABLE_TENSORFLOW
877+
// FIXME: TF-545 we want the sum() func to be marked as
878+
// `@_alwaysEmitIntoClient` like before when we define the VJP
879+
@inlinable
880+
@differentiable(vjp: _vjpSum
881+
where Self : Differentiable,
882+
Self.TangentVector : SIMD,
883+
Scalar : BinaryFloatingPoint & Differentiable,
884+
Scalar.TangentVector : BinaryFloatingPoint,
885+
Self.TangentVector == Self)
846886
public func sum() -> Scalar {
847887
// Implementation note: this eventually be defined to lower to either
848888
// llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
@@ -1157,60 +1197,112 @@ extension SIMD where Scalar: FixedWidthInteger {
11571197
extension SIMD where Scalar: FloatingPoint {
11581198

11591199
@_transparent
1200+
// SWIFT_ENABLE_TENSORFLOW
1201+
@differentiable(vjp: _vjpNegate(rhs:)
1202+
where Self : Differentiable,
1203+
Self.TangentVector : SIMD,
1204+
Scalar : BinaryFloatingPoint,
1205+
Self.TangentVector.Scalar : BinaryFloatingPoint)
11601206
public static prefix func -(rhs: Self) -> Self {
11611207
return 0 - rhs
11621208
}
11631209

11641210
@_transparent
1211+
// SWIFT_ENABLE_TENSORFLOW
1212+
@differentiable(vjp: _vjpAdd(lhs:rhs:)
1213+
where Self : Differentiable,
1214+
Self.TangentVector : SIMD,
1215+
Scalar : Differentiable & BinaryFloatingPoint,
1216+
Scalar.TangentVector : BinaryFloatingPoint,
1217+
Self.TangentVector.Scalar == Scalar.TangentVector)
11651218
public static func +(lhs: Scalar, rhs: Self) -> Self {
11661219
return Self(repeating: lhs) + rhs
11671220
}
11681221

11691222
@_transparent
1223+
// SWIFT_ENABLE_TENSORFLOW
1224+
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
1225+
where Self : Differentiable,
1226+
Self.TangentVector : SIMD,
1227+
Scalar : Differentiable & BinaryFloatingPoint,
1228+
Scalar.TangentVector : BinaryFloatingPoint,
1229+
Self.TangentVector.Scalar == Scalar.TangentVector)
11701230
public static func -(lhs: Scalar, rhs: Self) -> Self {
11711231
return Self(repeating: lhs) - rhs
11721232
}
11731233

11741234
@_transparent
1235+
// SWIFT_ENABLE_TENSORFLOW
1236+
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
1237+
where Self : Differentiable,
1238+
Self.TangentVector : SIMD,
1239+
Scalar : BinaryFloatingPoint & Differentiable,
1240+
Self.TangentVector == Self,
1241+
Scalar.TangentVector == Scalar)
11751242
public static func *(lhs: Scalar, rhs: Self) -> Self {
11761243
return Self(repeating: lhs) * rhs
11771244
}
11781245

11791246
@_transparent
1247+
// SWIFT_ENABLE_TENSORFLOW
1248+
@differentiable(vjp: _vjpDivide(lhs:rhs:)
1249+
where Self : Differentiable,
1250+
Self.TangentVector : SIMD,
1251+
Scalar : BinaryFloatingPoint & Differentiable,
1252+
Self.TangentVector == Self,
1253+
Scalar.TangentVector == Scalar)
11801254
public static func /(lhs: Scalar, rhs: Self) -> Self {
11811255
return Self(repeating: lhs) / rhs
11821256
}
11831257

11841258
@_transparent
1259+
// SWIFT_ENABLE_TENSORFLOW
1260+
@differentiable(vjp: _vjpAdd(lhs:rhs:)
1261+
where Self : Differentiable,
1262+
Self.TangentVector : SIMD,
1263+
Scalar : Differentiable & BinaryFloatingPoint,
1264+
Scalar.TangentVector : BinaryFloatingPoint,
1265+
Self.TangentVector.Scalar == Scalar.TangentVector)
11851266
public static func +(lhs: Self, rhs: Scalar) -> Self {
11861267
return lhs + Self(repeating: rhs)
11871268
}
11881269

11891270
@_transparent
1271+
// SWIFT_ENABLE_TENSORFLOW
1272+
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
1273+
where Self : Differentiable,
1274+
Self.TangentVector : SIMD,
1275+
Scalar : Differentiable & BinaryFloatingPoint,
1276+
Scalar.TangentVector : BinaryFloatingPoint,
1277+
Self.TangentVector.Scalar == Scalar.TangentVector)
11901278
public static func -(lhs: Self, rhs: Scalar) -> Self {
11911279
return lhs - Self(repeating: rhs)
11921280
}
11931281

11941282
@_transparent
1283+
// SWIFT_ENABLE_TENSORFLOW
1284+
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
1285+
where Self : Differentiable,
1286+
Self.TangentVector : SIMD,
1287+
Scalar : BinaryFloatingPoint & Differentiable,
1288+
Self.TangentVector == Self,
1289+
Scalar.TangentVector == Scalar)
11951290
public static func *(lhs: Self, rhs: Scalar) -> Self {
11961291
return lhs * Self(repeating: rhs)
11971292
}
11981293

11991294
@_transparent
1295+
// SWIFT_ENABLE_TENSORFLOW
1296+
@differentiable(vjp: _vjpDivide(lhs:rhs:)
1297+
where Self : Differentiable,
1298+
Self.TangentVector : SIMD,
1299+
Scalar : BinaryFloatingPoint & Differentiable,
1300+
Self.TangentVector == Self,
1301+
Scalar.TangentVector == Scalar)
12001302
public static func /(lhs: Self, rhs: Scalar) -> Self {
12011303
return lhs / Self(repeating: rhs)
12021304
}
12031305

1204-
@_transparent
1205-
public static func +=(lhs: inout Self, rhs: Self) {
1206-
lhs = lhs + rhs
1207-
}
1208-
1209-
@_transparent
1210-
public static func -=(lhs: inout Self, rhs: Self) {
1211-
lhs = lhs - rhs
1212-
}
1213-
12141306
@_transparent
12151307
public static func *=(lhs: inout Self, rhs: Self) {
12161308
lhs = lhs * rhs
@@ -1407,3 +1499,159 @@ where T: SIMD, T.Scalar: FloatingPoint {
14071499
}
14081500
return result
14091501
}
1502+
1503+
// SWIFT_ENABLE_TENSORFLOW
1504+
extension SIMD
1505+
where Self : Differentiable,
1506+
TangentVector : SIMD,
1507+
Scalar : BinaryFloatingPoint,
1508+
TangentVector.Scalar : BinaryFloatingPoint {
1509+
@inlinable
1510+
static func _vjpAdd(lhs: Self, rhs: Self)
1511+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1512+
return (lhs + rhs, { v in
1513+
return (v, v)
1514+
})
1515+
}
1516+
1517+
@inlinable
1518+
static func _vjpSubtract(lhs: Self, rhs: Self)
1519+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1520+
return (lhs - rhs, { v in
1521+
return (v, -v)
1522+
})
1523+
}
1524+
1525+
@inlinable
1526+
static func _vjpNegate(rhs: Self)
1527+
-> (Self, (TangentVector) -> (TangentVector)) {
1528+
return (-rhs, { v in
1529+
return -v
1530+
})
1531+
}
1532+
}
1533+
1534+
extension SIMD
1535+
where Self : Differentiable,
1536+
TangentVector : SIMD,
1537+
Scalar : BinaryFloatingPoint,
1538+
Self.TangentVector == Self {
1539+
@inlinable
1540+
static func _vjpMultiply(lhs: Self, rhs: Self)
1541+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1542+
return (lhs * rhs, { v in
1543+
return (v * rhs, v * lhs)
1544+
})
1545+
}
1546+
1547+
@inlinable
1548+
static func _vjpDivide(lhs: Self, rhs: Self)
1549+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1550+
return (lhs / rhs, { v in
1551+
(v / rhs, -lhs / (rhs * rhs) * v)
1552+
})
1553+
}
1554+
}
1555+
1556+
extension SIMD
1557+
where Self : Differentiable,
1558+
TangentVector : SIMD,
1559+
Scalar : BinaryFloatingPoint & Differentiable,
1560+
Scalar.TangentVector : BinaryFloatingPoint,
1561+
TangentVector.Scalar == Scalar.TangentVector {
1562+
@inlinable
1563+
static func _vjpAdd(lhs: Scalar, rhs: Self)
1564+
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1565+
return (lhs + rhs, { v in
1566+
return (v.sum(), v)
1567+
})
1568+
}
1569+
1570+
@inlinable
1571+
static func _vjpSubtract(lhs: Scalar, rhs: Self)
1572+
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1573+
return (lhs - rhs, { v in
1574+
return (v.sum(), -v)
1575+
})
1576+
}
1577+
1578+
@inlinable
1579+
static func _vjpAdd(lhs: Self, rhs: Scalar)
1580+
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1581+
return (lhs + rhs, { v in
1582+
return (v, v.sum())
1583+
})
1584+
}
1585+
1586+
@inlinable
1587+
static func _vjpSubtract(lhs: Self, rhs: Scalar)
1588+
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1589+
return (lhs - rhs, { v in
1590+
return (v, -v.sum())
1591+
})
1592+
}
1593+
}
1594+
1595+
extension SIMD
1596+
where Self : Differentiable,
1597+
TangentVector : SIMD,
1598+
Scalar : BinaryFloatingPoint & Differentiable,
1599+
Self.TangentVector == Self,
1600+
Scalar.TangentVector == Scalar {
1601+
@inlinable
1602+
static func _vjpMultiply(lhs: Self, rhs: Scalar)
1603+
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1604+
return (lhs * rhs, { v in
1605+
return (v * rhs, (v * lhs).sum())
1606+
})
1607+
}
1608+
1609+
@inlinable
1610+
static func _vjpDivide(lhs: Self, rhs: Scalar)
1611+
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1612+
return (lhs / rhs, { v in
1613+
(v / rhs, (-lhs / (rhs * rhs) * v).sum())
1614+
})
1615+
}
1616+
1617+
@inlinable
1618+
static func _vjpMultiply(lhs: Scalar, rhs: Self)
1619+
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1620+
return (lhs * rhs, { v in
1621+
return ((v * rhs).sum(), v * lhs)
1622+
})
1623+
}
1624+
1625+
@inlinable
1626+
static func _vjpDivide(lhs: Scalar, rhs: Self)
1627+
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1628+
return (lhs / rhs, { v in
1629+
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
1630+
})
1631+
}
1632+
}
1633+
1634+
extension SIMD
1635+
where Self : Differentiable,
1636+
TangentVector : SIMD,
1637+
Scalar : BinaryFloatingPoint & Differentiable,
1638+
Scalar.TangentVector : BinaryFloatingPoint,
1639+
TangentVector == Self {
1640+
@inlinable
1641+
func _vjpSum() -> (Scalar, (Scalar.TangentVector) -> TangentVector) {
1642+
return (sum(), { v in Self(repeating: Scalar(v)) })
1643+
}
1644+
}
1645+
1646+
extension SIMD
1647+
where Self : Differentiable,
1648+
Self.TangentVector : SIMD,
1649+
Scalar : BinaryFloatingPoint & Differentiable,
1650+
Self.TangentVector == Self,
1651+
Scalar.TangentVector == Scalar {
1652+
@usableFromInline
1653+
static func _vjpInit(repeating value: Scalar)
1654+
-> (Self, (TangentVector) -> Scalar.TangentVector) {
1655+
return (Self(repeating: value), { v in v.sum() })
1656+
}
1657+
}

0 commit comments

Comments
 (0)