@@ -34,9 +34,10 @@ public extension TensorFlowScalar {
34
34
}
35
35
36
36
public extension Tensor {
37
- /// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors. Unpacks
38
- /// `N` tensors from this tensor by chipping it along the `axis` dimension, where `N` is
39
- /// inferred from this tensor's shape. For example, given a tensor with shape `[A, B, C, D]`:
37
+ /// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors.
38
+ /// Unpacks `N` tensors from this tensor by chipping it along the `axis` dimension, where `N`
39
+ /// is inferred from this tensor's shape. For example, given a tensor with shape
40
+ /// `[A, B, C, D]`:
40
41
///
41
42
/// - If `axis == 0` then the `i`-th tensor in the returned array is the slice
42
43
/// `self[i, :, :, :]` and each tensor in that array will have shape `[B, C, D]`.
@@ -51,14 +52,15 @@ public extension Tensor {
51
52
/// - Parameters:
52
53
/// - axis: Dimension along which to unstack. Negative values wrap around.
53
54
///
54
- /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
55
- /// provided tensors.
55
+ /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
56
+ /// the provided tensors.
56
57
///
57
58
/// - Returns: Array containing the unstacked tensors.
58
59
@inlinable
59
60
@differentiable ( vjp: _vjpUnstacked ( alongAxis: ) where Scalar: TensorFlowFloatingPoint)
60
61
func unstacked( alongAxis axis: Int = 0 ) -> [ Tensor ] {
61
- return Raw . unpack ( value: self , num: Int64 ( shape [ axis] ) , axis: Int64 ( axis) )
62
+ let posAxis = axis < 0 ? axis + rank : axis
63
+ return Raw . unpack ( value: self , num: Int64 ( shape [ posAxis] ) , axis: Int64 ( posAxis) )
62
64
}
63
65
64
66
/// Splits a tensor into multiple tensors. The tensor is split along dimension `axis` into
@@ -79,15 +81,14 @@ public extension Tensor {
79
81
/// - axis: The dimension along which to split this tensor. Negative values wrap around.
80
82
///
81
83
/// - Precondition: `count` must divide the size of dimension `axis` evenly.
82
- /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
83
- /// provided tensors.
84
+ /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
85
+ /// the provided tensors.
84
86
///
85
87
/// - Returns: An array containing the tensors parts.
86
88
@inlinable
87
89
@differentiable ( vjp: _vjpSplit ( count: alongAxis: ) where Scalar: TensorFlowFloatingPoint)
88
90
func split( count: Int , alongAxis axis: Int = 0 ) -> [ Tensor ] {
89
- return Raw . split (
90
- splitDim: Tensor < Int32 > ( Int32 ( axis) ) , value: self , numSplit: Int64 ( count) )
91
+ Raw . split ( splitDim: Tensor < Int32 > ( Int32 ( axis) ) , value: self , numSplit: Int64 ( count) )
91
92
}
92
93
93
94
/// Splits a tensor into multiple tensors. The tensor is split into `sizes.shape[0]` pieces.
@@ -109,16 +110,16 @@ public extension Tensor {
109
110
/// - axis: Dimension along which to split this tensor. Negative values wrap around.
110
111
///
111
112
/// - Precondition: The values in `sizes` must add up to the size of dimension `axis`.
112
- /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of the
113
- /// provided tensors.
113
+ /// - Precondition: `axis` must be in the range `[-rank, rank)`, where `rank` is the rank of
114
+ /// the provided tensors.
114
115
///
115
116
/// - Returns: Array containing the tensors parts.
116
117
@inlinable
117
118
@differentiable (
118
119
wrt: self ,
119
120
vjp: _vjpSplit ( sizes: alongAxis: ) where Scalar: TensorFlowFloatingPoint)
120
121
func split( sizes: Tensor < Int32 > , alongAxis axis: Int = 0 ) -> [ Tensor ] {
121
- return Raw . splitV (
122
+ Raw . splitV (
122
123
value: self ,
123
124
sizeSplits: sizes,
124
125
splitDim: Tensor < Int32 > ( Int32 ( axis) ) ,
@@ -136,15 +137,15 @@ public extension Tensor {
136
137
@inlinable
137
138
@differentiable ( wrt: self , vjp: _vjpTiled ( multiples: ) where Scalar: TensorFlowFloatingPoint)
138
139
func tiled( multiples: Tensor < Int32 > ) -> Tensor {
139
- return Raw . tile ( self , multiples: multiples)
140
+ Raw . tile ( self , multiples: multiples)
140
141
}
141
142
142
143
/// Reshape to the shape of the specified `Tensor`.
143
144
/// - Precondition: The number of scalars matches the new shape.
144
145
@inlinable
145
146
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
146
147
func reshaped< T> ( like other: Tensor < T > ) -> Tensor {
147
- return reshaped ( toShape: other. shapeTensor)
148
+ reshaped ( toShape: other. shapeTensor)
148
149
}
149
150
150
151
/// Reshape to the specified shape.
@@ -153,32 +154,30 @@ public extension Tensor {
153
154
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
154
155
func reshaped( to newShape: TensorShape ) -> Tensor {
155
156
// TODO(TF-433): Remove workaround for differentiating `map`.
156
- return reshaped ( toShape: Tensor < Int32 > ( { newShape. dimensions. map ( Int32 . init) } ( ) ) )
157
+ reshaped ( toShape: Tensor < Int32 > ( { newShape. dimensions. map ( Int32 . init) } ( ) ) )
157
158
}
158
159
159
160
/// Reshape to the specified `Tensor` representing a shape.
160
161
/// - Precondition: The number of scalars matches the new shape.
161
162
@inlinable
162
- @differentiable (
163
- wrt: self ,
164
- vjp: _vjpReshaped ( toShape: ) where Scalar: TensorFlowFloatingPoint)
163
+ @differentiable ( wrt: self , vjp: _vjpReshaped ( toShape: ) where Scalar: TensorFlowFloatingPoint)
165
164
func reshaped( toShape newShape: Tensor < Int32 > ) -> Tensor {
166
- return Raw . reshape ( self , shape: newShape)
165
+ Raw . reshape ( self , shape: newShape)
167
166
}
168
167
169
168
/// Return a copy of the tensor collapsed into a 1-D `Tensor`, in row-major order.
170
169
@inlinable
171
170
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
172
171
func flattened( ) -> Tensor {
173
- return reshaped ( to: [ - 1 ] )
172
+ reshaped ( to: [ - 1 ] )
174
173
}
175
174
176
175
/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the specified shape
177
176
/// indices.
178
177
@inlinable
179
178
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
180
179
func expandingShape( at axes: Int ... ) -> Tensor {
181
- return expandingShape ( at: axes)
180
+ expandingShape ( at: axes)
182
181
}
183
182
184
183
/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
@@ -195,23 +194,23 @@ public extension Tensor {
195
194
@inlinable
196
195
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
197
196
func rankLifted( ) -> Tensor {
198
- return expandingShape ( at: 0 )
197
+ expandingShape ( at: 0 )
199
198
}
200
199
201
200
/// Removes the specified dimensions of size 1 from the shape of a tensor. If no dimensions are
202
201
/// specified, then all dimensions of size 1 will be removed.
203
202
@inlinable
204
203
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
205
204
func squeezingShape( at axes: Int ... ) -> Tensor {
206
- return squeezingShape ( at: axes)
205
+ squeezingShape ( at: axes)
207
206
}
208
207
209
208
/// Removes the specified dimensions of size 1 from the shape of a tensor. If no dimensions are
210
209
/// specified, then all dimensions of size 1 will be removed.
211
210
@inlinable
212
211
@differentiable ( wrt: self , vjp: _vjpSqueezingShape ( at: ) where Scalar: TensorFlowFloatingPoint)
213
212
func squeezingShape( at axes: [ Int ] ) -> Tensor {
214
- return Raw . squeeze ( self , squeezeDims: axes. map ( Int32 . init) )
213
+ Raw . squeeze ( self , squeezeDims: axes. map ( Int32 . init) )
215
214
}
216
215
}
217
216
@@ -225,10 +224,8 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
225
224
}
226
225
227
226
@inlinable
228
- func _vjpTiled(
229
- multiples: Tensor < Int32 >
230
- ) -> ( Tensor , ( Tensor ) -> Tensor ) {
231
- return ( tiled ( multiples: multiples) , { [ shape = shapeTensor] v in
227
+ func _vjpTiled( multiples: Tensor < Int32 > ) -> ( Tensor , ( Tensor ) -> Tensor ) {
228
+ ( tiled ( multiples: multiples) , { [ shape = shapeTensor] v in
232
229
let splitShape = Tensor < Int32 > ( stacking: [ multiples, shape] ) . transposed ( ) . flattened ( )
233
230
let axes = Tensor < Int32 > ( rangeFrom: 0 , to: Int32 ( splitShape. scalarCount) , stride: 2 )
234
231
return v. reshaped ( toShape: splitShape) . sum ( squeezingAxes: axes)
0 commit comments