Skip to content

Commit ffdb4c4

Browse files
committed
Get sum() differentiable.
1 parent 99b8408 commit ffdb4c4

File tree

3 files changed

+31
-28
lines changed

3 files changed

+31
-28
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2628,7 +2628,8 @@ static FuncDecl *resolveAutoDiffAssociatedFunction(
26282628
auto isABIPublic = [&](AbstractFunctionDecl *func) {
26292629
return func->getFormalAccess() >= AccessLevel::Public ||
26302630
func->getAttrs().hasAttribute<InlinableAttr>() ||
2631-
func->getAttrs().hasAttribute<UsableFromInlineAttr>();
2631+
func->getAttrs().hasAttribute<UsableFromInlineAttr>() ||
2632+
func->getAttrs().hasAttribute<AlwaysEmitIntoClientAttr>();
26322633
};
26332634

26342635
// If the original function is exported (i.e. it is public or

stdlib/public/core/SIMDVector.swift

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -873,14 +873,16 @@ extension SIMD where Scalar : FloatingPoint {
873873
}
874874

875875
/// Returns the sum of the scalars in the vector.
876-
@_alwaysEmitIntoClient
877876
// SWIFT_ENABLE_TENSORFLOW
878-
// @differentiable(vjp: _vjpSum
879-
// where Self : Differentiable,
880-
// Self.TangentVector : SIMD,
881-
// Scalar : BinaryFloatingPoint & Differentiable,
882-
// Scalar.TangentVector : BinaryFloatingPoint,
883-
// Self.TangentVector == Self)
877+
@inlinable
878+
// FIXME: TF-545 we want the sum() func to be marked as
879+
// `@_alwaysEmitIntoClient` like before when we define the VJP
880+
@differentiable(vjp: _vjpSum
881+
where Self : Differentiable,
882+
Self.TangentVector : SIMD,
883+
Scalar : BinaryFloatingPoint & Differentiable,
884+
Scalar.TangentVector : BinaryFloatingPoint,
885+
Self.TangentVector == Self)
884886
public func sum() -> Scalar {
885887
// Implementation note: this eventually be defined to lower to either
886888
// llvm.experimental.vector.reduce.fadd or an explicit tree-sum. Open-
@@ -1650,17 +1652,17 @@ extension SIMD
16501652
}
16511653
}
16521654

1653-
//extension SIMD
1654-
// where Self : Differentiable,
1655-
// TangentVector : SIMD,
1656-
// Scalar : BinaryFloatingPoint & Differentiable,
1657-
// Scalar.TangentVector : BinaryFloatingPoint,
1658-
// TangentVector == Self {
1659-
// @usableFromInline
1660-
// func _vjpSum() -> (Scalar, (Scalar.TangentVector) -> TangentVector) {
1661-
// return (sum(), { v in Self(repeating: Scalar(v)) })
1662-
// }
1663-
//}
1655+
extension SIMD
1656+
where Self : Differentiable,
1657+
TangentVector : SIMD,
1658+
Scalar : BinaryFloatingPoint & Differentiable,
1659+
Scalar.TangentVector : BinaryFloatingPoint,
1660+
TangentVector == Self {
1661+
@inlinable
1662+
func _vjpSum() -> (Scalar, (Scalar.TangentVector) -> TangentVector) {
1663+
return (sum(), { v in Self(repeating: Scalar(v)) })
1664+
}
1665+
}
16641666

16651667
extension SIMD
16661668
where Self : Differentiable,

test/AutoDiff/SIMD.swift

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ SIMDTests.test("Negate") {
3030
expectEqual(-a, bp1(a))
3131
}
3232

33-
//SIMDTests.test("Sum") {
34-
// let a = SIMD4<Float>(1, 2, 3, 4)
35-
//
36-
// let foo1 = { (x: SIMD4<Float>) -> Float in
37-
// return x.sum()
38-
// }
39-
// let bp1 = pullback(at: a, in: foo1)
40-
// expectEqual(SIMD4<Float>(3, 3, 3, 3), bp1(3))
41-
//}
33+
SIMDTests.test("Sum") {
34+
let a = SIMD4<Float>(1, 2, 3, 4)
35+
36+
let foo1 = { (x: SIMD4<Float>) -> Float in
37+
return x.sum()
38+
}
39+
let bp1 = pullback(at: a, in: foo1)
40+
expectEqual(SIMD4<Float>(3, 3, 3, 3), bp1(3))
41+
}
4242

4343
SIMDTests.test("Addition") {
4444
let a = SIMD4<Float>(1, 2, 3, 4)

0 commit comments

Comments
 (0)