@@ -38,7 +38,7 @@ final class LayerTests: XCTestCase {
38
38
// Input shapes.
39
39
let inputHeight = 2
40
40
let inputWidth = 5
41
-
41
+
42
42
let filter = Tensor < Float > ( shape: [ width, inputChannels, outputChannels] ,
43
43
scalars: [ 2 , 3 , 4 , 1 , 2 , 3 ] )
44
44
let bias = Tensor < Float > ( [ 0 ] )
@@ -256,14 +256,14 @@ final class LayerTests: XCTestCase {
256
256
XCTAssertEqual ( output. shape, expected)
257
257
}
258
258
259
- func testEmbedding( ) {
260
- var layer = Embedding < Float > ( vocabularySize: 3 , embeddingSize: 5 )
259
+ func testEmbedding( ) {
260
+ var layer = Embedding < Float > ( vocabularySize: 3 , embeddingSize: 5 )
261
261
var data = Tensor < Int32 > ( shape: [ 2 , 3 ] , scalars: [ 0 , 1 , 2 , 1 , 2 , 2 ] )
262
262
var input = EmbeddingInput ( indices: data)
263
263
var output = layer. inferring ( from: input)
264
264
let expectedShape = TensorShape ( [ 2 , 3 , 5 ] )
265
265
XCTAssertEqual ( output. shape, expectedShape)
266
-
266
+
267
267
let pretrained = Tensor < Float > ( shape: [ 2 , 2 ] , scalars: [ 0.4 , 0.3 , 0.2 , 0.1 ] )
268
268
layer = Embedding < Float > ( embeddings: pretrained)
269
269
data = Tensor < Int32 > ( shape: [ 2 , 2 ] , scalars: [ 0 , 1 , 1 , 1 ] )
@@ -318,6 +318,14 @@ final class LayerTests: XCTestCase {
318
318
// XCTAssertEqual(𝛁rnn.cell.bias, [ 0.2496884, 0.66947335, 0.7978788, -0.22378457])
319
319
}
320
320
321
+ func testFunction( ) {
322
+ let tanhLayer = Function < Tensor < Float > , Tensor < Float > > ( tanh)
323
+ let input = Tensor ( shape: [ 5 , 1 ] , scalars: ( 0 ..< 5 ) . map ( Float . init) )
324
+ let output = tanhLayer. inferring ( from: input)
325
+ let expected = Tensor < Float > ( [ [ 0.0 ] , [ 0.7615942 ] , [ 0.9640276 ] , [ 0.9950547 ] , [ 0.9993292 ] ] )
326
+ XCTAssertEqual ( output, expected)
327
+ }
328
+
321
329
static var allTests = [
322
330
( " testConv1D " , testConv1D) ,
323
331
( " testConv1DDilation " , testConv1DDilation) ,
@@ -344,6 +352,7 @@ final class LayerTests: XCTestCase {
344
352
( " testFlatten " , testFlatten) ,
345
353
( " testEmbedding " , testEmbedding) ,
346
354
( " testSimpleRNNCell " , testSimpleRNNCell) ,
347
- ( " testRNN " , testRNN)
355
+ ( " testRNN " , testRNN) ,
356
+ ( " testFunction " , testFunction)
348
357
]
349
358
}
0 commit comments