Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 0c1ff41

Browse files
committed
Address review comments.
1 parent 51fe3e9 commit 0c1ff41

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

Sources/TensorFlow/Initializers.swift

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ public extension Tensor where Scalar: BinaryFloatingPoint,
449449
}
450450
}
451451

452-
fileprivate extension Tensor where Scalar: BinaryFloatingPoint {
452+
fileprivate extension Tensor where Scalar: TensorFlowFloatingPoint {
453453
private static func glorot(
454454
fromStandardUniform randomUniform: __shared Tensor<Scalar>,
455455
shape: __shared TensorShape
@@ -459,12 +459,7 @@ fileprivate extension Tensor where Scalar: BinaryFloatingPoint {
459459
let fanIn = shape[shape.count - 2] * receptiveField
460460
let fanOut = shape[shape.count - 1] * receptiveField
461461
let minusOneToOne = 2 * randomUniform - 1
462-
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
463-
let _sqrt = Darwin.sqrt as (Scalar) -> Scalar
464-
#else
465-
let _sqrt = Glibc.sqrt as (Scalar) -> Scalar
466-
#endif
467-
return _sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
462+
return Scalar.sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
468463
}
469464
}
470465

@@ -488,8 +483,7 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint {
488483
}
489484
}
490485

491-
public extension Tensor where Scalar: BinaryFloatingPoint,
492-
Scalar.RawSignificand: FixedWidthInteger {
486+
public extension Tensor where Scalar: TensorFlowFloatingPoint {
493487
/// Performs Glorot uniform initialization for the specified shape, creating a tensor by
494488
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
495489
/// where limit is `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of

0 commit comments

Comments
 (0)