Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 8b699d9

Browse files
eaplataniosrxwei
authored andcommitted
Added support for a 'tiled(multiples:)' and its VJP. (#152)
1 parent 24fc2ba commit 8b699d9

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ public extension Tensor {
125125
numSplit: Int64(sizes.shape[0]))
126126
}
127127

128+
/// Returns a tiled tensor, constructed by tiling this tensor.
129+
///
130+
/// This constructor creates a new tensor by replicating this tensor `multiples` times. The
131+
/// constructed tensor's `i`'th dimension has `self.shape[i] * multiples[i]` elements, and the
132+
/// values of this tensor are replicated `multiples[i]` times along the `i`'th dimension. For
133+
/// example, tiling `[a b c d]` by `[2]` produces `[a b c d a b c d]`.
134+
///
135+
/// - Precondition: The shape of `multiples` must be `[tensor.rank]`.
136+
@inlinable
137+
@differentiable(wrt: self, vjp: _vjpTiled(multiples:) where Scalar: TensorFlowFloatingPoint)
138+
func tiled(multiples: Tensor<Int32>) -> Tensor {
139+
return Raw.tile(self, multiples: multiples)
140+
}
141+
128142
/// Reshape to the shape of the specified `Tensor`.
129143
/// - Precondition: The number of scalars matches the new shape.
130144
@inlinable
@@ -210,6 +224,17 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
210224
return (result, { v in Tensor(stacking: v.base, alongAxis: axis) })
211225
}
212226

227+
@inlinable
228+
func _vjpTiled(
229+
multiples: Tensor<Int32>
230+
) -> (Tensor, (Tensor) -> Tensor) {
231+
return (tiled(multiples: multiples), { [shape = shapeTensor] v in
232+
let splitShape = Tensor<Int32>(stacking: [multiples, shape]).transposed().flattened()
233+
let axes = Tensor<Int32>(rangeFrom: 0, to: Int32(splitShape.scalarCount), stride: 2)
234+
return v.reshaped(toShape: splitShape).sum(squeezingAxes: axes)
235+
})
236+
}
237+
213238
@inlinable
214239
func _vjpSplit(
215240
count: Int,

0 commit comments

Comments
 (0)