@@ -34,6 +34,18 @@ public extension TensorFlowScalar {
34
34
}
35
35
36
36
public extension Tensor {
37
+ /// Helper function that assess if `axis` is in the range `[-rank, rank)`, where `rank` is the rank of
38
+ /// the provided tensors.
39
+ @usableFromInline
40
+ internal func preconditionAxis( _ axis: Int ) {
41
+ precondition (
42
+ axis >= - rank && axis < rank,
43
+ """
44
+ The axis must be in the range [-rank, rank)
45
+ of the provided tensors.
46
+ """ )
47
+ }
48
+
37
49
/// Unpacks the given dimension of a rank-`R` tensor into multiple rank-`(R-1)` tensors.
38
50
/// Unpacks `N` tensors from this tensor by chipping it along the `axis` dimension, where `N`
39
51
/// is inferred from this tensor's shape. For example, given a tensor with shape
@@ -59,6 +71,7 @@ public extension Tensor {
59
71
@inlinable
60
72
@differentiable ( vjp: _vjpUnstacked ( alongAxis: ) where Scalar: TensorFlowFloatingPoint)
61
73
func unstacked( alongAxis axis: Int = 0 ) -> [ Tensor ] {
74
+ preconditionAxis ( axis)
62
75
let posAxis = axis < 0 ? axis + rank : axis
63
76
return _Raw. unpack ( value: self , num: Int64 ( shape [ posAxis] ) , axis: Int64 ( posAxis) )
64
77
}
@@ -88,7 +101,11 @@ public extension Tensor {
88
101
@inlinable
89
102
@differentiable ( vjp: _vjpSplit ( count: alongAxis: ) where Scalar: TensorFlowFloatingPoint)
90
103
func split( count: Int , alongAxis axis: Int = 0 ) -> [ Tensor ] {
91
- _Raw. split ( splitDim: Tensor < Int32 > ( Int32 ( axis) ) , value: self , numSplit: Int64 ( count) )
104
+ preconditionAxis ( axis)
105
+ precondition (
106
+ shapeTensor [ axis] . scalarized ( ) % Int32( count) == 0 ,
107
+ " Number of ways to split should evenly divide the split dimension. " )
108
+ return _Raw. split ( splitDim: Tensor < Int32 > ( Int32 ( axis) ) , value: self , numSplit: Int64 ( count) )
92
109
}
93
110
94
111
/// Splits a tensor into multiple tensors. The tensor is split into `sizes.shape[0]` pieces.
@@ -119,7 +136,11 @@ public extension Tensor {
119
136
wrt: self ,
120
137
vjp: _vjpSplit ( sizes: alongAxis: ) where Scalar: TensorFlowFloatingPoint)
121
138
func split( sizes: Tensor < Int32 > , alongAxis axis: Int = 0 ) -> [ Tensor ] {
122
- _Raw. splitV (
139
+ preconditionAxis ( axis)
140
+ precondition (
141
+ shapeTensor [ axis] == sizes. sum ( ) ,
142
+ " The values in sizes must add up to the size of dimension axis. " )
143
+ return _Raw. splitV (
123
144
value: self ,
124
145
sizeSplits: sizes,
125
146
splitDim: Tensor < Int32 > ( Int32 ( axis) ) ,
@@ -133,11 +154,16 @@ public extension Tensor {
133
154
/// values of this tensor are replicated `multiples[i]` times along the `i`'th dimension. For
134
155
/// example, tiling `[a b c d]` by `[2]` produces `[a b c d a b c d]`.
135
156
///
157
+ /// - Precondition: The expected `rank` of multiples must be `1`.
136
158
/// - Precondition: The shape of `multiples` must be `[tensor.rank]`.
137
159
@inlinable
138
160
@differentiable ( wrt: self , vjp: _vjpTiled ( multiples: ) where Scalar: TensorFlowFloatingPoint)
139
161
func tiled( multiples: Tensor < Int32 > ) -> Tensor {
140
- _Raw. tile ( self , multiples: multiples)
162
+ precondition ( multiples. rank == 1 , " The expected rank of multiples must be 1. " )
163
+ precondition (
164
+ rank == multiples. shapeTensor. scalarized ( ) ,
165
+ " The shape of multiples must be [tensor.rank]. " )
166
+ return _Raw. tile ( self , multiples: multiples)
141
167
}
142
168
143
169
/// Reshape to the shape of the specified `Tensor`.
@@ -162,7 +188,23 @@ public extension Tensor {
162
188
@inlinable
163
189
@differentiable ( wrt: self , vjp: _vjpReshaped ( toShape: ) where Scalar: TensorFlowFloatingPoint)
164
190
func reshaped( toShape newShape: Tensor < Int32 > ) -> Tensor {
165
- _Raw. reshape ( self , shape: newShape)
191
+ let totalNegative = newShape. scalars. filter ( { $0 == - 1 } ) . count
192
+ let positiveShapeSizes = newShape. scalars. filter ( { $0 > 0 } )
193
+ let newShapeScalarCount = positiveShapeSizes. reduce ( 1 , { $0 * $1} )
194
+
195
+ precondition ( totalNegative <= 1 , " Only one input size may be -1. " )
196
+
197
+ if totalNegative == 1 {
198
+ precondition (
199
+ scalarCount % Int( newShapeScalarCount) == 0 ,
200
+ " The number of scalars must be a multiple of the new shape. " )
201
+ } else {
202
+ precondition (
203
+ scalarCount == newShapeScalarCount,
204
+ " The number of scalars must match the new shape. " )
205
+ }
206
+
207
+ return _Raw. reshape ( self , shape: newShape)
166
208
}
167
209
168
210
/// Return a copy of the tensor collapsed into a 1-D `Tensor`, in row-major order.
@@ -408,7 +450,8 @@ public extension Tensor {
408
450
atIndices indices: Tensor < Index > ,
409
451
alongAxis axis: Int = 0
410
452
) -> Tensor {
411
- _Raw. gatherV2 ( params: self , indices: indices, axis: Tensor < Int32 > ( Int32 ( axis) ) )
453
+ preconditionAxis ( axis)
454
+ return _Raw. gatherV2 ( params: self , indices: indices, axis: Tensor < Int32 > ( Int32 ( axis) ) )
412
455
}
413
456
414
457
/// Returns slices of this tensor at `indices` along the `axis` dimension, while ignoring the
0 commit comments