Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit cf9ee4f

Browse files
jon-towsaeta
authored andcommitted
Add @differentiable support to pow(_:_:) and root(_:_:) (#366)
1 parent eb996cd commit cf9ee4f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

Sources/TensorFlow/Operators/Math.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1221,28 +1221,28 @@ internal func _vjpPow<T: TensorFlowFloatingPoint>(
12211221

12221222
/// Returns the power of the scalar to the tensor, broadcasting the scalar.
12231223
@inlinable
1224-
// @differentiable
1224+
@differentiable(wrt: rhs where T: TensorFlowFloatingPoint)
12251225
public func pow<T: TensorFlowFloatingPoint>(_ lhs: T, _ rhs: Tensor<T>) -> Tensor<T> {
12261226
pow(Tensor(lhs), rhs)
12271227
}
12281228

12291229
/// Returns the power of the tensor to the scalar, broadcasting the scalar.
12301230
@inlinable
1231-
// @differentiable
1231+
@differentiable(wrt: lhs where T: TensorFlowFloatingPoint)
12321232
public func pow<T: TensorFlowFloatingPoint>(_ lhs: Tensor<T>, _ rhs: T) -> Tensor<T> {
12331233
pow(lhs, Tensor(rhs))
12341234
}
12351235

12361236
/// Returns the power of the tensor to the scalar, broadcasting the scalar.
12371237
@inlinable
1238-
// @differentiable
1238+
@differentiable
12391239
public func pow<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ n: Int) -> Tensor<T> {
12401240
pow(x, Tensor(T(n)))
12411241
}
12421242

12431243
/// Returns the element-wise `n`th root of the tensor.
12441244
@inlinable
1245-
// @differentiable
1245+
@differentiable
12461246
public func root<T: TensorFlowFloatingPoint>(_ x: Tensor<T>, _ n: Int) -> Tensor<T> {
12471247
sign(x) * pow(abs(x), Tensor(T(1) / T(n)))
12481248
}

0 commit comments

Comments
 (0)