@@ -66,7 +66,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
66
66
XCTAssertEqual ( outputs [ 1 ] . valueDescription, " 13.0 " )
67
67
}
68
68
69
- func testClosureCaptures ( ) {
69
+ func testClosureCapturesOfTensors ( ) {
70
70
let x = Tensor < Float > ( 10.0 )
71
71
let y = x + x
72
72
func fn( input: Tensor < Float > ) -> Tensor < Float > {
@@ -89,6 +89,28 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
89
89
XCTAssertEqual ( outputs [ 0 ] . valueDescription, " 100.0 " )
90
90
}
91
91
92
+ func testClosureCapturesOfNonTensors( ) {
93
+ let x : Float = 5.0
94
+ func fn( input: Tensor < Float > ) -> Tensor < Float > {
95
+ return input * Tensor < Float > ( x)
96
+ }
97
+ let trace = LazyTensorTraceBuilder . trace ( fn)
98
+ /// Note that the computation x + x is encoded in the trace.
99
+ XCTAssertEqual ( trace. description,
100
+ """
101
+ lazyTrace_3(%0: float) -> (%2) {
102
+ %1 = Const[dtype: float, value: 5.0]()
103
+ %2 = Mul[T: float](%0, %1)
104
+ }
105
+ """ )
106
+ let outputs = runTrace (
107
+ trace: trace,
108
+ input: Tensor < Float > ( 23.0 ) )
109
+ XCTAssertEqual ( outputs. count, 1 )
110
+ XCTAssertEqual ( outputs [ 0 ] . valueDescription, " 115.0 " )
111
+ }
112
+
113
+
92
114
private func runTrace( trace: LazyTensorTrace , input: TensorGroup ) -> [ TFETensorHandle ] {
93
115
let tffunc = TFFunction ( trace: trace)
94
116
let inputHandles = input. _tensorHandles. map { $0. _tfeTensorHandle }
@@ -99,6 +121,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
99
121
static var allTests = [
100
122
( " testSingleInput " , testSingleInput) ,
101
123
( " testTensorGroupInputOutputs " , testTensorGroupInputOutputs) ,
102
- ( " testClosureCaptures " , testClosureCaptures)
124
+ ( " testClosureCapturesOfTensors " , testClosureCapturesOfTensors) ,
125
+ ( " testClosureCapturesOfNonTensors " , testClosureCapturesOfNonTensors)
103
126
]
104
127
}
0 commit comments