@@ -1323,6 +1323,75 @@ ForwardModeTests.test("ForceUnwrapping") {
1323
1323
// Array methods from ArrayDifferentiation.swift
1324
1324
//===----------------------------------------------------------------------===//
1325
1325
1326
+ typealias FloatArrayTan = Array < Float > . TangentVector
1327
+
1328
+ ForwardModeTests . test ( " Array.+ " ) {
1329
+ func sumFirstThreeConcatenating( _ a: [ Float ] , _ b: [ Float ] ) -> Float {
1330
+ let c = a + b
1331
+ return c [ 0 ] + c[ 1 ] + c[ 2 ]
1332
+ }
1333
+
1334
+ expectEqual ( 3 , differential ( at: [ 0 , 0 ] , [ 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ 1 , 1 ] ) , . init( [ 1 , 1 ] ) ) )
1335
+ expectEqual ( 0 , differential ( at: [ 0 , 0 ] , [ 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ 0 , 0 ] ) , . init( [ 0 , 1 ] ) ) )
1336
+ expectEqual ( 1 , differential ( at: [ 0 , 0 ] , [ 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ 0 , 1 ] ) , . init( [ 0 , 1 ] ) ) )
1337
+ expectEqual ( 1 , differential ( at: [ 0 , 0 ] , [ 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ 1 , 0 ] ) , . init( [ 0 , 1 ] ) ) )
1338
+ expectEqual ( 1 , differential ( at: [ 0 , 0 ] , [ 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ 0 , 0 ] ) , . init( [ 1 , 1 ] ) ) )
1339
+ expectEqual ( 2 , differential ( at: [ 0 , 0 ] , [ 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ 1 , 1 ] ) , . init( [ 0 , 1 ] ) ) )
1340
+
1341
+ expectEqual (
1342
+ 3 ,
1343
+ differential ( at: [ 0 , 0 , 0 , 0 ] , [ 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ 1 , 1 , 1 , 1 ] ) , . init( [ 1 , 1 ] ) ) )
1344
+ expectEqual (
1345
+ 3 ,
1346
+ differential ( at: [ 0 , 0 , 0 , 0 ] , [ 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ 1 , 1 , 1 , 0 ] ) , . init( [ 0 , 0 ] ) ) )
1347
+
1348
+ expectEqual (
1349
+ 3 ,
1350
+ differential ( at: [ ] , [ 0 , 0 , 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ ] ) , . init( [ 1 , 1 , 1 , 1 ] ) ) )
1351
+ expectEqual (
1352
+ 0 ,
1353
+ differential ( at: [ ] , [ 0 , 0 , 0 , 0 ] , in: sumFirstThreeConcatenating) ( . init( [ ] ) , . init( [ 0 , 0 , 0 , 1 ] ) ) )
1354
+ }
1355
+
1356
+ ForwardModeTests . test ( " Array.init(repeating:count:) " ) {
1357
+ @differentiable
1358
+ func repeating( _ x: Float ) -> [ Float ] {
1359
+ Array ( repeating: x, count: 10 )
1360
+ }
1361
+ expectEqual ( Float ( 10 ) , derivative ( at: . zero) { x in
1362
+ repeating ( x) . differentiableReduce ( 0 , { $0 + $1} )
1363
+ } )
1364
+ expectEqual ( Float ( 20 ) , differential ( at: . zero, in: { x in
1365
+ repeating ( x) . differentiableReduce ( 0 , { $0 + $1} )
1366
+ } ) ( 2 ) )
1367
+ }
1368
+
1369
+ ForwardModeTests . test ( " Array.DifferentiableView.init " ) {
1370
+ @differentiable
1371
+ func constructView( _ x: [ Float ] ) -> Array < Float > . DifferentiableView {
1372
+ return Array< Float> . DifferentiableView( x)
1373
+ }
1374
+
1375
+ let forward = differential ( at: [ 5 , 6 , 7 , 8 ] , in: constructView)
1376
+ expectEqual (
1377
+ FloatArrayTan ( [ 1 , 2 , 3 , 4 ] ) ,
1378
+ forward ( FloatArrayTan ( [ 1 , 2 , 3 , 4 ] ) ) )
1379
+ }
1380
+
1381
+ ForwardModeTests . test ( " Array.DifferentiableView.base " ) {
1382
+ @differentiable
1383
+ func accessBase( _ x: Array < Float > . DifferentiableView ) -> [ Float ] {
1384
+ return x. base
1385
+ }
1386
+
1387
+ let forward = differential (
1388
+ at: Array< Float> . DifferentiableView( [ 5 , 6 , 7 , 8 ] ) ,
1389
+ in: accessBase)
1390
+ expectEqual (
1391
+ FloatArrayTan ( [ 1 , 2 , 3 , 4 ] ) ,
1392
+ forward ( FloatArrayTan ( [ 1 , 2 , 3 , 4 ] ) ) )
1393
+ }
1394
+
1326
1395
ForwardModeTests . test ( " Array.differentiableMap " ) {
1327
1396
let x : [ Float ] = [ 1 , 2 , 3 ]
1328
1397
let tan = Array< Float> . TangentVector( [ 1 , 1 , 1 ] )
@@ -1338,4 +1407,24 @@ ForwardModeTests.test("Array.differentiableMap") {
1338
1407
expectEqual ( [ 2 , 4 , 6 ] , differential ( at: x, in: squareMap) ( tan) )
1339
1408
}
1340
1409
1410
+ ForwardModeTests . test ( " Array.differentiableReduce " ) {
1411
+ let x : [ Float ] = [ 1 , 2 , 3 ]
1412
+ let tan = Array< Float> . TangentVector( [ 1 , 1 , 1 ] )
1413
+
1414
+ func sumReduce( _ a: [ Float ] ) -> Float {
1415
+ return a. differentiableReduce ( 0 , { $0 + $1 } )
1416
+ }
1417
+ expectEqual ( 1 + 1 + 1 , differential ( at: x, in: sumReduce) ( tan) )
1418
+
1419
+ func productReduce( _ a: [ Float ] ) -> Float {
1420
+ return a. differentiableReduce ( 1 , { $0 * $1 } )
1421
+ }
1422
+ expectEqual ( x [ 1 ] * x[ 2 ] + x[ 0 ] * x[ 2 ] + x[ 0 ] * x[ 1 ] , differential ( at: x, in: productReduce) ( tan) )
1423
+
1424
+ func sumOfSquaresReduce( _ a: [ Float ] ) -> Float {
1425
+ return a. differentiableReduce ( 0 , { $0 + $1 * $1 } )
1426
+ }
1427
+ expectEqual ( 2 * x[ 0 ] + 2 * x[ 1 ] + 2 * x[ 2 ] , differential ( at: x, in: sumOfSquaresReduce) ( tan) )
1428
+ }
1429
+
1341
1430
runAllTests ( )
0 commit comments