@@ -213,9 +213,6 @@ extension SIMD {
213
213
/// in this vector. Because of this, the index is always in-range and no trap
214
214
/// can occur.
215
215
@_alwaysEmitIntoClient
216
- // @differentiable(vjp: _vjpSubscript(index:)
217
- // where Self : SIMDDifferentiable,
218
- // Scalar : SIMDDifferentiable)
219
216
public subscript< Index> ( index: SIMD2 < Index > ) -> SIMD2 < Scalar >
220
217
where Index: FixedWidthInteger {
221
218
var result = SIMD2 < Scalar > ( )
@@ -785,23 +782,34 @@ extension SIMD where Scalar: FixedWidthInteger {
785
782
extension SIMD where Scalar: FloatingPoint {
786
783
@_transparent
787
784
@differentiable ( vjp: _vjpAdd ( lhs: rhs: )
788
- where Self : SIMDDifferentiable)
785
+ where Self : Differentiable,
786
+ Self . CotangentVector : SIMD,
787
+ Scalar : BinaryFloatingPoint,
788
+ Self . CotangentVector. Scalar: BinaryFloatingPoint)
789
789
public static func + ( lhs: Self , rhs: Self ) -> Self {
790
790
var result = Self ( )
791
791
for i in result. indices { result [ i] = lhs [ i] + rhs[ i] }
792
792
return result
793
793
}
794
794
795
795
@_transparent
796
- @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
797
- where Self : SIMDDifferentiable)
796
+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
797
+ where Self: Differentiable,
798
+ Self . CotangentVector: SIMD,
799
+ Scalar : BinaryFloatingPoint,
800
+ Self . CotangentVector. Scalar: BinaryFloatingPoint)
798
801
public static func - ( lhs: Self , rhs: Self ) -> Self {
799
802
var result = Self ( )
800
803
for i in result. indices { result [ i] = lhs [ i] - rhs[ i] }
801
804
return result
802
805
}
803
806
804
807
@_transparent
808
+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
809
+ where Self: Differentiable,
810
+ Self . CotangentVector: SIMD,
811
+ Scalar : BinaryFloatingPoint,
812
+ Self . CotangentVector == Self)
805
813
public static func * ( lhs: Self , rhs: Self ) -> Self {
806
814
var result = Self ( )
807
815
for i in result. indices { result [ i] = lhs [ i] * rhs[ i] }
@@ -851,8 +859,11 @@ extension SIMD where Scalar: FloatingPoint {
851
859
/// Returns the sum of the scalars in the vector.
852
860
@_alwaysEmitIntoClient
853
861
@differentiable ( vjp: _vjpSum
854
- where Self : SIMDDifferentiable,
855
- Scalar : SIMDDifferentiable)
862
+ where Self : Differentiable,
863
+ Self . CotangentVector : SIMD,
864
+ Scalar : BinaryFloatingPoint & Differentiable,
865
+ Scalar . CotangentVector : BinaryFloatingPoint,
866
+ Self . CotangentVector == Self)
856
867
public func sum( ) -> Scalar {
857
868
// Implementation note: this eventually be defined to lower to either
858
869
// llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
@@ -1166,59 +1177,95 @@ extension SIMD where Scalar: FixedWidthInteger {
1166
1177
1167
1178
extension SIMD where Scalar: FloatingPoint {
1168
1179
1169
- @_transparent // ????
1180
+ @_transparent
1170
1181
public static prefix func - ( rhs: Self ) -> Self {
1171
1182
return 0 - rhs
1172
1183
}
1173
1184
1174
1185
@_transparent
1175
1186
@differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1176
- where Scalar : SIMDDifferentiable,
1177
- Self : SIMDDifferentiable)
1187
+ where Self: Differentiable,
1188
+ Self . CotangentVector: SIMD,
1189
+ Scalar : Differentiable & BinaryFloatingPoint,
1190
+ Scalar . CotangentVector: BinaryFloatingPoint,
1191
+ Self . CotangentVector. Scalar == Scalar . CotangentVector)
1178
1192
public static func + ( lhs: Scalar , rhs: Self ) -> Self {
1179
1193
return Self ( repeating: lhs) + rhs
1180
1194
}
1181
1195
1182
1196
@_transparent
1183
1197
@differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1184
- where Scalar : SIMDDifferentiable,
1185
- Self : SIMDDifferentiable)
1198
+ where Self: Differentiable,
1199
+ Self . CotangentVector: SIMD,
1200
+ Scalar : Differentiable & BinaryFloatingPoint,
1201
+ Scalar . CotangentVector: BinaryFloatingPoint,
1202
+ Self . CotangentVector. Scalar == Scalar . CotangentVector)
1186
1203
public static func - ( lhs: Scalar , rhs: Self ) -> Self {
1187
1204
return Self ( repeating: lhs) - rhs
1188
1205
}
1189
1206
1190
1207
@_transparent
1208
+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
1209
+ where Self : Differentiable,
1210
+ Self . CotangentVector : SIMD,
1211
+ Scalar : BinaryFloatingPoint & Differentiable,
1212
+ Self . CotangentVector == Self,
1213
+ Scalar . CotangentVector == Scalar)
1191
1214
public static func * ( lhs: Scalar , rhs: Self ) -> Self {
1192
1215
return Self ( repeating: lhs) * rhs
1193
1216
}
1194
1217
1195
1218
@_transparent
1219
+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1220
+ where Self : Differentiable,
1221
+ Self . CotangentVector : SIMD,
1222
+ Scalar : BinaryFloatingPoint & Differentiable,
1223
+ Self . CotangentVector == Self,
1224
+ Scalar . CotangentVector == Scalar)
1196
1225
public static func / ( lhs: Scalar , rhs: Self ) -> Self {
1197
1226
return Self ( repeating: lhs) / rhs
1198
1227
}
1199
1228
1200
1229
@_transparent
1201
1230
@differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1202
- where Scalar : SIMDDifferentiable,
1203
- Self : SIMDDifferentiable)
1231
+ where Self: Differentiable,
1232
+ Self . CotangentVector: SIMD,
1233
+ Scalar : Differentiable & BinaryFloatingPoint,
1234
+ Scalar . CotangentVector: BinaryFloatingPoint,
1235
+ Self . CotangentVector. Scalar == Scalar . CotangentVector)
1204
1236
public static func + ( lhs: Self , rhs: Scalar ) -> Self {
1205
1237
return lhs + Self( repeating: rhs)
1206
1238
}
1207
1239
1208
1240
@_transparent
1209
- @differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
1210
- where Scalar : SIMDDifferentiable,
1211
- Self : SIMDDifferentiable)
1241
+ @differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1242
+ where Self: Differentiable,
1243
+ Self . CotangentVector: SIMD,
1244
+ Scalar : Differentiable & BinaryFloatingPoint,
1245
+ Scalar . CotangentVector: BinaryFloatingPoint,
1246
+ Self . CotangentVector. Scalar == Scalar . CotangentVector)
1212
1247
public static func - ( lhs: Self , rhs: Scalar ) -> Self {
1213
1248
return lhs - Self( repeating: rhs)
1214
1249
}
1215
1250
1216
1251
@_transparent
1252
+ @differentiable ( vjp: _vjpMultiply ( lhs: rhs: )
1253
+ where Self : Differentiable,
1254
+ Self . CotangentVector : SIMD,
1255
+ Scalar : BinaryFloatingPoint & Differentiable,
1256
+ Self . CotangentVector == Self,
1257
+ Scalar . CotangentVector == Scalar)
1217
1258
public static func * ( lhs: Self , rhs: Scalar ) -> Self {
1218
1259
return lhs * Self( repeating: rhs)
1219
1260
}
1220
1261
1221
1262
@_transparent
1263
+ @differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1264
+ where Self : Differentiable,
1265
+ Self . CotangentVector : SIMD,
1266
+ Scalar : BinaryFloatingPoint & Differentiable,
1267
+ Self . CotangentVector == Self,
1268
+ Scalar . CotangentVector == Scalar)
1222
1269
public static func / ( lhs: Self , rhs: Scalar ) -> Self {
1223
1270
return lhs / Self( repeating: rhs)
1224
1271
}
@@ -1415,18 +1462,18 @@ where T: SIMD, T.Scalar: FloatingPoint {
1415
1462
return result
1416
1463
}
1417
1464
1418
- public protocol SIMDDifferentiable : Differentiable
1419
- where Self == Self . TangentVector ,
1420
- Self == Self . CotangentVector ,
1421
- Self == Self . AllDifferentiableVariables { }
1422
-
1423
1465
extension SIMD
1424
- where Self : SIMDDifferentiable ,
1425
- Scalar : FloatingPoint {
1466
+ where Self: Differentiable ,
1467
+ CotangentVector: SIMD ,
1468
+ Scalar : BinaryFloatingPoint ,
1469
+ /* Required in order to use unary negation operator due to following error:
1470
+ >Self.CotangentVector.Scalar' does not conform to protocol 'FloatingPoint'
1471
+ */
1472
+ CotangentVector. Scalar: BinaryFloatingPoint {
1426
1473
@inlinable
1427
1474
static func _vjpAdd(
1428
1475
lhs: Self , rhs: Self
1429
- ) -> ( Self , ( Self ) -> ( Self , Self ) ) {
1476
+ ) -> ( Self , ( CotangentVector ) -> ( CotangentVector , CotangentVector ) ) {
1430
1477
return ( lhs + rhs, { v in
1431
1478
return ( v, v)
1432
1479
} )
@@ -1435,73 +1482,133 @@ extension SIMD
1435
1482
@inlinable
1436
1483
static func _vjpSubtract(
1437
1484
lhs: Self , rhs: Self
1438
- ) -> ( Self , ( Self ) -> ( Self , Self ) ) {
1439
- return ( lhs - rhs, { v in
1485
+ ) -> ( Self , ( CotangentVector ) -> ( CotangentVector , CotangentVector ) ) {
1486
+ return ( lhs - rhs, { ( v : CotangentVector ) in
1440
1487
return ( v, - v)
1441
1488
} )
1442
1489
}
1443
1490
}
1444
1491
1445
1492
extension SIMD
1446
- where Self : SIMDDifferentiable ,
1447
- Scalar : FloatingPoint & SIMDDifferentiable {
1493
+ where Self: Differentiable ,
1494
+ CotangentVector: SIMD ,
1495
+ // error: generic parameter 'Self' could not be inferred: return (lhs * rhs,...
1496
+ Scalar : BinaryFloatingPoint ,
1497
+ // binary operator '*' cannot be applied to operands of type 'Self.CotangentVector' and 'Self'
1498
+ Self. CotangentVector == Self {
1499
+ @inlinable
1500
+ static func _vjpMultiply(
1501
+ lhs: Self , rhs: Self
1502
+ ) -> ( Self , ( CotangentVector ) -> ( CotangentVector , CotangentVector ) ) {
1503
+ return ( lhs * rhs, { ( v: CotangentVector ) in
1504
+ return ( v * rhs, v * lhs)
1505
+ } )
1506
+ }
1507
+
1508
+ @inlinable
1509
+ static func _vjpDivide(
1510
+ lhs: Self , rhs: Self
1511
+ ) -> ( Self , ( CotangentVector ) -> ( CotangentVector , CotangentVector ) ) {
1512
+ return ( lhs / rhs, { ( v: CotangentVector ) in
1513
+ ( v / rhs, - lhs / ( rhs * rhs) * v)
1514
+ } )
1515
+ }
1516
+ }
1517
+
1518
+ extension SIMD
1519
+ where Self : Differentiable ,
1520
+ CotangentVector : SIMD ,
1521
+ Scalar : BinaryFloatingPoint & Differentiable ,
1522
+ Scalar. CotangentVector: BinaryFloatingPoint ,
1523
+ CotangentVector. Scalar == Scalar . CotangentVector {
1448
1524
@inlinable
1449
1525
static func _vjpAdd(
1450
1526
lhs: Scalar , rhs: Self
1451
- ) -> ( Self , ( Self ) -> ( Scalar , Self ) ) {
1452
- return ( lhs + rhs, { v in
1527
+ ) -> ( Self , ( CotangentVector ) -> ( Scalar . CotangentVector , CotangentVector ) ) {
1528
+ return ( lhs + rhs, { ( v : CotangentVector ) in
1453
1529
return ( v. sum ( ) , v)
1454
1530
} )
1455
1531
}
1456
-
1532
+
1457
1533
@inlinable
1458
1534
static func _vjpSubtract(
1459
1535
lhs: Scalar , rhs: Self
1460
- ) -> ( Self , ( Self ) -> ( Scalar , Self ) ) {
1461
- return ( lhs + rhs, { v in
1536
+ ) -> ( Self , ( CotangentVector ) -> ( Scalar . CotangentVector , CotangentVector ) ) {
1537
+ return ( lhs + rhs, { ( v : CotangentVector ) in
1462
1538
return ( v. sum ( ) , - v)
1463
1539
} )
1464
1540
}
1465
-
1541
+
1466
1542
@inlinable
1467
1543
static func _vjpAdd(
1468
1544
lhs: Self , rhs: Scalar
1469
- ) -> ( Self , ( Self ) -> ( Self , Scalar ) ) {
1470
- return ( lhs + rhs, { v in
1545
+ ) -> ( Self , ( CotangentVector ) -> ( CotangentVector , Scalar . CotangentVector ) ) {
1546
+ return ( lhs + rhs, { ( v : CotangentVector ) in
1471
1547
return ( v, v. sum ( ) )
1472
1548
} )
1473
1549
}
1474
-
1550
+
1475
1551
@inlinable
1476
1552
static func _vjpSubtract(
1477
1553
lhs: Self , rhs: Scalar
1478
- ) -> ( Self , ( Self ) -> ( Self , Scalar ) ) {
1479
- return ( lhs + rhs, { v in
1554
+ ) -> ( Self , ( CotangentVector ) -> ( CotangentVector , Scalar . CotangentVector ) ) {
1555
+ return ( lhs + rhs, { ( v : CotangentVector ) in
1480
1556
return ( v, - v. sum ( ) )
1481
1557
} )
1482
1558
}
1483
1559
}
1484
1560
1485
-
1486
1561
extension SIMD
1487
- where Self: SIMDDifferentiable ,
1488
- Scalar: SIMDDifferentiable & FloatingPoint {
1489
- public static func _vjpSum( ) -> ( Scalar , ( Scalar ) -> Self ) {
1490
- return ( self . sum ( ) , { v in SIMD ( repeating: v) } )
1562
+ where Self : Differentiable ,
1563
+ CotangentVector : SIMD ,
1564
+ Scalar : BinaryFloatingPoint & Differentiable ,
1565
+ Self. CotangentVector == Self ,
1566
+ Scalar. CotangentVector == Scalar {
1567
+ @inlinable
1568
+ static func _vjpMultiply(
1569
+ lhs: Self , rhs: Scalar
1570
+ ) -> ( Self , ( CotangentVector ) -> ( CotangentVector , Scalar . CotangentVector ) ) {
1571
+ return ( lhs * rhs, { ( v: CotangentVector ) in
1572
+ return ( v * rhs, ( v * lhs) . sum ( ) )
1573
+ } )
1574
+ }
1575
+
1576
+ @inlinable
1577
+ static func _vjpDivide(
1578
+ lhs: Self , rhs: Scalar
1579
+ ) -> ( Self , ( CotangentVector ) -> ( CotangentVector , Scalar . CotangentVector ) ) {
1580
+ return ( lhs / rhs, { ( v: CotangentVector ) in
1581
+ ( - lhs / ( rhs * rhs) * v, ( v / rhs) . sum ( ) )
1582
+ } )
1583
+ }
1584
+
1585
+ @inlinable
1586
+ static func _vjpMultiply(
1587
+ lhs: Scalar , rhs: Self
1588
+ ) -> ( Self , ( CotangentVector ) -> ( Scalar . CotangentVector , CotangentVector ) ) {
1589
+ return ( lhs * rhs, { ( v: CotangentVector ) in
1590
+ return ( ( v * lhs) . sum ( ) , v * rhs)
1591
+ } )
1592
+ }
1593
+
1594
+ @inlinable
1595
+ static func _vjpDivide(
1596
+ lhs: Scalar , rhs: Self
1597
+ ) -> ( Self , ( CotangentVector ) -> ( Scalar . CotangentVector , CotangentVector ) ) {
1598
+ return ( lhs / rhs, { ( v: CotangentVector ) in
1599
+ ( ( v / rhs) . sum ( ) , - lhs / ( rhs * rhs) * v)
1600
+ } )
1491
1601
}
1492
1602
}
1493
1603
1494
- //extension SIMD
1495
- // where Self : SIMDDifferentiable,
1496
- // Scalar : SIMDDifferentiable & SIMDScalar {
1497
- // public func _vjpSubscript<Index>(index: SIMD2<Index>) ->
1498
- // (SIMD2<Scalar>, (SIMD2<Self.Scalar>) -> (CotangentVector, SIMD2<Index>)) where Index: FixedWidthInteger & SIMDScalar
1499
- // {
1500
- // func pullback(_ v: SIMD2<Self.Scalar>) -> (CotangentVector, SIMD2<Index>) {
1501
- // var gradientOut = SIMD(repeating: 0)
1502
- // gradientOut[index] = gradientIn
1503
- // return (CotangentVector(gradientOut), index)
1504
- // }
1505
- // return (self[index], pullback)
1506
- // }
1507
- //}
1604
+ extension SIMD
1605
+ where Self : Differentiable ,
1606
+ CotangentVector : SIMD ,
1607
+ Scalar : BinaryFloatingPoint & Differentiable ,
1608
+ Scalar. CotangentVector : BinaryFloatingPoint ,
1609
+ CotangentVector == Self {
1610
+ @usableFromInline
1611
+ func _vjpSum( ) -> ( Scalar , ( Scalar . CotangentVector ) -> CotangentVector ) {
1612
+ return ( sum ( ) , { ( v: Scalar . CotangentVector ) in Self . init ( repeating: Scalar ( v) ) } )
1613
+ }
1614
+ }
0 commit comments