Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 473fbff

Browse files
committed
Make library resilient.
Mark layer structs as `@_fixed_layout`. Gate TensorFlow module import on `COMPILING_TENSORFLOW_MODULE` flag. This will be used during stdlib compilation to prevent dubious import.
1 parent a27f7cf commit 473fbff

File tree

4 files changed

+10
-0
lines changed

4 files changed

+10
-0
lines changed

Sources/DeepLearning/Helpers.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#if !COMPILING_TENSORFLOW_MODULE
1516
import TensorFlow
17+
#endif
1618

1719
// `pow` is defined in Darwin/Glibc on `Float` and `Double`, but there doesn't exist a generic
1820
// version for `FloatingPoint`.

Sources/DeepLearning/Layer.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#if !COMPILING_TENSORFLOW_MODULE
1516
@_exported import TensorFlow
17+
#endif
1618

1719
/// A neural network layer.
1820
///
@@ -44,6 +46,7 @@ public extension Layer {
4446
}
4547
}
4648

49+
@_fixed_layout
4750
public struct Dense<Scalar>: Layer
4851
where Scalar : FloatingPoint & Differentiable & TensorFlowScalar {
4952
public var weight: Tensor<Scalar>
@@ -63,6 +66,7 @@ public extension Dense where Scalar : BinaryFloatingPoint,
6366
}
6467
}
6568

69+
@_fixed_layout
6670
public struct Conv2D<Scalar>: Layer
6771
where Scalar : FloatingPoint & Differentiable & TensorFlowScalar {
6872
public var filter: Tensor<Scalar>

Sources/DeepLearning/Loss.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#if !COMPILING_TENSORFLOW_MODULE
1516
import TensorFlow
17+
#endif
1618

1719
@differentiable(vjp: _vjpMSE)
1820
public func meanSquaredError<Scalar: FloatingPoint>(predicted: Tensor<Scalar>,

Sources/DeepLearning/Optimizer.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#if !COMPILING_TENSORFLOW_MODULE
1516
import TensorFlow
17+
#endif
1618

1719
public protocol Optimizer {
1820
associatedtype Model: Layer

0 commit comments

Comments
 (0)