@@ -804,7 +804,7 @@ extension SIMD where Scalar : FloatingPoint {
804
804
// SWIFT_ENABLE_TENSORFLOW
805
805
@differentiable ( vjp: _vjpSubtract ( lhs: rhs: )
806
806
where Self : Differentiable,
807
- Self . TangentVector: SIMD,
807
+ Self . TangentVector : SIMD,
808
808
Scalar : BinaryFloatingPoint,
809
809
Self . TangentVector. Scalar : BinaryFloatingPoint)
810
810
public static func - ( lhs: Self , rhs: Self ) -> Self {
@@ -874,9 +874,9 @@ extension SIMD where Scalar : FloatingPoint {
874
874
875
875
/// Returns the sum of the scalars in the vector.
876
876
// SWIFT_ENABLE_TENSORFLOW
877
- @inlinable
878
877
// FIXME: TF-545 we want the sum() func to be marked as
879
878
// `@_alwaysEmitIntoClient` like before when we define the VJP
879
+ @inlinable
880
880
@differentiable ( vjp: _vjpSum
881
881
where Self : Differentiable,
882
882
Self . TangentVector : SIMD,
@@ -1211,7 +1211,7 @@ extension SIMD where Scalar: FloatingPoint {
1211
1211
// SWIFT_ENABLE_TENSORFLOW
1212
1212
@differentiable ( vjp: _vjpAdd ( lhs: rhs: )
1213
1213
where Self : Differentiable,
1214
- Self . TangentVector: SIMD,
1214
+ Self . TangentVector : SIMD,
1215
1215
Scalar : Differentiable & BinaryFloatingPoint,
1216
1216
Scalar . TangentVector : BinaryFloatingPoint,
1217
1217
Self . TangentVector. Scalar == Scalar . TangentVector)
@@ -1246,11 +1246,11 @@ extension SIMD where Scalar: FloatingPoint {
1246
1246
@_transparent
1247
1247
// SWIFT_ENABLE_TENSORFLOW
1248
1248
@differentiable ( vjp: _vjpDivide ( lhs: rhs: )
1249
- where Self : Differentiable,
1250
- Self . TangentVector : SIMD,
1251
- Scalar : BinaryFloatingPoint & Differentiable,
1252
- Self . TangentVector == Self,
1253
- Scalar . TangentVector == Scalar)
1249
+ where Self : Differentiable,
1250
+ Self . TangentVector : SIMD,
1251
+ Scalar : BinaryFloatingPoint & Differentiable,
1252
+ Self . TangentVector == Self,
1253
+ Scalar . TangentVector == Scalar)
1254
1254
public static func / ( lhs: Scalar , rhs: Self ) -> Self {
1255
1255
return Self ( repeating: lhs) / rhs
1256
1256
}
@@ -1290,7 +1290,7 @@ extension SIMD where Scalar: FloatingPoint {
1290
1290
public static func * ( lhs: Self , rhs: Scalar ) -> Self {
1291
1291
return lhs * Self( repeating: rhs)
1292
1292
}
1293
-
1293
+
1294
1294
@_transparent
1295
1295
// SWIFT_ENABLE_TENSORFLOW
1296
1296
@differentiable ( vjp: _vjpDivide ( lhs: rhs: )
@@ -1507,26 +1507,24 @@ extension SIMD
1507
1507
Scalar : BinaryFloatingPoint ,
1508
1508
TangentVector. Scalar : BinaryFloatingPoint {
1509
1509
@inlinable
1510
- static func _vjpAdd(
1511
- lhs: Self , rhs: Self
1512
- ) -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1510
+ static func _vjpAdd( lhs: Self , rhs: Self )
1511
+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1513
1512
return ( lhs + rhs, { v in
1514
1513
return ( v, v)
1515
1514
} )
1516
1515
}
1517
1516
1518
1517
@inlinable
1519
- static func _vjpSubtract(
1520
- lhs: Self , rhs: Self
1521
- ) -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1518
+ static func _vjpSubtract( lhs: Self , rhs: Self )
1519
+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1522
1520
return ( lhs - rhs, { v in
1523
1521
return ( v, - v)
1524
1522
} )
1525
1523
}
1526
1524
1527
1525
@inlinable
1528
1526
static func _vjpNegate( rhs: Self )
1529
- -> ( Self , ( TangentVector ) -> ( TangentVector ) ) {
1527
+ -> ( Self , ( TangentVector ) -> ( TangentVector ) ) {
1530
1528
return ( - rhs, { v in
1531
1529
return - v
1532
1530
} )
@@ -1535,22 +1533,20 @@ extension SIMD
1535
1533
1536
1534
extension SIMD
1537
1535
where Self : Differentiable ,
1538
- TangentVector: SIMD ,
1536
+ TangentVector : SIMD ,
1539
1537
Scalar : BinaryFloatingPoint ,
1540
1538
Self. TangentVector == Self {
1541
1539
@inlinable
1542
- static func _vjpMultiply(
1543
- lhs: Self , rhs: Self
1544
- ) -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1540
+ static func _vjpMultiply( lhs: Self , rhs: Self )
1541
+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1545
1542
return ( lhs * rhs, { v in
1546
1543
return ( v * rhs, v * lhs)
1547
1544
} )
1548
1545
}
1549
1546
1550
1547
@inlinable
1551
- static func _vjpDivide(
1552
- lhs: Self , rhs: Self
1553
- ) -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1548
+ static func _vjpDivide( lhs: Self , rhs: Self )
1549
+ -> ( Self , ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
1554
1550
return ( lhs / rhs, { v in
1555
1551
( v / rhs, - lhs / ( rhs * rhs) * v)
1556
1552
} )
@@ -1561,39 +1557,35 @@ extension SIMD
1561
1557
where Self : Differentiable ,
1562
1558
TangentVector : SIMD ,
1563
1559
Scalar : BinaryFloatingPoint & Differentiable ,
1564
- Scalar. TangentVector: BinaryFloatingPoint ,
1560
+ Scalar. TangentVector : BinaryFloatingPoint ,
1565
1561
TangentVector. Scalar == Scalar . TangentVector {
1566
1562
@inlinable
1567
- static func _vjpAdd(
1568
- lhs: Scalar , rhs: Self
1569
- ) -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1563
+ static func _vjpAdd( lhs: Scalar , rhs: Self )
1564
+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1570
1565
return ( lhs + rhs, { v in
1571
1566
return ( v. sum ( ) , v)
1572
1567
} )
1573
1568
}
1574
1569
1575
1570
@inlinable
1576
- static func _vjpSubtract(
1577
- lhs: Scalar , rhs: Self
1578
- ) -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1571
+ static func _vjpSubtract( lhs: Scalar , rhs: Self )
1572
+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1579
1573
return ( lhs - rhs, { v in
1580
1574
return ( v. sum ( ) , - v)
1581
1575
} )
1582
1576
}
1583
1577
1584
1578
@inlinable
1585
- static func _vjpAdd(
1586
- lhs: Self , rhs: Scalar
1587
- ) -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1579
+ static func _vjpAdd( lhs: Self , rhs: Scalar )
1580
+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1588
1581
return ( lhs + rhs, { v in
1589
1582
return ( v, v. sum ( ) )
1590
1583
} )
1591
1584
}
1592
1585
1593
1586
@inlinable
1594
- static func _vjpSubtract(
1595
- lhs: Self , rhs: Scalar
1596
- ) -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1587
+ static func _vjpSubtract( lhs: Self , rhs: Scalar )
1588
+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1597
1589
return ( lhs - rhs, { v in
1598
1590
return ( v, - v. sum ( ) )
1599
1591
} )
@@ -1607,36 +1599,32 @@ extension SIMD
1607
1599
Self. TangentVector == Self ,
1608
1600
Scalar. TangentVector == Scalar {
1609
1601
@inlinable
1610
- static func _vjpMultiply(
1611
- lhs: Self , rhs: Scalar
1612
- ) -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1602
+ static func _vjpMultiply( lhs: Self , rhs: Scalar )
1603
+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1613
1604
return ( lhs * rhs, { v in
1614
1605
return ( v * rhs, ( v * lhs) . sum ( ) )
1615
1606
} )
1616
1607
}
1617
1608
1618
1609
@inlinable
1619
- static func _vjpDivide(
1620
- lhs: Self , rhs: Scalar
1621
- ) -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1610
+ static func _vjpDivide( lhs: Self , rhs: Scalar )
1611
+ -> ( Self , ( TangentVector ) -> ( TangentVector , Scalar . TangentVector ) ) {
1622
1612
return ( lhs / rhs, { v in
1623
1613
( v / rhs, ( - lhs / ( rhs * rhs) * v) . sum ( ) )
1624
1614
} )
1625
1615
}
1626
1616
1627
1617
@inlinable
1628
- static func _vjpMultiply(
1629
- lhs: Scalar , rhs: Self
1630
- ) -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1618
+ static func _vjpMultiply( lhs: Scalar , rhs: Self )
1619
+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1631
1620
return ( lhs * rhs, { v in
1632
1621
return ( ( v * rhs) . sum ( ) , v * lhs)
1633
1622
} )
1634
1623
}
1635
1624
1636
1625
@inlinable
1637
- static func _vjpDivide(
1638
- lhs: Scalar , rhs: Self
1639
- ) -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1626
+ static func _vjpDivide( lhs: Scalar , rhs: Self )
1627
+ -> ( Self , ( TangentVector ) -> ( Scalar . TangentVector , TangentVector ) ) {
1640
1628
return ( lhs / rhs, { v in
1641
1629
( ( v / rhs) . sum ( ) , - lhs / ( rhs * rhs) * v)
1642
1630
} )
@@ -1662,8 +1650,8 @@ extension SIMD
1662
1650
Self. TangentVector == Self ,
1663
1651
Scalar. TangentVector == Scalar {
1664
1652
@usableFromInline
1665
- static func _vjpInit( repeating value: Scalar ) ->
1666
- ( Self , ( TangentVector ) -> Scalar . TangentVector ) {
1653
+ static func _vjpInit( repeating value: Scalar )
1654
+ -> ( Self , ( TangentVector ) -> Scalar . TangentVector ) {
1667
1655
return ( Self ( repeating: value) , { v in v. sum ( ) } )
1668
1656
}
1669
1657
}
0 commit comments