Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit efc6acf

Browse files
authored
Port over tensor_autodiff_runtime.swift tests. (#235)
- Port over tensor_autodiff_runtime.swift tests from apple/swift PR 24899.
1 parent 95e5a6d commit efc6acf

File tree

2 files changed

+546
-4
lines changed

2 files changed

+546
-4
lines changed

Tests/TensorFlowTests/OperatorTests/BasicTests.swift

Lines changed: 118 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,19 +443,54 @@ final class BasicOperatorTests: XCTestCase {
443443
XCTAssertEqual(result.shape, [1, 3, 1, 2, 1])
444444
}
445445

446-
func testUnbroadcast1() {
446+
func testUnbroadcastRank4ToRank2() {
447447
let x = Tensor<Float>(repeating: 1, shape: [2, 3, 4, 5])
448448
let y = Tensor<Float>(repeating: 1, shape: [4, 5])
449449
let z = x.unbroadcasted(like: y)
450450
XCTAssertEqual(z.array, ShapedArray<Float>(repeating: 6, shape: [4, 5]))
451451
}
452452

453-
func testUnbroadcast2() {
453+
func testUnbroadcastRank4ToRank3() {
454454
let x = Tensor<Float>(repeating: 1, shape: [2, 3, 4, 5])
455455
let y = Tensor<Float>(repeating: 1, shape: [3, 1, 5])
456456
let z = x.unbroadcasted(like: y)
457457
XCTAssertEqual(z.array, ShapedArray<Float>(repeating: 8, shape: [3, 1, 5]))
458458
}
459+
460+
func testUnbroadcast3x3To1x3() {
461+
func foo(tensor: Tensor<Float>, shape: Tensor<Int32>) -> Tensor<Float> {
462+
tensor.unbroadcasted(toShape: shape)
463+
}
464+
465+
// [3,3] -> [1,3]
466+
let atTensor: Tensor<Float> = [
467+
[1, 2, 3],
468+
[1, 2, 3],
469+
[1, 2, 3]]
470+
var pb: (Tensor<Float>) -> Tensor<Float> = pullback(at: atTensor) { x in
471+
foo(tensor: x, shape: [1, 3])
472+
}
473+
474+
// Same shape as parameter of pullback
475+
var inputTensor: Tensor<Float> = [[1, 2, 3]]
476+
var expected: Tensor<Float> = atTensor
477+
XCTAssertEqual(expected, pb(inputTensor))
478+
// Different shape than parameter of pullback
479+
inputTensor = [2]
480+
expected = [
481+
[2, 2, 2],
482+
[2, 2, 2],
483+
[2, 2, 2]]
484+
XCTAssertEqual(expected, pb(inputTensor))
485+
486+
// Same shape as tensor we are differentiating at
487+
inputTensor = [
488+
[8, 1, 3],
489+
[8, 1, 3],
490+
[8, 1, 3]]
491+
expected = inputTensor
492+
XCTAssertEqual(expected, pb(inputTensor))
493+
}
459494

460495
func testSliceUpdate() {
461496
var t1 = Tensor<Float>([[1, 2, 3], [4, 5, 6]])
@@ -482,6 +517,82 @@ final class BasicOperatorTests: XCTestCase {
482517
target .= Tensor(repeating: 1, shape: [1, 3, 1])
483518
XCTAssertEqual(target, Tensor(repeating: 1, shape: [2, 3, 4]))
484519
}
520+
521+
func testBroadcast3x0To3x3() {
522+
func foo(tensor: Tensor<Float>, shape: Tensor<Int32>) -> Tensor<Float> {
523+
tensor.broadcasted(toShape: shape)
524+
}
525+
526+
// [3,] -> [3,3]
527+
var pb: (Tensor<Float>) -> Tensor<Float> = pullback(at: [99, 33, 55]) { x in
528+
foo(tensor: x, shape: [3, 3])
529+
}
530+
531+
// Same shape as parameter of pullback
532+
var inputTensor: Tensor<Float> = [
533+
[1, 2, 3],
534+
[1, 2, 3],
535+
[1, 2, 3]]
536+
var expected: Tensor<Float> = [3, 6, 9]
537+
XCTAssertEqual(expected, pb(inputTensor))
538+
539+
// Different shape than parameter of pullback
540+
inputTensor = [
541+
[1, 2, 3],
542+
[1, 2, 3],
543+
[1, 2, 3],
544+
[1, 2, 3]]
545+
expected = [4, 8, 12]
546+
XCTAssertEqual(expected, pb(inputTensor))
547+
548+
// Same shape as tensor we are differentiating at
549+
inputTensor = [1, 2, 3]
550+
expected = [1, 2, 3]
551+
XCTAssertEqual(expected, pb(inputTensor))
552+
553+
// Extremely padded shape as tensor we are differentiating at
554+
inputTensor = [[[[[[1, 2, 3]]]]]]
555+
expected = [1, 2, 3]
556+
XCTAssertEqual(expected, pb(inputTensor))
557+
}
558+
559+
func testBroadcast3x1To3x3() {
560+
func foo(tensor: Tensor<Float>, shape: Tensor<Int32>) -> Tensor<Float> {
561+
tensor.broadcasted(toShape: shape)
562+
}
563+
564+
// [3,1] -> [3x3]
565+
var pb: (Tensor<Float>) -> Tensor<Float> = pullback(at: [[99, 33, 55]]) { x in
566+
foo(tensor: x, shape: [3, 3])
567+
}
568+
569+
// Same shape as parameter of pullback
570+
var inputTensor: Tensor<Float> = [
571+
[1, 2, 3],
572+
[1, 2, 3],
573+
[1, 2, 3]]
574+
var expected: Tensor<Float> = [[3, 6, 9]]
575+
XCTAssertEqual(expected, pb(inputTensor))
576+
577+
// Different shape than parameter of pullback
578+
inputTensor = [
579+
[1, 2, 3],
580+
[1, 2, 3],
581+
[1, 2, 3],
582+
[1, 2, 3]]
583+
expected = [[4, 8, 12]]
584+
XCTAssertEqual(expected, pb(inputTensor))
585+
586+
// Same shape as tensor we are differentiating at
587+
inputTensor = [[1, 2, 3]]
588+
expected = [[1, 2, 3]]
589+
XCTAssertEqual(expected, pb(inputTensor))
590+
591+
// Extremely padded shape of tensor we are differentiating at
592+
inputTensor = [[[[[[1, 2, 3]]]]]]
593+
expected = [[1, 2, 3]]
594+
XCTAssertEqual(expected, pb(inputTensor))
595+
}
485596

486597
static var allTests = [
487598
("testGathering", testGathering),
@@ -507,9 +618,12 @@ final class BasicOperatorTests: XCTestCase {
507618
("testFlatten0D", testFlatten0D),
508619
("testReshapeToScalar", testReshapeToScalar),
509620
("testReshapeTensor", testReshapeTensor),
510-
("testUnbroadcast1", testUnbroadcast1),
511-
("testUnbroadcast2", testUnbroadcast2),
621+
("testUnbroadcastRank4ToRank2", testUnbroadcastRank4ToRank2),
622+
("testUnbroadcastRank4ToRank3", testUnbroadcastRank4ToRank3),
623+
("testUnbroadcast3x3To1x3", testUnbroadcast3x3To1x3),
512624
("testSliceUpdate", testSliceUpdate),
625+
("testBroadcast3x0To3x3", testBroadcast3x0To3x3),
626+
("testBroadcast3x1To3x3", testBroadcast3x1To3x3),
513627
("testBroadcastTensor", testBroadcastTensor)
514628
]
515629
}

0 commit comments

Comments
 (0)