Skip to content

Commit 7abb1c3

Browse files
committed
Code cleanup and additonal tests.
1 parent bc32cf9 commit 7abb1c3

File tree

3 files changed

+123
-78
lines changed

3 files changed

+123
-78
lines changed

stdlib/public/core/SIMDVector.swift

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ extension SIMD where Scalar : FloatingPoint {
804804
// SWIFT_ENABLE_TENSORFLOW
805805
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
806806
where Self : Differentiable,
807-
Self.TangentVector: SIMD,
807+
Self.TangentVector : SIMD,
808808
Scalar : BinaryFloatingPoint,
809809
Self.TangentVector.Scalar : BinaryFloatingPoint)
810810
public static func -(lhs: Self, rhs: Self) -> Self {
@@ -874,9 +874,9 @@ extension SIMD where Scalar : FloatingPoint {
874874

875875
/// Returns the sum of the scalars in the vector.
876876
// SWIFT_ENABLE_TENSORFLOW
877-
@inlinable
878877
// FIXME: TF-545 we want the sum() func to be marked as
879878
// `@_alwaysEmitIntoClient` like before when we define the VJP
879+
@inlinable
880880
@differentiable(vjp: _vjpSum
881881
where Self : Differentiable,
882882
Self.TangentVector : SIMD,
@@ -1211,7 +1211,7 @@ extension SIMD where Scalar: FloatingPoint {
12111211
// SWIFT_ENABLE_TENSORFLOW
12121212
@differentiable(vjp: _vjpAdd(lhs:rhs:)
12131213
where Self : Differentiable,
1214-
Self.TangentVector: SIMD,
1214+
Self.TangentVector : SIMD,
12151215
Scalar : Differentiable & BinaryFloatingPoint,
12161216
Scalar.TangentVector : BinaryFloatingPoint,
12171217
Self.TangentVector.Scalar == Scalar.TangentVector)
@@ -1246,11 +1246,11 @@ extension SIMD where Scalar: FloatingPoint {
12461246
@_transparent
12471247
// SWIFT_ENABLE_TENSORFLOW
12481248
@differentiable(vjp: _vjpDivide(lhs:rhs:)
1249-
where Self : Differentiable,
1250-
Self.TangentVector : SIMD,
1251-
Scalar : BinaryFloatingPoint & Differentiable,
1252-
Self.TangentVector == Self,
1253-
Scalar.TangentVector == Scalar)
1249+
where Self : Differentiable,
1250+
Self.TangentVector : SIMD,
1251+
Scalar : BinaryFloatingPoint & Differentiable,
1252+
Self.TangentVector == Self,
1253+
Scalar.TangentVector == Scalar)
12541254
public static func /(lhs: Scalar, rhs: Self) -> Self {
12551255
return Self(repeating: lhs) / rhs
12561256
}
@@ -1290,7 +1290,7 @@ extension SIMD where Scalar: FloatingPoint {
12901290
public static func *(lhs: Self, rhs: Scalar) -> Self {
12911291
return lhs * Self(repeating: rhs)
12921292
}
1293-
1293+
12941294
@_transparent
12951295
// SWIFT_ENABLE_TENSORFLOW
12961296
@differentiable(vjp: _vjpDivide(lhs:rhs:)
@@ -1507,26 +1507,24 @@ extension SIMD
15071507
Scalar : BinaryFloatingPoint,
15081508
TangentVector.Scalar : BinaryFloatingPoint {
15091509
@inlinable
1510-
static func _vjpAdd(
1511-
lhs: Self, rhs: Self
1512-
) -> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1510+
static func _vjpAdd(lhs: Self, rhs: Self)
1511+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
15131512
return (lhs + rhs, { v in
15141513
return (v, v)
15151514
})
15161515
}
15171516

15181517
@inlinable
1519-
static func _vjpSubtract(
1520-
lhs: Self, rhs: Self
1521-
) -> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1518+
static func _vjpSubtract(lhs: Self, rhs: Self)
1519+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
15221520
return (lhs - rhs, { v in
15231521
return (v, -v)
15241522
})
15251523
}
15261524

15271525
@inlinable
15281526
static func _vjpNegate(rhs: Self)
1529-
-> (Self, (TangentVector) -> (TangentVector)) {
1527+
-> (Self, (TangentVector) -> (TangentVector)) {
15301528
return (-rhs, { v in
15311529
return -v
15321530
})
@@ -1535,22 +1533,20 @@ extension SIMD
15351533

15361534
extension SIMD
15371535
where Self : Differentiable,
1538-
TangentVector: SIMD,
1536+
TangentVector : SIMD,
15391537
Scalar : BinaryFloatingPoint,
15401538
Self.TangentVector == Self {
15411539
@inlinable
1542-
static func _vjpMultiply(
1543-
lhs: Self, rhs: Self
1544-
) -> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1540+
static func _vjpMultiply(lhs: Self, rhs: Self)
1541+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
15451542
return (lhs * rhs, { v in
15461543
return (v * rhs, v * lhs)
15471544
})
15481545
}
15491546

15501547
@inlinable
1551-
static func _vjpDivide(
1552-
lhs: Self, rhs: Self
1553-
) -> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
1548+
static func _vjpDivide(lhs: Self, rhs: Self)
1549+
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
15541550
return (lhs / rhs, { v in
15551551
(v / rhs, -lhs / (rhs * rhs) * v)
15561552
})
@@ -1561,39 +1557,35 @@ extension SIMD
15611557
where Self : Differentiable,
15621558
TangentVector : SIMD,
15631559
Scalar : BinaryFloatingPoint & Differentiable,
1564-
Scalar.TangentVector: BinaryFloatingPoint,
1560+
Scalar.TangentVector : BinaryFloatingPoint,
15651561
TangentVector.Scalar == Scalar.TangentVector {
15661562
@inlinable
1567-
static func _vjpAdd(
1568-
lhs: Scalar, rhs: Self
1569-
) -> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1563+
static func _vjpAdd(lhs: Scalar, rhs: Self)
1564+
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
15701565
return (lhs + rhs, { v in
15711566
return (v.sum(), v)
15721567
})
15731568
}
15741569

15751570
@inlinable
1576-
static func _vjpSubtract(
1577-
lhs: Scalar, rhs: Self
1578-
) -> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1571+
static func _vjpSubtract(lhs: Scalar, rhs: Self)
1572+
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
15791573
return (lhs - rhs, { v in
15801574
return (v.sum(), -v)
15811575
})
15821576
}
15831577

15841578
@inlinable
1585-
static func _vjpAdd(
1586-
lhs: Self, rhs: Scalar
1587-
) -> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1579+
static func _vjpAdd(lhs: Self, rhs: Scalar)
1580+
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
15881581
return (lhs + rhs, { v in
15891582
return (v, v.sum())
15901583
})
15911584
}
15921585

15931586
@inlinable
1594-
static func _vjpSubtract(
1595-
lhs: Self, rhs: Scalar
1596-
) -> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1587+
static func _vjpSubtract(lhs: Self, rhs: Scalar)
1588+
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
15971589
return (lhs - rhs, { v in
15981590
return (v, -v.sum())
15991591
})
@@ -1607,36 +1599,32 @@ extension SIMD
16071599
Self.TangentVector == Self,
16081600
Scalar.TangentVector == Scalar {
16091601
@inlinable
1610-
static func _vjpMultiply(
1611-
lhs: Self, rhs: Scalar
1612-
) -> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1602+
static func _vjpMultiply(lhs: Self, rhs: Scalar)
1603+
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
16131604
return (lhs * rhs, { v in
16141605
return (v * rhs, (v * lhs).sum())
16151606
})
16161607
}
16171608

16181609
@inlinable
1619-
static func _vjpDivide(
1620-
lhs: Self, rhs: Scalar
1621-
) -> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
1610+
static func _vjpDivide(lhs: Self, rhs: Scalar)
1611+
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
16221612
return (lhs / rhs, { v in
16231613
(v / rhs, (-lhs / (rhs * rhs) * v).sum())
16241614
})
16251615
}
16261616

16271617
@inlinable
1628-
static func _vjpMultiply(
1629-
lhs: Scalar, rhs: Self
1630-
) -> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1618+
static func _vjpMultiply(lhs: Scalar, rhs: Self)
1619+
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
16311620
return (lhs * rhs, { v in
16321621
return ((v * rhs).sum(), v * lhs)
16331622
})
16341623
}
16351624

16361625
@inlinable
1637-
static func _vjpDivide(
1638-
lhs: Scalar, rhs: Self
1639-
) -> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
1626+
static func _vjpDivide(lhs: Scalar, rhs: Self)
1627+
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
16401628
return (lhs / rhs, { v in
16411629
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
16421630
})
@@ -1662,8 +1650,8 @@ extension SIMD
16621650
Self.TangentVector == Self,
16631651
Scalar.TangentVector == Scalar {
16641652
@usableFromInline
1665-
static func _vjpInit(repeating value: Scalar) ->
1666-
(Self, (TangentVector) -> Scalar.TangentVector) {
1653+
static func _vjpInit(repeating value: Scalar)
1654+
-> (Self, (TangentVector) -> Scalar.TangentVector) {
16671655
return (Self(repeating: value), { v in v.sum() })
16681656
}
16691657
}

stdlib/public/core/SIMDVectorTypes.swift.gyb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ extension SIMD${n} : AdditiveArithmetic where Scalar : FloatingPoint {}
194194

195195
extension SIMD${n} : Differentiable
196196
where Scalar : Differentiable & BinaryFloatingPoint,
197-
Scalar.TangentVector: BinaryFloatingPoint {
197+
Scalar.TangentVector : BinaryFloatingPoint {
198198
public typealias TangentVector = SIMD${n}
199199
public typealias AllDifferentiableVariables = SIMD${n}
200200
public func tangentVector(from cotangent: TangentVector) -> TangentVector {

0 commit comments

Comments
 (0)