@@ -90,6 +90,13 @@ extension SIMD {
90
90
91
91
/// A vector with the specified value in all lanes.
92
92
@_transparent
93
+ // SWIFT_ENABLE_TENSORFLOW
94
+ @differentiable ( vjp: _vjpInit ( repeating: )
95
+ where Self : Differentiable,
96
+ Self . TangentVector : SIMD,
97
+ Scalar : BinaryFloatingPoint & Differentiable,
98
+ Self . TangentVector == Self,
99
+ Scalar . TangentVector == Scalar)
93
100
public init ( repeating value: Scalar ) {
94
101
self . init ( )
95
102
for i in indices { self [ i] = value }
@@ -779,29 +786,53 @@ extension SIMD where Scalar: FixedWidthInteger {
779
786
780
787
// Implementations of floating-point operations. These should eventually all
781
788
// be replaced with @_semantics to lower directly to vector IR nodes.
782
- extension SIMD where Scalar: FloatingPoint {
783
- @_transparent
789
+ extension SIMD where Scalar : FloatingPoint {
790
+ @_transparent
791
+ // SWIFT_ENABLE_TENSORFLOW
792
+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
793
+ where Self : Differentiable,
794
+ Self . TangentVector : SIMD,
795
+ Scalar : BinaryFloatingPoint,
796
+ Self . TangentVector. Scalar : BinaryFloatingPoint)
784
797
public static func + ( lhs: Self , rhs: Self ) -> Self {
785
798
var result = Self ( )
786
799
for i in result. indices { result [ i] = lhs [ i] + rhs[ i] }
787
800
return result
788
801
}
789
802
790
803
@_transparent
804
+ // SWIFT_ENABLE_TENSORFLOW
805
+ @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
806
+ where Self : Differentiable,
807
+ Self . TangentVector : SIMD,
808
+ Scalar : BinaryFloatingPoint,
809
+ Self . TangentVector. Scalar : BinaryFloatingPoint)
791
810
public static func - ( lhs: Self , rhs: Self ) -> Self {
792
811
var result = Self ( )
793
812
for i in result. indices { result [ i] = lhs [ i] - rhs[ i] }
794
813
return result
795
814
}
796
815
797
816
@_transparent
817
+ // SWIFT_ENABLE_TENSORFLOW
818
+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
819
+ where Self : Differentiable,
820
+ Self . TangentVector : SIMD,
821
+ Scalar : BinaryFloatingPoint,
822
+ Self . TangentVector == Self)
798
823
public static func * ( lhs: Self , rhs: Self ) -> Self {
799
824
var result = Self ( )
800
825
for i in result. indices { result [ i] = lhs [ i] * rhs[ i] }
801
826
return result
802
827
}
803
828
804
829
@_transparent
830
+ // SWIFT_ENABLE_TENSORFLOW
831
+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
832
+ where Self : Differentiable,
833
+ Self . TangentVector : SIMD,
834
+ Scalar : BinaryFloatingPoint,
835
+ Self . TangentVector == Self)
805
836
public static func / ( lhs: Self , rhs: Self ) -> Self {
806
837
var result = Self ( )
807
838
for i in result. indices { result [ i] = lhs [ i] / rhs[ i] }
@@ -842,7 +873,16 @@ extension SIMD where Scalar: FloatingPoint {
842
873
}
843
874
844
875
/// Returns the sum of the scalars in the vector.
845
- @_alwaysEmitIntoClient
876
+ // SWIFT_ENABLE_TENSORFLOW
877
+ // FIXME: TF-545 we want the sum() func to be marked as
878
+ // `@_alwaysEmitIntoClient` like before when we define the VJP
879
+ @inlinable
880
+ @differentiable ( vjp: _vjpSum
881
+ where Self : Differentiable,
882
+ Self . TangentVector : SIMD,
883
+ Scalar : BinaryFloatingPoint & Differentiable,
884
+ Scalar . TangentVector : BinaryFloatingPoint,
885
+ Self . TangentVector == Self)
846
886
public func sum( ) -> Scalar {
847
887
// Implementation note: this eventually be defined to lower to either
848
888
// llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
@@ -1157,60 +1197,112 @@ extension SIMD where Scalar: FixedWidthInteger {
1157
1197
extension SIMD where Scalar: FloatingPoint {
1158
1198
1159
1199
@_transparent
1200
+ // SWIFT_ENABLE_TENSORFLOW
1201
+ @differentiable ( vjp: _vjpNegate ( rhs: )
1202
+ where Self : Differentiable,
1203
+ Self . TangentVector : SIMD,
1204
+ Scalar : BinaryFloatingPoint,
1205
+ Self . TangentVector. Scalar : BinaryFloatingPoint)
1160
1206
public static prefix func - ( rhs: Self ) -> Self {
1161
1207
return 0 - rhs
1162
1208
}
1163
1209
1164
1210
@_transparent
1211
+ // SWIFT_ENABLE_TENSORFLOW
1212
+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1213
+ where Self : Differentiable,
1214
+ Self . TangentVector : SIMD,
1215
+ Scalar : Differentiable & BinaryFloatingPoint,
1216
+ Scalar . TangentVector : BinaryFloatingPoint,
1217
+ Self . TangentVector. Scalar == Scalar . TangentVector)
1165
1218
public static func + ( lhs: Scalar , rhs: Self ) -> Self {
1166
1219
return Self ( repeating: lhs) + rhs
1167
1220
}
1168
1221
1169
1222
@_transparent
1223
+ // SWIFT_ENABLE_TENSORFLOW
1224
+ @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
1225
+ where Self : Differentiable,
1226
+ Self . TangentVector : SIMD,
1227
+ Scalar : Differentiable & BinaryFloatingPoint,
1228
+ Scalar . TangentVector : BinaryFloatingPoint,
1229
+ Self . TangentVector. Scalar == Scalar . TangentVector)
1170
1230
public static func - ( lhs: Scalar , rhs: Self ) -> Self {
1171
1231
return Self ( repeating: lhs) - rhs
1172
1232
}
1173
1233
1174
1234
@_transparent
1235
+ // SWIFT_ENABLE_TENSORFLOW
1236
+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
1237
+ where Self : Differentiable,
1238
+ Self . TangentVector : SIMD,
1239
+ Scalar : BinaryFloatingPoint & Differentiable,
1240
+ Self . TangentVector == Self,
1241
+ Scalar . TangentVector == Scalar)
1175
1242
public static func * ( lhs: Scalar , rhs: Self ) -> Self {
1176
1243
return Self ( repeating: lhs) * rhs
1177
1244
}
1178
1245
1179
1246
@_transparent
1247
+ // SWIFT_ENABLE_TENSORFLOW
1248
+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1249
+ where Self : Differentiable,
1250
+ Self . TangentVector : SIMD,
1251
+ Scalar : BinaryFloatingPoint & Differentiable,
1252
+ Self . TangentVector == Self,
1253
+ Scalar . TangentVector == Scalar)
1180
1254
public static func / ( lhs: Scalar , rhs: Self ) -> Self {
1181
1255
return Self ( repeating: lhs) / rhs
1182
1256
}
1183
1257
1184
1258
@_transparent
1259
+ // SWIFT_ENABLE_TENSORFLOW
1260
+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1261
+ where Self : Differentiable,
1262
+ Self . TangentVector : SIMD,
1263
+ Scalar : Differentiable & BinaryFloatingPoint,
1264
+ Scalar . TangentVector : BinaryFloatingPoint,
1265
+ Self . TangentVector. Scalar == Scalar . TangentVector)
1185
1266
public static func + ( lhs: Self , rhs: Scalar ) -> Self {
1186
1267
return lhs + Self( repeating: rhs)
1187
1268
}
1188
1269
1189
1270
@_transparent
1271
+ // SWIFT_ENABLE_TENSORFLOW
1272
+ @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
1273
+ where Self : Differentiable,
1274
+ Self . TangentVector : SIMD,
1275
+ Scalar : Differentiable & BinaryFloatingPoint,
1276
+ Scalar . TangentVector : BinaryFloatingPoint,
1277
+ Self . TangentVector. Scalar == Scalar . TangentVector)
1190
1278
public static func - ( lhs: Self , rhs: Scalar ) -> Self {
1191
1279
return lhs - Self( repeating: rhs)
1192
1280
}
1193
1281
1194
1282
@_transparent
1283
+ // SWIFT_ENABLE_TENSORFLOW
1284
+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
1285
+ where Self : Differentiable,
1286
+ Self . TangentVector : SIMD,
1287
+ Scalar : BinaryFloatingPoint & Differentiable,
1288
+ Self . TangentVector == Self,
1289
+ Scalar . TangentVector == Scalar)
1195
1290
public static func * ( lhs: Self , rhs: Scalar ) -> Self {
1196
1291
return lhs * Self( repeating: rhs)
1197
1292
}
1198
1293
1199
1294
@_transparent
1295
+ // SWIFT_ENABLE_TENSORFLOW
1296
+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1297
+ where Self : Differentiable,
1298
+ Self . TangentVector : SIMD,
1299
+ Scalar : BinaryFloatingPoint & Differentiable,
1300
+ Self . TangentVector == Self,
1301
+ Scalar . TangentVector == Scalar)
1200
1302
public static func / ( lhs: Self , rhs: Scalar ) -> Self {
1201
1303
return lhs / Self( repeating: rhs)
1202
1304
}
1203
1305
1204
- @_transparent
1205
- public static func += ( lhs: inout Self , rhs: Self ) {
1206
- lhs = lhs + rhs
1207
- }
1208
-
1209
- @_transparent
1210
- public static func -= ( lhs: inout Self , rhs: Self ) {
1211
- lhs = lhs - rhs
1212
- }
1213
-
1214
1306
@_transparent
1215
1307
public static func *= ( lhs: inout Self , rhs: Self ) {
1216
1308
lhs = lhs * rhs
@@ -1407,3 +1499,159 @@ where T: SIMD, T.Scalar: FloatingPoint {
1407
1499
}
1408
1500
return result
1409
1501
}
1502
+
1503
+ // SWIFT_ENABLE_TENSORFLOW
1504
+ extension SIMD
1505
+ where Self : Differentiable ,
1506
+ TangentVector : SIMD ,
1507
+ Scalar : BinaryFloatingPoint ,
1508
+ TangentVector. Scalar : BinaryFloatingPoint {
1509
+ @inlinable
1510
+ static func _vjpAdd( lhs: Self , rhs: Self )
1511
+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1512
+ return ( lhs + rhs, { v in
1513
+ return ( v, v)
1514
+ } )
1515
+ }
1516
+
1517
+ @inlinable
1518
+ static func _vjpSubtract( lhs: Self , rhs: Self )
1519
+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1520
+ return ( lhs - rhs, { v in
1521
+ return ( v, - v)
1522
+ } )
1523
+ }
1524
+
1525
+ @inlinable
1526
+ static func _vjpNegate( rhs: Self )
1527
+ -> ( Self , ( TangentVector ) -> ( TangentVector ) ) {
1528
+ return ( - rhs, { v in
1529
+ return - v
1530
+ } )
1531
+ }
1532
+ }
1533
+
1534
+ extension SIMD
1535
+ where Self : Differentiable ,
1536
+ TangentVector : SIMD ,
1537
+ Scalar : BinaryFloatingPoint ,
1538
+ Self. TangentVector == Self {
1539
+ @inlinable
1540
+ static func _vjpMultiply( lhs: Self , rhs: Self )
1541
+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1542
+ return ( lhs * rhs, { v in
1543
+ return ( v * rhs, v * lhs)
1544
+ } )
1545
+ }
1546
+
1547
+ @inlinable
1548
+ static func _vjpDivide( lhs: Self , rhs: Self )
1549
+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1550
+ return ( lhs / rhs, { v in
1551
+ ( v / rhs, - lhs / ( rhs * rhs) * v)
1552
+ } )
1553
+ }
1554
+ }
1555
+
1556
+ extension SIMD
1557
+ where Self : Differentiable ,
1558
+ TangentVector : SIMD ,
1559
+ Scalar : BinaryFloatingPoint & Differentiable ,
1560
+ Scalar. TangentVector : BinaryFloatingPoint ,
1561
+ TangentVector. Scalar == Scalar . TangentVector {
1562
+ @inlinable
1563
+ static func _vjpAdd( lhs: Scalar , rhs: Self )
1564
+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1565
+ return ( lhs + rhs, { v in
1566
+ return ( v. sum ( ) , v)
1567
+ } )
1568
+ }
1569
+
1570
+ @inlinable
1571
+ static func _vjpSubtract( lhs: Scalar , rhs: Self )
1572
+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1573
+ return ( lhs - rhs, { v in
1574
+ return ( v. sum ( ) , - v)
1575
+ } )
1576
+ }
1577
+
1578
+ @inlinable
1579
+ static func _vjpAdd( lhs: Self , rhs: Scalar )
1580
+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1581
+ return ( lhs + rhs, { v in
1582
+ return ( v, v. sum ( ) )
1583
+ } )
1584
+ }
1585
+
1586
+ @inlinable
1587
+ static func _vjpSubtract( lhs: Self , rhs: Scalar )
1588
+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1589
+ return ( lhs - rhs, { v in
1590
+ return ( v, - v. sum ( ) )
1591
+ } )
1592
+ }
1593
+ }
1594
+
1595
+ extension SIMD
1596
+ where Self : Differentiable ,
1597
+ TangentVector : SIMD ,
1598
+ Scalar : BinaryFloatingPoint & Differentiable ,
1599
+ Self. TangentVector == Self ,
1600
+ Scalar. TangentVector == Scalar {
1601
+ @inlinable
1602
+ static func _vjpMultiply( lhs: Self , rhs: Scalar )
1603
+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1604
+ return ( lhs * rhs, { v in
1605
+ return ( v * rhs, ( v * lhs) . sum ( ) )
1606
+ } )
1607
+ }
1608
+
1609
+ @inlinable
1610
+ static func _vjpDivide( lhs: Self , rhs: Scalar )
1611
+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1612
+ return ( lhs / rhs, { v in
1613
+ ( v / rhs, ( - lhs / ( rhs * rhs) * v) . sum ( ) )
1614
+ } )
1615
+ }
1616
+
1617
+ @inlinable
1618
+ static func _vjpMultiply( lhs: Scalar , rhs: Self )
1619
+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1620
+ return ( lhs * rhs, { v in
1621
+ return ( ( v * rhs) . sum ( ) , v * lhs)
1622
+ } )
1623
+ }
1624
+
1625
+ @inlinable
1626
+ static func _vjpDivide( lhs: Scalar , rhs: Self )
1627
+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1628
+ return ( lhs / rhs, { v in
1629
+ ( ( v / rhs) . sum ( ) , - lhs / ( rhs * rhs) * v)
1630
+ } )
1631
+ }
1632
+ }
1633
+
1634
+ extension SIMD
1635
+ where Self : Differentiable ,
1636
+ TangentVector : SIMD ,
1637
+ Scalar : BinaryFloatingPoint & Differentiable ,
1638
+ Scalar. TangentVector : BinaryFloatingPoint ,
1639
+ TangentVector == Self {
1640
+ @inlinable
1641
+ func _vjpSum( ) -> ( Scalar , ( Scalar . TangentVector ) -> TangentVector ) {
1642
+ return ( sum ( ) , { v in Self ( repeating: Scalar ( v) ) } )
1643
+ }
1644
+ }
1645
+
1646
+ extension SIMD
1647
+ where Self : Differentiable ,
1648
+ Self. TangentVector : SIMD ,
1649
+ Scalar : BinaryFloatingPoint & Differentiable ,
1650
+ Self. TangentVector == Self ,
1651
+ Scalar. TangentVector == Scalar {
1652
+ @usableFromInline
1653
+ static func _vjpInit( repeating value: Scalar )
1654
+ -> ( Self , ( TangentVector ) -> Scalar . TangentVector ) {
1655
+ return ( Self ( repeating: value) , { v in v. sum ( ) } )
1656
+ }
1657
+ }
0 commit comments