@@ -90,13 +90,6 @@ extension SIMD {
90
90
91
91
/// A vector with the specified value in all lanes.
92
92
@_transparent
93
- // SWIFT_ENABLE_TENSORFLOW
94
- @differentiable ( vjp: _vjpInit ( repeating: )
95
- where Self : Differentiable,
96
- Self . TangentVector : SIMD,
97
- Scalar : BinaryFloatingPoint & Differentiable,
98
- Self . TangentVector == Self,
99
- Scalar . TangentVector == Scalar)
100
93
public init ( repeating value: Scalar ) {
101
94
self . init ( )
102
95
for i in indices { self [ i] = value }
@@ -788,51 +781,27 @@ extension SIMD where Scalar: FixedWidthInteger {
788
781
// be replaced with @_semantics to lower directly to vector IR nodes.
789
782
extension SIMD where Scalar : FloatingPoint {
790
783
@_transparent
791
- // SWIFT_ENABLE_TENSORFLOW
792
- @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
793
- where Self : Differentiable,
794
- Self . TangentVector : SIMD,
795
- Scalar : BinaryFloatingPoint,
796
- Self . TangentVector. Scalar : BinaryFloatingPoint)
797
784
public static func + ( lhs: Self , rhs: Self ) -> Self {
798
785
var result = Self ( )
799
786
for i in result. indices { result [ i] = lhs [ i] + rhs[ i] }
800
787
return result
801
788
}
802
789
803
790
@_transparent
804
- // SWIFT_ENABLE_TENSORFLOW
805
- @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
806
- where Self : Differentiable,
807
- Self . TangentVector : SIMD,
808
- Scalar : BinaryFloatingPoint,
809
- Self . TangentVector. Scalar : BinaryFloatingPoint)
810
791
public static func - ( lhs: Self , rhs: Self ) -> Self {
811
792
var result = Self ( )
812
793
for i in result. indices { result [ i] = lhs [ i] - rhs[ i] }
813
794
return result
814
795
}
815
796
816
797
@_transparent
817
- // SWIFT_ENABLE_TENSORFLOW
818
- @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
819
- where Self : Differentiable,
820
- Self . TangentVector : SIMD,
821
- Scalar : BinaryFloatingPoint,
822
- Self . TangentVector == Self)
823
798
public static func * ( lhs: Self , rhs: Self ) -> Self {
824
799
var result = Self ( )
825
800
for i in result. indices { result [ i] = lhs [ i] * rhs[ i] }
826
801
return result
827
802
}
828
803
829
804
@_transparent
830
- // SWIFT_ENABLE_TENSORFLOW
831
- @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
832
- where Self : Differentiable,
833
- Self . TangentVector : SIMD,
834
- Scalar : BinaryFloatingPoint,
835
- Self . TangentVector == Self)
836
805
public static func / ( lhs: Self , rhs: Self ) -> Self {
837
806
var result = Self ( )
838
807
for i in result. indices { result [ i] = lhs [ i] / rhs[ i] }
@@ -877,12 +846,6 @@ extension SIMD where Scalar : FloatingPoint {
877
846
// FIXME: TF-545 we want the sum() func to be marked as
878
847
// `@_alwaysEmitIntoClient` like before when we define the VJP
879
848
@inlinable
880
- @differentiable ( vjp: _vjpSum
881
- where Self : Differentiable,
882
- Self . TangentVector : SIMD,
883
- Scalar : BinaryFloatingPoint & Differentiable,
884
- Scalar . TangentVector : BinaryFloatingPoint,
885
- Self . TangentVector == Self)
886
849
public func sum( ) -> Scalar {
887
850
// Implementation note: this eventually be defined to lower to either
888
851
// llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
@@ -1197,108 +1160,46 @@ extension SIMD where Scalar: FixedWidthInteger {
1197
1160
extension SIMD where Scalar: FloatingPoint {
1198
1161
1199
1162
@_transparent
1200
- // SWIFT_ENABLE_TENSORFLOW
1201
- @differentiable ( vjp: _vjpNegate ( rhs: )
1202
- where Self : Differentiable,
1203
- Self . TangentVector : SIMD,
1204
- Scalar : BinaryFloatingPoint,
1205
- Self . TangentVector. Scalar : BinaryFloatingPoint)
1206
1163
public static prefix func - ( rhs: Self ) -> Self {
1207
1164
return 0 - rhs
1208
1165
}
1209
1166
1210
1167
@_transparent
1211
- // SWIFT_ENABLE_TENSORFLOW
1212
- @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1213
- where Self : Differentiable,
1214
- Self . TangentVector : SIMD,
1215
- Scalar : Differentiable & BinaryFloatingPoint,
1216
- Scalar . TangentVector : BinaryFloatingPoint,
1217
- Self . TangentVector. Scalar == Scalar . TangentVector)
1218
1168
public static func + ( lhs: Scalar , rhs: Self ) -> Self {
1219
1169
return Self ( repeating: lhs) + rhs
1220
1170
}
1221
1171
1222
1172
@_transparent
1223
- // SWIFT_ENABLE_TENSORFLOW
1224
- @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
1225
- where Self : Differentiable,
1226
- Self . TangentVector : SIMD,
1227
- Scalar : Differentiable & BinaryFloatingPoint,
1228
- Scalar . TangentVector : BinaryFloatingPoint,
1229
- Self . TangentVector. Scalar == Scalar . TangentVector)
1230
1173
public static func - ( lhs: Scalar , rhs: Self ) -> Self {
1231
1174
return Self ( repeating: lhs) - rhs
1232
1175
}
1233
1176
1234
1177
@_transparent
1235
- // SWIFT_ENABLE_TENSORFLOW
1236
- @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
1237
- where Self : Differentiable,
1238
- Self . TangentVector : SIMD,
1239
- Scalar : BinaryFloatingPoint & Differentiable,
1240
- Self . TangentVector == Self,
1241
- Scalar . TangentVector == Scalar)
1242
1178
public static func * ( lhs: Scalar , rhs: Self ) -> Self {
1243
1179
return Self ( repeating: lhs) * rhs
1244
1180
}
1245
1181
1246
1182
@_transparent
1247
- // SWIFT_ENABLE_TENSORFLOW
1248
- @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1249
- where Self : Differentiable,
1250
- Self . TangentVector : SIMD,
1251
- Scalar : BinaryFloatingPoint & Differentiable,
1252
- Self . TangentVector == Self,
1253
- Scalar . TangentVector == Scalar)
1254
1183
public static func / ( lhs: Scalar , rhs: Self ) -> Self {
1255
1184
return Self ( repeating: lhs) / rhs
1256
1185
}
1257
1186
1258
1187
@_transparent
1259
- // SWIFT_ENABLE_TENSORFLOW
1260
- @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1261
- where Self : Differentiable,
1262
- Self . TangentVector : SIMD,
1263
- Scalar : Differentiable & BinaryFloatingPoint,
1264
- Scalar . TangentVector : BinaryFloatingPoint,
1265
- Self . TangentVector. Scalar == Scalar . TangentVector)
1266
1188
public static func + ( lhs: Self , rhs: Scalar ) -> Self {
1267
1189
return lhs + Self( repeating: rhs)
1268
1190
}
1269
1191
1270
1192
@_transparent
1271
- // SWIFT_ENABLE_TENSORFLOW
1272
- @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
1273
- where Self : Differentiable,
1274
- Self . TangentVector : SIMD,
1275
- Scalar : Differentiable & BinaryFloatingPoint,
1276
- Scalar . TangentVector : BinaryFloatingPoint,
1277
- Self . TangentVector. Scalar == Scalar . TangentVector)
1278
1193
public static func - ( lhs: Self , rhs: Scalar ) -> Self {
1279
1194
return lhs - Self( repeating: rhs)
1280
1195
}
1281
1196
1282
1197
@_transparent
1283
- // SWIFT_ENABLE_TENSORFLOW
1284
- @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
1285
- where Self : Differentiable,
1286
- Self . TangentVector : SIMD,
1287
- Scalar : BinaryFloatingPoint & Differentiable,
1288
- Self . TangentVector == Self,
1289
- Scalar . TangentVector == Scalar)
1290
1198
public static func * ( lhs: Self , rhs: Scalar ) -> Self {
1291
1199
return lhs * Self( repeating: rhs)
1292
1200
}
1293
1201
1294
1202
@_transparent
1295
- // SWIFT_ENABLE_TENSORFLOW
1296
- @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1297
- where Self : Differentiable,
1298
- Self . TangentVector : SIMD,
1299
- Scalar : BinaryFloatingPoint & Differentiable,
1300
- Self . TangentVector == Self,
1301
- Scalar . TangentVector == Scalar)
1302
1203
public static func / ( lhs: Self , rhs: Scalar ) -> Self {
1303
1204
return lhs / Self( repeating: rhs)
1304
1205
}
@@ -1520,24 +1421,27 @@ extension SIMD
1520
1421
Scalar : BinaryFloatingPoint ,
1521
1422
TangentVector. Scalar : BinaryFloatingPoint {
1522
1423
@inlinable
1424
+ @derivative ( of: + )
1523
1425
static func _vjpAdd( lhs: Self , rhs: Self )
1524
- -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1426
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1525
1427
return ( lhs + rhs, { v in
1526
1428
return ( v, v)
1527
1429
} )
1528
1430
}
1529
1431
1530
1432
@inlinable
1433
+ @derivative ( of: - )
1531
1434
static func _vjpSubtract( lhs: Self , rhs: Self )
1532
- -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1435
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1533
1436
return ( lhs - rhs, { v in
1534
1437
return ( v, - v)
1535
1438
} )
1536
1439
}
1537
1440
1538
1441
@inlinable
1442
+ @derivative ( of: - )
1539
1443
static func _vjpNegate( rhs: Self )
1540
- -> ( Self , ( TangentVector ) -> ( TangentVector ) ) {
1444
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector ) ) {
1541
1445
return ( - rhs, { v in
1542
1446
return - v
1543
1447
} )
@@ -1550,16 +1454,18 @@ extension SIMD
1550
1454
Scalar : BinaryFloatingPoint ,
1551
1455
Self. TangentVector == Self {
1552
1456
@inlinable
1457
+ @derivative ( of: * )
1553
1458
static func _vjpMultiply( lhs: Self , rhs: Self )
1554
- -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1459
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1555
1460
return ( lhs * rhs, { v in
1556
1461
return ( v * rhs, v * lhs)
1557
1462
} )
1558
1463
}
1559
1464
1560
1465
@inlinable
1466
+ @derivative ( of: / )
1561
1467
static func _vjpDivide( lhs: Self , rhs: Self )
1562
- -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1468
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1563
1469
return ( lhs / rhs, { v in
1564
1470
( v / rhs, - lhs / ( rhs * rhs) * v)
1565
1471
} )
@@ -1573,32 +1479,36 @@ extension SIMD
1573
1479
Scalar. TangentVector : BinaryFloatingPoint ,
1574
1480
TangentVector. Scalar == Scalar . TangentVector {
1575
1481
@inlinable
1482
+ @derivative ( of: + )
1576
1483
static func _vjpAdd( lhs: Scalar , rhs: Self )
1577
- -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1484
+ -> ( value : Self , pullback : ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1578
1485
return ( lhs + rhs, { v in
1579
1486
return ( v. sum ( ) , v)
1580
1487
} )
1581
1488
}
1582
1489
1583
1490
@inlinable
1491
+ @derivative ( of: - )
1584
1492
static func _vjpSubtract( lhs: Scalar , rhs: Self )
1585
- -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1493
+ -> ( value : Self , pullback : ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1586
1494
return ( lhs - rhs, { v in
1587
1495
return ( v. sum ( ) , - v)
1588
1496
} )
1589
1497
}
1590
1498
1591
1499
@inlinable
1500
+ @derivative ( of: + )
1592
1501
static func _vjpAdd( lhs: Self , rhs: Scalar )
1593
- -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1502
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1594
1503
return ( lhs + rhs, { v in
1595
1504
return ( v, v. sum ( ) )
1596
1505
} )
1597
1506
}
1598
1507
1599
1508
@inlinable
1509
+ @derivative ( of: - )
1600
1510
static func _vjpSubtract( lhs: Self , rhs: Scalar )
1601
- -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1511
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1602
1512
return ( lhs - rhs, { v in
1603
1513
return ( v, - v. sum ( ) )
1604
1514
} )
@@ -1612,32 +1522,36 @@ extension SIMD
1612
1522
Self. TangentVector == Self ,
1613
1523
Scalar. TangentVector == Scalar {
1614
1524
@inlinable
1525
+ @derivative ( of: * )
1615
1526
static func _vjpMultiply( lhs: Self , rhs: Scalar )
1616
- -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1527
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1617
1528
return ( lhs * rhs, { v in
1618
1529
return ( v * rhs, ( v * lhs) . sum ( ) )
1619
1530
} )
1620
1531
}
1621
1532
1622
1533
@inlinable
1534
+ @derivative ( of: / )
1623
1535
static func _vjpDivide( lhs: Self , rhs: Scalar )
1624
- -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1536
+ -> ( value : Self , pullback : ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1625
1537
return ( lhs / rhs, { v in
1626
1538
( v / rhs, ( - lhs / ( rhs * rhs) * v) . sum ( ) )
1627
1539
} )
1628
1540
}
1629
1541
1630
1542
@inlinable
1543
+ @derivative ( of: * )
1631
1544
static func _vjpMultiply( lhs: Scalar , rhs: Self )
1632
- -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1545
+ -> ( value : Self , pullback : ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1633
1546
return ( lhs * rhs, { v in
1634
1547
return ( ( v * rhs) . sum ( ) , v * lhs)
1635
1548
} )
1636
1549
}
1637
1550
1638
1551
@inlinable
1552
+ @derivative ( of: / )
1639
1553
static func _vjpDivide( lhs: Scalar , rhs: Self )
1640
- -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1554
+ -> ( value : Self , pullback : ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1641
1555
return ( lhs / rhs, { v in
1642
1556
( ( v / rhs) . sum ( ) , - lhs / ( rhs * rhs) * v)
1643
1557
} )
@@ -1651,7 +1565,8 @@ extension SIMD
1651
1565
Scalar. TangentVector : BinaryFloatingPoint ,
1652
1566
TangentVector == Self {
1653
1567
@inlinable
1654
- func _vjpSum( ) -> ( Scalar , ( Scalar . TangentVector ) -> TangentVector ) {
1568
+ @derivative ( of: sum)
1569
+ func _vjpSum( ) -> ( value: Scalar , pullback: ( Scalar . TangentVector ) -> TangentVector ) {
1655
1570
return ( sum ( ) , { v in Self ( repeating: Scalar ( v) ) } )
1656
1571
}
1657
1572
}
@@ -1663,8 +1578,9 @@ extension SIMD
1663
1578
Self. TangentVector == Self ,
1664
1579
Scalar. TangentVector == Scalar {
1665
1580
@usableFromInline
1581
+ @derivative ( of: init ( repeating: ) )
1666
1582
static func _vjpInit( repeating value: Scalar )
1667
- -> ( Self , ( TangentVector ) -> Scalar . TangentVector ) {
1583
+ -> ( value : Self , pullback : ( TangentVector ) -> Scalar . TangentVector ) {
1668
1584
return ( Self ( repeating: value) , { v in v. sum ( ) } )
1669
1585
}
1670
1586
}
0 commit comments