@@ -78,24 +78,6 @@ public protocol _TensorFlowDataTypeCompatible {
78
78
/// The underlying TensorFlow data type.
79
79
@inlinable
80
80
static var tensorFlowDataType : TensorDataType { get }
81
-
82
- // Hooks used by the TFPartition pass for primitive operations on tensors.
83
- // These should not be called directly or implemented.
84
-
85
- /// This converts a TensorHandle that is known to have a 0d value into
86
- /// the scalar that it produces. Users should call the _TFGetScalarOrDie
87
- /// wrapper function.
88
- static func _getScalarOrDie( _ handle: TensorHandle < Self > ) -> Self
89
-
90
- /// This converts a TensorHandle into a scalar if it is 0d, or returns nil
91
- /// otherwise. Users should call the _TFGetScalar wrapper function.
92
- static func _getScalar( _ handle: TensorHandle < Self > ) -> Self ?
93
-
94
- /// This indicates that it is safe to hoist the specified computation that
95
- /// creates a tensor to being a parameter that is passed in from outside of
96
- /// the tensor program.
97
- static func _hoistableClosure( _ fn: ( ) -> TensorHandle < Self > )
98
- -> TensorHandle < Self >
99
81
}
100
82
101
83
/// A scalar data type compatible with TensorFlow.
@@ -125,206 +107,67 @@ public protocol TensorFlowFloatingPoint :
125
107
extension Float : TensorFlowFloatingPoint { }
126
108
extension Double : TensorFlowFloatingPoint { }
127
109
128
- // This is the implementation of the _getScalarOrDie requirement for each
129
- // concrete type below. We use this round-about approach to implement the
130
- // global _TFGetScalarOrDie function in order to ensure that the noinline
131
- // SIL functions below have non-generic type signatures. This is important for
132
- // the inner workings of the partitioning pass.
133
- private func _TFGetScalarOrDieImpl< Scalar> (
134
- _ handle: TensorHandle < Scalar >
135
- ) -> Scalar {
136
- return handle. makeHostCopy ( ) . scalar!
137
- }
138
-
139
- // This is the implementation of the _getScalar requirement for each concrete
140
- // type below. We use this round-about approach to implement the
141
- // global _TFGetScalar function in order to ensure that the noinline
142
- // SIL functions below have non-generic type signatures. This is important for
143
- // the inner workings of the partitioning pass.
144
- private func _TFGetScalarImpl< Scalar> (
145
- _ handle: TensorHandle < Scalar >
146
- ) -> Scalar ? {
147
- return handle. makeHostCopy ( ) . scalar
148
- }
149
-
150
110
extension Bool : TensorFlowScalar {
151
111
@inlinable
152
112
public static var tensorFlowDataType : TensorDataType {
153
113
return TensorDataType ( TF_BOOL)
154
114
}
155
- @_silgen_name ( " __tf_get_scalar_or_die_Bool " ) @inline ( never)
156
- public static func _getScalarOrDie( _ handle: TensorHandle < Bool > ) -> Bool {
157
- return _TFGetScalarOrDieImpl ( handle)
158
- }
159
- @_silgen_name ( " __tf_get_scalar_Bool " ) @inline ( never)
160
- public static func _getScalar( _ handle: TensorHandle < Bool > ) -> Bool ? {
161
- return _TFGetScalarImpl ( handle)
162
- }
163
- @_silgen_name ( " __tf_hoistable_Bool " ) @_optimize ( none) @inline ( never)
164
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < Bool > )
165
- -> TensorHandle < Bool > {
166
- return fn ( )
167
- }
168
115
}
169
116
170
117
extension Int8 : TensorFlowScalar {
171
118
@inlinable
172
119
public static var tensorFlowDataType : TensorDataType {
173
120
return TensorDataType ( TF_INT8)
174
121
}
175
- @_silgen_name ( " __tf_get_scalar_or_die_Int8 " ) @inline ( never)
176
- public static func _getScalarOrDie( _ handle: TensorHandle < Int8 > ) -> Int8 {
177
- return _TFGetScalarOrDieImpl ( handle)
178
- }
179
- @_silgen_name ( " __tf_get_scalar_Int8 " ) @inline ( never)
180
- public static func _getScalar( _ handle: TensorHandle < Int8 > ) -> Int8 ? {
181
- return _TFGetScalarImpl ( handle)
182
- }
183
- @_silgen_name ( " __tf_hoistable_Int8 " ) @_optimize ( none) @inline ( never)
184
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < Int8 > )
185
- -> TensorHandle < Int8 > {
186
- return fn ( )
187
- }
188
122
}
189
123
190
124
extension UInt8 : TensorFlowScalar {
191
125
@inlinable
192
126
public static var tensorFlowDataType : TensorDataType {
193
127
return TensorDataType ( TF_UINT8)
194
128
}
195
- @_silgen_name ( " __tf_get_scalar_or_die_UInt8 " ) @inline ( never)
196
- public static func _getScalarOrDie( _ handle: TensorHandle < UInt8 > ) -> UInt8 {
197
- return _TFGetScalarOrDieImpl ( handle)
198
- }
199
- @_silgen_name ( " __tf_get_scalar_UInt8 " ) @inline ( never)
200
- public static func _getScalar( _ handle: TensorHandle < UInt8 > ) -> UInt8 ? {
201
- return _TFGetScalarImpl ( handle)
202
- }
203
- @_silgen_name ( " __tf_hoistable_UInt8 " ) @_optimize ( none) @inline ( never)
204
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < UInt8 > )
205
- -> TensorHandle < UInt8 > {
206
- return fn ( )
207
- }
208
129
}
209
130
210
131
extension Int16 : TensorFlowScalar {
211
132
@inlinable
212
133
public static var tensorFlowDataType : TensorDataType {
213
134
return TensorDataType ( TF_INT16)
214
135
}
215
- @_silgen_name ( " __tf_get_scalar_or_die_Int16 " ) @inline ( never)
216
- public static func _getScalarOrDie( _ handle: TensorHandle < Int16 > ) -> Int16 {
217
- return _TFGetScalarOrDieImpl ( handle)
218
- }
219
- @_silgen_name ( " __tf_get_scalar_Int16 " ) @inline ( never)
220
- public static func _getScalar( _ handle: TensorHandle < Int16 > ) -> Int16 ? {
221
- return _TFGetScalarImpl ( handle)
222
- }
223
- @_silgen_name ( " __tf_hoistable_Int16 " ) @_optimize ( none) @inline ( never)
224
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < Int16 > )
225
- -> TensorHandle < Int16 > {
226
- return fn ( )
227
- }
228
136
}
229
137
230
138
extension UInt16 : TensorFlowScalar {
231
139
@inlinable
232
140
public static var tensorFlowDataType : TensorDataType {
233
141
return TensorDataType ( TF_UINT16)
234
142
}
235
- @_silgen_name ( " __tf_get_scalar_or_die_UInt16 " ) @inline ( never)
236
- public static func _getScalarOrDie( _ handle: TensorHandle < UInt16 > ) -> UInt16 {
237
- return _TFGetScalarOrDieImpl ( handle)
238
- }
239
- @_silgen_name ( " __tf_get_scalar_UInt16 " ) @inline ( never)
240
- public static func _getScalar( _ handle: TensorHandle < UInt16 > ) -> UInt16 ? {
241
- return _TFGetScalarImpl ( handle)
242
- }
243
- @_silgen_name ( " __tf_hoistable_UInt16 " ) @_optimize ( none) @inline ( never)
244
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < UInt16 > )
245
- -> TensorHandle < UInt16 > {
246
- return fn ( )
247
- }
248
143
}
249
144
250
145
extension Int32 : TensorFlowScalar {
251
146
@inlinable
252
147
public static var tensorFlowDataType : TensorDataType {
253
148
return TensorDataType ( TF_INT32)
254
149
}
255
- @_silgen_name ( " __tf_get_scalar_or_die_Int32 " ) @inline ( never)
256
- public static func _getScalarOrDie( _ handle: TensorHandle < Int32 > ) -> Int32 {
257
- return _TFGetScalarOrDieImpl ( handle)
258
- }
259
- @_silgen_name ( " __tf_get_scalar_Int32 " ) @inline ( never)
260
- public static func _getScalar( _ handle: TensorHandle < Int32 > ) -> Int32 ? {
261
- return _TFGetScalarImpl ( handle)
262
- }
263
- @_silgen_name ( " __tf_hoistable_Int32 " ) @_optimize ( none) @inline ( never)
264
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < Int32 > )
265
- -> TensorHandle < Int32 > {
266
- return fn ( )
267
- }
268
150
}
269
151
270
152
extension UInt32 : TensorFlowScalar {
271
153
@inlinable
272
154
public static var tensorFlowDataType : TensorDataType {
273
155
return TensorDataType ( TF_UINT32)
274
156
}
275
- @_silgen_name ( " __tf_get_scalar_or_die_UInt32 " ) @inline ( never)
276
- public static func _getScalarOrDie( _ handle: TensorHandle < UInt32 > ) -> UInt32 {
277
- return _TFGetScalarOrDieImpl ( handle)
278
- }
279
- @_silgen_name ( " __tf_get_scalar_UInt32 " ) @inline ( never)
280
- public static func _getScalar( _ handle: TensorHandle < UInt32 > ) -> UInt32 ? {
281
- return _TFGetScalarImpl ( handle)
282
- }
283
- @_silgen_name ( " __tf_hoistable_UInt32 " ) @_optimize ( none) @inline ( never)
284
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < UInt32 > )
285
- -> TensorHandle < UInt32 > {
286
- return fn ( )
287
- }
288
157
}
289
158
290
159
extension Int64 : TensorFlowScalar {
291
160
@inlinable
292
161
public static var tensorFlowDataType : TensorDataType {
293
162
return TensorDataType ( TF_INT64)
294
163
}
295
- @_silgen_name ( " __tf_get_scalar_or_die_Int64 " ) @inline ( never)
296
- public static func _getScalarOrDie( _ handle: TensorHandle < Int64 > ) -> Int64 {
297
- return _TFGetScalarOrDieImpl ( handle)
298
- }
299
- @_silgen_name ( " __tf_get_scalar_Int64 " ) @inline ( never)
300
- public static func _getScalar( _ handle: TensorHandle < Int64 > ) -> Int64 ? {
301
- return _TFGetScalarImpl ( handle)
302
- }
303
- @_silgen_name ( " __tf_hoistable_Int64 " ) @_optimize ( none) @inline ( never)
304
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < Int64 > )
305
- -> TensorHandle < Int64 > {
306
- return fn ( )
307
- }
308
164
}
309
165
310
166
extension UInt64 : TensorFlowScalar {
311
167
@inlinable
312
168
public static var tensorFlowDataType : TensorDataType {
313
169
return TensorDataType ( TF_UINT64)
314
170
}
315
- @_silgen_name ( " __tf_get_scalar_or_die_UInt64 " ) @inline ( never)
316
- public static func _getScalarOrDie( _ handle: TensorHandle < UInt64 > ) -> UInt64 {
317
- return _TFGetScalarOrDieImpl ( handle)
318
- }
319
- @_silgen_name ( " __tf_get_scalar_UInt64 " ) @inline ( never)
320
- public static func _getScalar( _ handle: TensorHandle < UInt64 > ) -> UInt64 ? {
321
- return _TFGetScalarImpl ( handle)
322
- }
323
- @_silgen_name ( " __tf_hoistable_UInt64 " ) @_optimize ( none) @inline ( never)
324
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < UInt64 > )
325
- -> TensorHandle < UInt64 > {
326
- return fn ( )
327
- }
328
171
}
329
172
330
173
@_fixed_layout
@@ -338,80 +181,25 @@ extension BFloat16 : TensorFlowScalar {
338
181
public static var tensorFlowDataType : TensorDataType {
339
182
return TensorDataType ( TF_BFLOAT16)
340
183
}
341
- @_silgen_name ( " __tf_get_scalar_or_die_BFloat16 " ) @inline ( never)
342
- public static func _getScalarOrDie
343
- ( _ handle: TensorHandle < BFloat16 >
344
- ) -> BFloat16 {
345
- return _TFGetScalarOrDieImpl ( handle)
346
- }
347
- @_silgen_name ( " __tf_get_scalar_BFloat16 " ) @inline ( never)
348
- public static func _getScalar( _ handle: TensorHandle < BFloat16 > ) -> BFloat16 ? {
349
- return _TFGetScalarImpl ( handle)
350
- }
351
- @_silgen_name ( " __tf_hoistable_BFloat16 " ) @_optimize ( none) @inline ( never)
352
- public static func _hoistableClosure(
353
- _ fn: ( ) -> TensorHandle < BFloat16 >
354
- ) -> TensorHandle < BFloat16 > {
355
- return fn ( )
356
- }
357
184
}
358
185
359
186
extension Float : TensorFlowScalar {
360
187
@inlinable
361
188
public static var tensorFlowDataType : TensorDataType {
362
189
return TensorDataType ( TF_FLOAT)
363
190
}
364
- @_silgen_name ( " __tf_get_scalar_or_die_Float " ) @inline ( never)
365
- public static func _getScalarOrDie( _ handle: TensorHandle < Float > ) -> Float {
366
- return _TFGetScalarOrDieImpl ( handle)
367
- }
368
- @_silgen_name ( " __tf_get_scalar_Float " ) @inline ( never)
369
- public static func _getScalar( _ handle: TensorHandle < Float > ) -> Float ? {
370
- return _TFGetScalarImpl ( handle)
371
- }
372
- @_silgen_name ( " __tf_hoistable_Float " ) @_optimize ( none) @inline ( never)
373
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < Float > )
374
- -> TensorHandle < Float > {
375
- return fn ( )
376
- }
377
191
}
378
192
379
193
extension Double : TensorFlowScalar {
380
194
@inlinable
381
195
public static var tensorFlowDataType : TensorDataType {
382
196
return TensorDataType ( TF_DOUBLE)
383
197
}
384
- @_silgen_name ( " __tf_get_scalar_or_die_Double " ) @inline ( never)
385
- public static func _getScalarOrDie( _ handle: TensorHandle < Double > ) -> Double {
386
- return _TFGetScalarOrDieImpl ( handle)
387
- }
388
- @_silgen_name ( " __tf_get_scalar_Double " ) @inline ( never)
389
- public static func _getScalar( _ handle: TensorHandle < Double > ) -> Double ? {
390
- return _TFGetScalarImpl ( handle)
391
- }
392
- @_silgen_name ( " __tf_hoistable_Double " ) @_optimize ( none) @inline ( never)
393
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < Double > )
394
- -> TensorHandle < Double > {
395
- return fn ( )
396
- }
397
198
}
398
199
399
200
extension String : _TensorFlowDataTypeCompatible {
400
201
@inlinable
401
202
public static var tensorFlowDataType : TensorDataType {
402
203
return TensorDataType ( TF_STRING)
403
204
}
404
- @_silgen_name ( " __tf_get_scalar_or_die_String " ) @inline ( never)
405
- public static func _getScalarOrDie( _ handle: TensorHandle < String > ) -> String {
406
- return _TFGetScalarOrDieImpl ( handle)
407
- }
408
- @_silgen_name ( " __tf_get_scalar_String " ) @inline ( never)
409
- public static func _getScalar( _ handle: TensorHandle < String > ) -> String ? {
410
- return _TFGetScalarImpl ( handle)
411
- }
412
- @_silgen_name ( " __tf_hoistable_String " ) @_optimize ( none) @inline ( never)
413
- public static func _hoistableClosure( _ fn: ( ) -> TensorHandle < String > )
414
- -> TensorHandle < String > {
415
- return fn ( )
416
- }
417
205
}
0 commit comments