Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 472b29f

Browse files
Shashi456rxwei
authored andcommitted
Add function/lambda layer (#298)
1 parent a41903b commit 472b29f

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

Sources/TensorFlow/Layers/Core.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,19 @@ public extension Dense {
223223
activation: activation)
224224
}
225225
}
226+
227+
/// A layer that encloses a custom differentiable function.
228+
public struct Function<Input: Differentiable, Output: Differentiable>: Layer {
229+
public typealias Body = @differentiable (Input) -> Output
230+
231+
@noDerivative public let body: Body
232+
233+
public init(_ body: @escaping Body) {
234+
self.body = body
235+
}
236+
237+
@differentiable
238+
public func callAsFunction(_ input: Input) -> Output {
239+
return body(input)
240+
}
241+
}

Tests/TensorFlowTests/LayerTests.swift

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ final class LayerTests: XCTestCase {
3838
// Input shapes.
3939
let inputHeight = 2
4040
let inputWidth = 5
41-
41+
4242
let filter = Tensor<Float>(shape: [width, inputChannels, outputChannels],
4343
scalars: [2, 3, 4, 1, 2, 3])
4444
let bias = Tensor<Float>([0])
@@ -256,14 +256,14 @@ final class LayerTests: XCTestCase {
256256
XCTAssertEqual(output.shape, expected)
257257
}
258258

259-
func testEmbedding() {
260-
var layer = Embedding<Float>(vocabularySize: 3, embeddingSize: 5)
259+
func testEmbedding() {
260+
var layer = Embedding<Float>(vocabularySize: 3, embeddingSize: 5)
261261
var data = Tensor<Int32>(shape: [2, 3], scalars: [0, 1, 2, 1, 2, 2])
262262
var input = EmbeddingInput(indices: data)
263263
var output = layer.inferring(from: input)
264264
let expectedShape = TensorShape([2, 3, 5])
265265
XCTAssertEqual(output.shape, expectedShape)
266-
266+
267267
let pretrained = Tensor<Float>(shape:[2, 2], scalars: [0.4, 0.3, 0.2, 0.1])
268268
layer = Embedding<Float>(embeddings: pretrained)
269269
data = Tensor<Int32>(shape: [2, 2], scalars: [0, 1, 1, 1])
@@ -318,6 +318,14 @@ final class LayerTests: XCTestCase {
318318
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
319319
}
320320

321+
func testFunction() {
322+
let tanhLayer = Function<Tensor<Float>, Tensor<Float>>(tanh)
323+
let input = Tensor(shape: [5, 1], scalars: (0..<5).map(Float.init))
324+
let output = tanhLayer.inferring(from: input)
325+
let expected = Tensor<Float>([[0.0], [0.7615942], [0.9640276], [0.9950547], [0.9993292]])
326+
XCTAssertEqual(output, expected)
327+
}
328+
321329
static var allTests = [
322330
("testConv1D", testConv1D),
323331
("testConv1DDilation", testConv1DDilation),
@@ -344,6 +352,7 @@ final class LayerTests: XCTestCase {
344352
("testFlatten", testFlatten),
345353
("testEmbedding", testEmbedding),
346354
("testSimpleRNNCell", testSimpleRNNCell),
347-
("testRNN", testRNN)
355+
("testRNN", testRNN),
356+
("testFunction", testFunction)
348357
]
349358
}

0 commit comments

Comments
 (0)