@@ -262,4 +262,135 @@ TensorADTests.testAllBackends("Side effects") {
262
262
expectEqual ( Tensor ( 48 ) , gradient ( at: Tensor ( 4 ) , in: bar) )
263
263
}
264
264
265
+ TensorADTests . testAllBackends ( " broadcast(toShape:) " ) {
266
+ func foo( tensor: Tensor < Float > , shape: Tensor < Int32 > ) -> Tensor < Float > {
267
+ tensor. broadcast ( toShape: shape)
268
+ }
269
+
270
+ var inputTensor : Tensor < Float >
271
+ var expected : Tensor < Float >
272
+ var pb : ( Tensor < Float > ) -> Tensor < Float >
273
+
274
+ // [3,] -> [3,3]
275
+ pb = pullback ( at: Tensor ( [ 99 , 33 , 55 ] ) ) { x in
276
+ foo ( tensor: x, shape: Tensor ( [ 3 , 3 ] ) )
277
+ }
278
+
279
+ // Test 1: same shape as parameter of pullback
280
+ inputTensor = Tensor ( [
281
+ [ 1 , 2 , 3 ] ,
282
+ [ 1 , 2 , 3 ] ,
283
+ [ 1 , 2 , 3 ] ]
284
+ )
285
+ expected = Tensor ( [ 3 , 6 , 9 ] )
286
+ expectEqual ( expected, pb ( inputTensor) )
287
+
288
+ // Test 2: different shape than parameter of pullback
289
+ inputTensor = Tensor ( [
290
+ [ 1 , 2 , 3 ] ,
291
+ [ 1 , 2 , 3 ] ,
292
+ [ 1 , 2 , 3 ] ,
293
+ [ 1 , 2 , 3 ] ]
294
+ )
295
+ expected = Tensor ( [ 4 , 8 , 12 ] )
296
+ expectEqual ( expected, pb ( inputTensor) )
297
+
298
+ // Test 3: same shape as tensor we are differentiating at
299
+ inputTensor = Tensor ( [ 1 , 2 , 3 ] )
300
+ expected = Tensor ( [ 1 , 2 , 3 ] )
301
+ expectEqual ( expected, pb ( inputTensor) )
302
+
303
+ // Test 4: extremely padded shape as tensor we are differentiating at
304
+ inputTensor = Tensor ( [ [ [ [ [ [ 1 , 2 , 3 ] ] ] ] ] ] )
305
+ expected = Tensor ( [ 1 , 2 , 3 ] )
306
+ expectEqual ( expected, pb ( inputTensor) )
307
+
308
+ // [3,1] -> [3x3]
309
+ pb = pullback ( at: Tensor ( [ [ 99 , 33 , 55 ] ] ) ) { x in
310
+ foo ( tensor: x, shape: Tensor ( [ 3 , 3 ] ) )
311
+ }
312
+
313
+ // Test 5: same shape as parameter of pullback
314
+ inputTensor = Tensor ( [
315
+ [ 1 , 2 , 3 ] ,
316
+ [ 1 , 2 , 3 ] ,
317
+ [ 1 , 2 , 3 ] ]
318
+ )
319
+ expected = Tensor ( [ [ 3 , 6 , 9 ] ] )
320
+ expectEqual ( expected, pb ( inputTensor) )
321
+
322
+ // Test 6: different shape than parameter of pullback
323
+ inputTensor = Tensor ( [
324
+ [ 1 , 2 , 3 ] ,
325
+ [ 1 , 2 , 3 ] ,
326
+ [ 1 , 2 , 3 ] ,
327
+ [ 1 , 2 , 3 ] ]
328
+ )
329
+ expected = Tensor ( [ [ 4 , 8 , 12 ] ] )
330
+ expectEqual ( expected, pb ( inputTensor) )
331
+
332
+ // Test 7: same shape as tensor we are differentiating at
333
+ inputTensor = Tensor ( [ [ 1 , 2 , 3 ] ] )
334
+ expected = Tensor ( [ [ 1 , 2 , 3 ] ] )
335
+ expectEqual ( expected, pb ( inputTensor) )
336
+
337
+ // Test 8: extremely padded shape of tensor we are differentiating at
338
+ inputTensor = Tensor ( [ [ [ [ [ [ 1 , 2 , 3 ] ] ] ] ] ] )
339
+ expected = Tensor ( [ [ 1 , 2 , 3 ] ] )
340
+ expectEqual ( expected, pb ( inputTensor) )
341
+ }
342
+
343
+ TensorADTests . testAllBackends ( " unbroadcast(toShape: " ) {
344
+ func foo( tensor: Tensor < Float > , shape: Tensor < Int32 > ) -> Tensor < Float > {
345
+ tensor. unbroadcast ( toShape: shape)
346
+ }
347
+
348
+ var inputTensor : Tensor < Float >
349
+ var expected : Tensor < Float >
350
+ var pb : ( Tensor < Float > ) -> Tensor < Float >
351
+
352
+ // [3,3] -> [1,3]
353
+ let atTensor : Tensor < Float > = Tensor ( [
354
+ [ 1 , 2 , 3 ] ,
355
+ [ 1 , 2 , 3 ] ,
356
+ [ 1 , 2 , 3 ] ]
357
+ )
358
+ pb = pullback ( at: atTensor) { x in
359
+ foo ( tensor: x, shape: Tensor ( [ 1 , 3 ] ) )
360
+ }
361
+
362
+ // Test 1: same shape as parameter of pullback
363
+ inputTensor = Tensor ( [ [ 1 , 2 , 3 ] ] )
364
+ expected = atTensor
365
+ expectEqual ( expected, pb ( inputTensor) )
366
+
367
+ // Test 2: different shape than parameter of pullback
368
+ inputTensor = Tensor ( [ 2 ] )
369
+ expected = Tensor ( [
370
+ [ 2 , 2 , 2 ] ,
371
+ [ 2 , 2 , 2 ] ,
372
+ [ 2 , 2 , 2 ] ]
373
+ )
374
+ expectEqual ( expected, pb ( inputTensor) )
375
+
376
+ // Test 3: same shape as tensor we are differentiating at
377
+ inputTensor = Tensor ( [
378
+ [ 8 , 1 , 3 ] ,
379
+ [ 8 , 1 , 3 ] ,
380
+ [ 8 , 1 , 3 ] ]
381
+ )
382
+ expected = inputTensor
383
+ expectEqual ( expected, pb ( inputTensor) )
384
+
385
+ // TODO
386
+ // Test 4: extremely padded shape as tensor we are differentiating at
387
+ // inputTensor = Tensor([
388
+ // [[8, 1, 3]],
389
+ // [[8, 1, 3]],
390
+ // [[8, 1, 3]]]
391
+ // )
392
+ // expected = Tensor([1, 2, 3])
393
+ // expectEqual(expected, pb(inputTensor))
394
+ }
395
+
265
396
runAllTests ( )
0 commit comments