@@ -125,6 +125,20 @@ public extension Tensor {
125
125
numSplit: Int64 ( sizes. shape [ 0 ] ) )
126
126
}
127
127
128
+ /// Returns a tiled tensor, constructed by tiling this tensor.
129
+ ///
130
+ /// This constructor creates a new tensor by replicating this tensor `multiples` times. The
131
+ /// constructed tensor's `i`'th dimension has `self.shape[i] * multiples[i]` elements, and the
132
+ /// values of this tensor are replicated `multiples[i]` times along the `i`'th dimension. For
133
+ /// example, tiling `[a b c d]` by `[2]` produces `[a b c d a b c d]`.
134
+ ///
135
+ /// - Precondition: The shape of `multiples` must be `[tensor.rank]`.
136
+ @inlinable
137
+ @differentiable ( wrt: self , vjp: _vjpTiled ( multiples: ) where Scalar: TensorFlowFloatingPoint)
138
+ func tiled( multiples: Tensor < Int32 > ) -> Tensor {
139
+ return Raw . tile ( self , multiples: multiples)
140
+ }
141
+
128
142
/// Reshape to the shape of the specified `Tensor`.
129
143
/// - Precondition: The number of scalars matches the new shape.
130
144
@inlinable
@@ -210,6 +224,17 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
210
224
return ( result, { v in Tensor ( stacking: v. base, alongAxis: axis) } )
211
225
}
212
226
227
+ @inlinable
228
+ func _vjpTiled(
229
+ multiples: Tensor < Int32 >
230
+ ) -> ( Tensor , ( Tensor ) -> Tensor ) {
231
+ return ( tiled ( multiples: multiples) , { [ shape = shapeTensor] v in
232
+ let splitShape = Tensor < Int32 > ( stacking: [ multiples, shape] ) . transposed ( ) . flattened ( )
233
+ let axes = Tensor < Int32 > ( rangeFrom: 0 , to: Int32 ( splitShape. scalarCount) , stride: 2 )
234
+ return v. reshaped ( toShape: splitShape) . sum ( squeezingAxes: axes)
235
+ } )
236
+ }
237
+
213
238
@inlinable
214
239
func _vjpSplit(
215
240
count: Int ,
0 commit comments