@@ -66,6 +66,29 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
66
66
XCTAssertEqual ( outputs [ 1 ] . valueDescription, " 13.0 " )
67
67
}
68
68
69
+ func testClosureCaptures( ) {
70
+ let x = Tensor < Float > ( 10.0 )
71
+ let y = x + x
72
+ func fn( input: Tensor < Float > ) -> Tensor < Float > {
73
+ return input * y
74
+ }
75
+ let trace = LazyTensorTraceBuilder . trace ( fn)
76
+ /// Note that the computation x + x is encoded in the trace.
77
+ XCTAssertEqual ( trace. description,
78
+ """
79
+ lazyTrace_4(%0: float) -> (%3) {
80
+ %1 = Const[dtype: float, value: 10.0]()
81
+ %2 = Add[T: float](%1, %1)
82
+ %3 = Mul[T: float](%0, %2)
83
+ }
84
+ """ )
85
+ let outputs = runTrace (
86
+ trace: trace,
87
+ input: Tensor < Float > ( 5.0 ) )
88
+ XCTAssertEqual ( outputs. count, 1 )
89
+ XCTAssertEqual ( outputs [ 0 ] . valueDescription, " 100.0 " )
90
+ }
91
+
69
92
private func runTrace( trace: LazyTensorTrace , input: TensorGroup ) -> [ TFETensorHandle ] {
70
93
let tffunc = TFFunction ( trace: trace)
71
94
let inputHandles = input. _tensorHandles. map { $0. _tfeTensorHandle }
@@ -75,6 +98,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
75
98
76
99
static var allTests = [
77
100
( " testSingleInput " , testSingleInput) ,
78
- ( " testTensorGroupInputOutputs " , testTensorGroupInputOutputs)
101
+ ( " testTensorGroupInputOutputs " , testTensorGroupInputOutputs) ,
102
+ ( " testClosureCaptures " , testClosureCaptures)
79
103
]
80
104
}
0 commit comments