Skip to content

[AutoDiff] Replace SIMD use of @differentiable(jvp:vjp:) with @derivative(of:) #28930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 4, 2020
85 changes: 50 additions & 35 deletions stdlib/public/core/SIMDVector.swift
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ extension SIMD {
/// A vector with the specified value in all lanes.
@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpInit(repeating:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint & Differentiable,
Expand Down Expand Up @@ -789,7 +789,7 @@ extension SIMD where Scalar: FixedWidthInteger {
extension SIMD where Scalar : FloatingPoint {
@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpAdd(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint,
Expand All @@ -802,7 +802,7 @@ extension SIMD where Scalar : FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint,
Expand All @@ -815,7 +815,7 @@ extension SIMD where Scalar : FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint,
Expand All @@ -828,11 +828,11 @@ extension SIMD where Scalar : FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpDivide(lhs:rhs:)
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint,
Self.TangentVector == Self)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint,
Self.TangentVector == Self)
public static func /(lhs: Self, rhs: Self) -> Self {
var result = Self()
for i in result.indices { result[i] = lhs[i] / rhs[i] }
Expand Down Expand Up @@ -877,7 +877,7 @@ extension SIMD where Scalar : FloatingPoint {
// FIXME: TF-545 we want the sum() func to be marked as
// `@_alwaysEmitIntoClient` like before when we define the VJP
@inlinable
@differentiable(vjp: _vjpSum
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint & Differentiable,
Expand Down Expand Up @@ -1198,7 +1198,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpNegate(rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint,
Expand All @@ -1209,7 +1209,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpAdd(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : Differentiable & BinaryFloatingPoint,
Expand All @@ -1221,7 +1221,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : Differentiable & BinaryFloatingPoint,
Expand All @@ -1233,7 +1233,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint & Differentiable,
Expand All @@ -1245,7 +1245,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpDivide(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint & Differentiable,
Expand All @@ -1257,7 +1257,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpAdd(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : Differentiable & BinaryFloatingPoint,
Expand All @@ -1269,7 +1269,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpSubtract(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : Differentiable & BinaryFloatingPoint,
Expand All @@ -1281,7 +1281,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpMultiply(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint & Differentiable,
Expand All @@ -1293,7 +1293,7 @@ extension SIMD where Scalar: FloatingPoint {

@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(vjp: _vjpDivide(lhs:rhs:)
@differentiable(
where Self : Differentiable,
Self.TangentVector : SIMD,
Scalar : BinaryFloatingPoint & Differentiable,
Expand Down Expand Up @@ -1520,24 +1520,27 @@ extension SIMD
Scalar : BinaryFloatingPoint,
TangentVector.Scalar : BinaryFloatingPoint {
@inlinable
@derivative(of: +)
static func _vjpAdd(lhs: Self, rhs: Self)
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs + rhs, { v in
return (v, v)
})
}

@inlinable
@derivative(of: -)
static func _vjpSubtract(lhs: Self, rhs: Self)
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs - rhs, { v in
return (v, -v)
})
}

@inlinable
@derivative(of: -)
static func _vjpNegate(rhs: Self)
-> (Self, (TangentVector) -> (TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector)) {
return (-rhs, { v in
return -v
})
Expand All @@ -1550,16 +1553,18 @@ extension SIMD
Scalar : BinaryFloatingPoint,
Self.TangentVector == Self {
@inlinable
@derivative(of: *)
static func _vjpMultiply(lhs: Self, rhs: Self)
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs * rhs, { v in
return (v * rhs, v * lhs)
})
}

@inlinable
@derivative(of: /)
static func _vjpDivide(lhs: Self, rhs: Self)
-> (Self, (TangentVector) -> (TangentVector, TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
return (lhs / rhs, { v in
(v / rhs, -lhs / (rhs * rhs) * v)
})
Expand All @@ -1573,32 +1578,36 @@ extension SIMD
Scalar.TangentVector : BinaryFloatingPoint,
TangentVector.Scalar == Scalar.TangentVector {
@inlinable
@derivative(of: +)
static func _vjpAdd(lhs: Scalar, rhs: Self)
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
return (lhs + rhs, { v in
return (v.sum(), v)
})
}

@inlinable
@derivative(of: -)
static func _vjpSubtract(lhs: Scalar, rhs: Self)
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
return (lhs - rhs, { v in
return (v.sum(), -v)
})
}

@inlinable
@derivative(of: +)
static func _vjpAdd(lhs: Self, rhs: Scalar)
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
return (lhs + rhs, { v in
return (v, v.sum())
})
}

@inlinable
@derivative(of: -)
static func _vjpSubtract(lhs: Self, rhs: Scalar)
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
return (lhs - rhs, { v in
return (v, -v.sum())
})
Expand All @@ -1612,32 +1621,36 @@ extension SIMD
Self.TangentVector == Self,
Scalar.TangentVector == Scalar {
@inlinable
@derivative(of: *)
static func _vjpMultiply(lhs: Self, rhs: Scalar)
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
return (lhs * rhs, { v in
return (v * rhs, (v * lhs).sum())
})
}

@inlinable
@derivative(of: /)
static func _vjpDivide(lhs: Self, rhs: Scalar)
-> (Self, (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (TangentVector, Scalar.TangentVector)) {
return (lhs / rhs, { v in
(v / rhs, (-lhs / (rhs * rhs) * v).sum())
})
}

@inlinable
@derivative(of: *)
static func _vjpMultiply(lhs: Scalar, rhs: Self)
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
return (lhs * rhs, { v in
return ((v * rhs).sum(), v * lhs)
})
}

@inlinable
@derivative(of: /)
static func _vjpDivide(lhs: Scalar, rhs: Self)
-> (Self, (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
-> (value: Self, pullback: (TangentVector) -> (Scalar.TangentVector, TangentVector)) {
return (lhs / rhs, { v in
((v / rhs).sum(), -lhs / (rhs * rhs) * v)
})
Expand All @@ -1651,7 +1664,8 @@ extension SIMD
Scalar.TangentVector : BinaryFloatingPoint,
TangentVector == Self {
@inlinable
func _vjpSum() -> (Scalar, (Scalar.TangentVector) -> TangentVector) {
@derivative(of: sum)
func _vjpSum() -> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector) {
return (sum(), { v in Self(repeating: Scalar(v)) })
}
}
Expand All @@ -1662,9 +1676,10 @@ extension SIMD
Scalar : BinaryFloatingPoint & Differentiable,
Self.TangentVector == Self,
Scalar.TangentVector == Scalar {
@usableFromInline
@inlinable
@derivative(of: init(repeating:))
static func _vjpInit(repeating value: Scalar)
-> (Self, (TangentVector) -> Scalar.TangentVector) {
-> (value: Self, pullback: (TangentVector) -> Scalar.TangentVector) {
return (Self(repeating: value), { v in v.sum() })
}
}