@@ -103,13 +103,37 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
103
103
%2 = Mul[T: float](%0, %1)
104
104
}
105
105
""" )
106
- let outputs = runTrace (
107
- trace: trace,
108
- input: Tensor < Float > ( 23.0 ) )
106
+ let outputs = runTrace ( trace: trace, input: Tensor < Float > ( 23.0 ) )
109
107
XCTAssertEqual ( outputs. count, 1 )
110
108
XCTAssertEqual ( outputs [ 0 ] . valueDescription, " 115.0 " )
111
109
}
112
110
111
+ func testNestedTracing( ) {
112
+ func square( input: Tensor < Float > ) -> Tensor < Float > {
113
+ return input * input
114
+ }
115
+
116
+ func nestedTrace( input: Tensor < Float > ) -> Tensor < Float > {
117
+ let trace = LazyTensorTraceBuilder . trace ( square)
118
+ let outputs = runTrace ( trace: trace, input: Tensor < Float > ( 3.0 ) )
119
+ XCTAssertEqual ( outputs. count, 1 )
120
+ let handle = TensorHandle < Float > ( handle: outputs [ 0 ] )
121
+ let y = Tensor < Float > ( handle: handle)
122
+ return y + input
123
+ }
124
+
125
+ let trace = LazyTensorTraceBuilder . trace ( nestedTrace)
126
+ XCTAssertEqual ( trace. description,
127
+ """
128
+ lazyTrace_3(%0: float) -> (%2) {
129
+ %1 = Const[dtype: float, value: 9.0]()
130
+ %2 = Add[T: float](%1, %0)
131
+ }
132
+ """ )
133
+ let outputs = runTrace ( trace: trace, input: Tensor < Float > ( 4.0 ) )
134
+ XCTAssertEqual ( outputs. count, 1 )
135
+ XCTAssertEqual ( outputs [ 0 ] . valueDescription, " 13.0 " )
136
+ }
113
137
114
138
private func runTrace( trace: LazyTensorTrace , input: TensorGroup ) -> [ TFETensorHandle ] {
115
139
let tffunc = TFFunction ( trace: trace)
@@ -122,6 +146,7 @@ final class LazyTensorExplicitTraceTests: XCTestCase {
122
146
( " testSingleInput " , testSingleInput) ,
123
147
( " testTensorGroupInputOutputs " , testTensorGroupInputOutputs) ,
124
148
( " testClosureCapturesOfTensors " , testClosureCapturesOfTensors) ,
125
- ( " testClosureCapturesOfNonTensors " , testClosureCapturesOfNonTensors)
149
+ ( " testClosureCapturesOfNonTensors " , testClosureCapturesOfNonTensors) ,
150
+ ( " testNestedTracing " , testNestedTracing)
126
151
]
127
152
}
0 commit comments