Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 8cf1a66

Browse files
committed
Address review comments.
1 parent 51fe3e9 commit 8cf1a66

File tree

2 files changed

+4
-10
lines changed

2 files changed

+4
-10
lines changed

Sources/TensorFlow/Initializers.swift

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint,
449449
}
450450
}
451451

452-
fileprivate extension Tensor where Scalar: BinaryFloatingPoint {
452+
fileprivate extension Tensor where Scalar: TensorFlowFloatingPoint {
453453
private static func glorot(
454454
fromStandardUniform randomUniform: __shared Tensor<Scalar>,
455455
shape: __shared TensorShape
@@ -459,12 +459,7 @@ fileprivate extension Tensor where Scalar: BinaryFloatingPoint {
459459
let fanIn = shape[shape.count - 2] * receptiveField
460460
let fanOut = shape[shape.count - 1] * receptiveField
461461
let minusOneToOne = 2 * randomUniform - 1
462-
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
463-
let _sqrt = Darwin.sqrt as (Scalar) -> Scalar
464-
#else
465-
let _sqrt = Glibc.sqrt as (Scalar) -> Scalar
466-
#endif
467-
return _sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
462+
return Scalar.sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
468463
}
469464
}
470465

@@ -488,8 +483,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
488483
}
489484
}
490485

491-
public extension Tensor where Scalar: BinaryFloatingPoint,
492-
Scalar.RawSignificand: FixedWidthInteger {
486+
public extension Tensor where Scalar: TensorFlowFloatingPoint {
493487
/// Performs Glorot uniform initialization for the specified shape, creating a tensor by
494488
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
495489
/// where limit is `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of

Sources/TensorFlow/Operators/Math.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ extension Tensor: ElementaryFunctions where Scalar: TensorFlowFloatingPoint {
132132
TensorFlow.log1p(x)
133133
}
134134

135-
/// `x**y` interpreted as `exp(y * log(x))`
135+
/// `exp(y log(x))` computed without loss of intermediate precision.
136136
///
137137
/// For real types, if `x` is negative the result is NaN, even if `y` has
138138
/// an integral value. For complex types, there is a branch cut on the

0 commit comments

Comments
 (0)