@@ -90,40 +90,36 @@ class LazyTensorTraceBuilder {
90
90
}
91
91
92
92
/// Trace the given function and return the trace.
93
- static func trace< In: TensorGroup , Out: TensorGroup > (
94
- _ fn: ( In ) -> Out
95
- ) -> LazyTensorTrace {
96
- assert ( _RuntimeConfig. useLazyTensor, " Lazy tensor is not enabled for tracing. " )
93
+ static func trace< In: TensorGroup , Out: TensorGroup > ( _ fn: ( In ) -> Out ) -> LazyTensorTrace {
94
+ precondition ( _RuntimeConfig. useLazyTensor, " Lazy tensor is not enabled for tracing. " )
97
95
98
96
// Set up inputs for running `fn`
99
- let inputs = In . _typeList. map { Self . makePlaceholder ( with: $0) }
100
- let inputHandles = inputs . map { LazyTensorHandle ( _lazy: $0, index: 0 ) }
97
+ let inputOps = In . _typeList. map { Self . makePlaceholder ( with: $0) }
98
+ let inputHandles = inputOps . map { LazyTensorHandle ( _lazy: $0, index: 0 ) }
101
99
let input = In ( _handles: inputHandles)
102
100
103
101
// Run the function.
104
102
let output : TensorArrayProtocol = fn ( input)
105
103
106
104
// Set up the closure that determines if a `LazyTensorOperation` should be an output.
107
- let outputLazyOperations = output. _tensorHandles. map { ( handle: _AnyTensorHandle ) -> LazyTensorOperation in
105
+ let outputLazyOperations = output. _tensorHandles. map {
106
+ ( handle: _AnyTensorHandle ) -> LazyTensorOperation in
108
107
let lazyOp = lazyTensorOperation ( handle)
109
- assert ( lazyOp != nil , " Found a non-lazy tensor in output when tracing. " )
108
+ precondition ( lazyOp != nil , " Found a non-lazy tensor in output when tracing. " )
110
109
return lazyOp!
111
110
}
112
- let outputIds = Set < ObjectIdentifier > ( outputLazyOperations. map {
113
- ObjectIdentifier ( $0)
114
- } )
111
+ let outputIds = Set < ObjectIdentifier > ( outputLazyOperations. map { ObjectIdentifier ( $0) } )
115
112
let isOutput : ( LazyTensorOperation ) -> Bool = { outputIds. contains ( ObjectIdentifier ( $0) ) }
116
113
117
114
// Create the builder and get the trace.
118
115
let builder = LazyTensorTraceBuilder ( )
119
116
builder. neverPromoteConstants = true
120
117
builder. isOutput = isOutput
121
- /// Set up the inputs for the builder as we need to have specific order.
122
- for inputOp in inputs {
123
- let id = ObjectIdentifier ( inputOp)
124
- builder. updateOperationAndCache ( id, inputOp)
118
+ /// Set up the inputs for the builder as we need to have them in a specific order.
119
+ for inputOp in inputOps {
120
+ builder. updateOperationAndCache ( ObjectIdentifier ( inputOp) , inputOp)
125
121
}
126
- builder. inputs = inputs
122
+ builder. inputs = inputOps
127
123
for lazyOp in outputLazyOperations { _ = builder. collectLazyOperation ( lazyOp) }
128
124
return LazyTensorTrace (
129
125
inputs: builder. inputs,
0 commit comments