Skip to content

Commit 1261704

Browse files
mikowalsrxwei
authored andcommitted
correct _vjpRsqrt to match Raw.rsqrtGrad (#24888)
correct _vjpRsqrt to match Raw.rsqrtGrad I got the updated derivative from here - https://www.derivative-calculator.net - and this matches the result of Raw.rsqrtGrad. I have just created the pull request by clicking the edit button in GitHub so I can not confirm that it builds or that tests pass.
1 parent 8472823 commit 1261704

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

stdlib/public/TensorFlow/Gradients.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ func _vjpRsqrt<T : TensorFlowFloatingPoint>(
455455
_ x: Tensor<T>
456456
) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) {
457457
let value = rsqrt(x)
458-
return (value, { v in -v / 2 * value })
458+
return (value, { v in -v / (2 * pow(x, 3 / 2))})
459459
}
460460

461461
@inlinable

0 commit comments

Comments
 (0)