@@ -335,6 +335,107 @@ public extension Tensor {
335
335
static func ++ ( lhs: Tensor , rhs: Tensor ) -> Tensor {
336
336
return lhs. concatenated ( with: rhs)
337
337
}
338
+
339
+ /// Gathers slices of this tensor at `indices` along the `axis` dimension.
340
+ ///
341
+ /// For 0-D (scalar) `indices`:
342
+ /// ```
343
+ /// result[p_0, ..., p_{axis-1},
344
+ /// p_{axis + 1}, ..., p_{N-1}] =
345
+ /// self[p_0, ..., p_{axis-1},
346
+ /// indices,
347
+ /// p_{axis + 1}, ..., p_{N-1}]
348
+ /// ```
349
+ ///
350
+ /// For 1-D (vector) `indices`:
351
+ /// ```
352
+ /// result[p_0, ..., p_{axis-1},
353
+ /// i,
354
+ /// p_{axis + 1}, ..., p_{N-1}] =
355
+ /// self[p_0, ..., p_{axis-1},
356
+ /// indices[i],
357
+ /// p_{axis + 1}, ..., p_{N-1}]
358
+ /// ```
359
+ ///
360
+ /// In the general case, produces a resulting tensor where:
361
+ /// ```
362
+ /// result[p_0, ..., p_{axis-1},
363
+ /// i_{batch\_dims}, ..., i_{M-1},
364
+ /// p_{axis + 1}, ..., p_{N-1}] =
365
+ /// self[p_0, ..., p_{axis-1},
366
+ /// indices[i_0, ..., i_{M-1}],
367
+ /// p_{axis + 1}, ..., p_{N-1}]
368
+ /// ```
369
+ /// where `N = self.rank` and `M = indices.rank`.
370
+ ///
371
+ /// The shape of the resulting tensor is:
372
+ /// `self.shape[..<axis] + indices.shape + self.shape[(axis + 1)...]`.
373
+ ///
374
+ /// - Note: On CPU, if an out-of-range index is found, an error is thrown. On GPU, if an
375
+ /// out-of-range index is found, a 0 is stored in the corresponding output values.
376
+ ///
377
+ /// - Parameters:
378
+ /// - indices: Contains the indices to gather at.
379
+ /// - axis: Dimension along which to gather. Negative values wrap around.
380
+ ///
381
+ /// - Precondition: `axis` must be in the range `[-rank, rank)`.
382
+ ///
383
+ /// - Returns: The gathered tensor.
384
+ @inlinable
385
+ @differentiable ( wrt: self , vjp: _vjpGathering where Scalar : TensorFlowFloatingPoint)
386
+ func gathering( atIndices indices: Tensor < Int32 > , alongAxis axis: Int = 0 ) -> Tensor {
387
+ return Raw . gatherV2 ( params: self , indices: indices, axis: Tensor < Int32 > ( Int32 ( axis) ) )
388
+ }
389
+
390
+ /// Gathers values from this tensor according to the provided boolean mask.
391
+ ///
392
+ /// For example:
393
+ /// ```
394
+ /// // 1-D example
395
+ /// // tensor is [0, 1, 2, 3]
396
+ /// // mask is [true, false, true, false]
397
+ /// tensor.gathering(where: mask) // is [0, 2]
398
+ ///
399
+ /// // 2-D example
400
+ /// // tensor is [[1, 2], [3, 4], [5, 6]]
401
+ /// // mask is [true, false, true]
402
+ /// tensor.gathering(where: mask) // is [[1, 2], [5, 6]]
403
+ /// ```
404
+ ///
405
+ /// In general, `0 < mask.rank = K <= tensor.rank`, and the `mask`'s shape must match the first
406
+ /// K dimensions of the `tensor`'s shape. We then have:
407
+ /// `tensor.gathering(where: mask)[i, j1, ..., jd] = tensor[i1, ..., iK, j1, ..., jd]`, where
408
+ /// `[i1, ..., iK]` is the `i`th `true` entry of `mask` (row-major order).
409
+ ///
410
+ /// The `axis` could be used with `mask` to indicate the axis to mask from. In that case,
411
+ /// `axis + mask.rank <= tensor.rank` and the `mask``'s shape must match the first
412
+ /// `axis + mask.rank` dimensions of the `tensor`'s shape.
413
+ ///
414
+ /// - Parameters:
415
+ /// - mask: K-D boolean tensor, where `K <= self.rank`.
416
+ /// - axis: 0-D integer tensor representing the axis in `self` to mask from, where
417
+ /// `K + axis <= self.rank`.
418
+ ///
419
+ /// - Precondition: The `mask` cannot be a scalar: `mask.rank != 0`.
420
+ ///
421
+ /// - Returns: `(self.rank - K + 1)`-dimensional tensor populated by entries in this tensor
422
+ /// corresponding to `true` values in `mask`.
423
+ @inlinable
424
+ // @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
425
+ func gathering( where mask: Tensor < Bool > , alongAxis axis: Int = 0 ) -> Tensor {
426
+ precondition ( mask. rank != 0 , " The boolean mask cannot be a scalar. " )
427
+ // TODO: Remove once control flow AD is supported.
428
+ let rank = self . rank
429
+ let posAxis = { axis < 0 ? axis + rank : axis } ( )
430
+ let leadingSize = shapeTensor [ posAxis ..< posAxis + mask. rank] . product ( ) . rankLifted ( )
431
+ let reshapedTensor = reshaped (
432
+ toShape: Tensor < Int32 > ( concatenating: [
433
+ shapeTensor [ ..< posAxis] ,
434
+ leadingSize,
435
+ shapeTensor [ ( posAxis + mask. rank) ... ] ] ) )
436
+ let indices = Tensor < Int32 > ( mask. flattened ( ) . nonZeroIndices ( ) . squeezingShape ( at: 1 ) )
437
+ return reshapedTensor. gathering ( atIndices: indices, alongAxis: posAxis)
438
+ }
338
439
}
339
440
340
441
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@@ -375,6 +476,103 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
375
476
return ( gradients [ 0 ] , gradients [ 1 ] )
376
477
} )
377
478
}
479
+
480
+ @inlinable
481
+ func _vjpGathering(
482
+ atIndices indices: Tensor < Int32 > ,
483
+ alongAxis axis: Int = 0
484
+ ) -> ( Tensor , ( Tensor ) -> Tensor ) {
485
+ let result = gathering ( atIndices: indices, alongAxis: axis)
486
+ let posAxis = axis < 0 ? axis + rank : axis
487
+
488
+ // We have a fast gradient implementation for the case when `posAxis == 0`.
489
+ if posAxis == 0 {
490
+ return ( result, { [ shape = shapeTensor] v in
491
+ let indicesCount = indices. scalarCountTensor. rankLifted ( )
492
+ let valuesShape = Tensor < Int32 > ( concatenating: [ indicesCount, shape [ 1 ... ] ] )
493
+ let values = v. reshaped ( toShape: valuesShape)
494
+ let valueIndices = indices. reshaped ( toShape: indicesCount)
495
+ return Raw . unsortedSegmentSum (
496
+ data: values,
497
+ segmentIds: valueIndices,
498
+ numSegments: shape [ 0 ] )
499
+ } )
500
+ }
501
+
502
+ return ( result, { [ shape = shapeTensor] v in
503
+ let indicesSize = Tensor < Int32 > ( Int32 ( indices. scalarCount) ) . rankLifted ( )
504
+ let outerShape = shape [ ..< posAxis]
505
+ let outerSize = outerShape. scalarCount
506
+ let innerShape = shape [ ( posAxis + 1 ) ... ]
507
+ let innerSize = innerShape. scalarCount
508
+ let outerIndices = Tensor < Int32 > ( rangeFrom: 0 , to: Int32 ( outerSize) , stride: 1 )
509
+ let innerIndices = Tensor < Int32 > (
510
+ rangeFrom: Int32 ( outerSize) + 1 ,
511
+ to: Int32 ( outerSize) + 1 + Int32( innerSize) ,
512
+ stride: 1 )
513
+ let valuesShape = Tensor < Int32 > ( concatenating: [ outerShape, indicesSize, innerShape] )
514
+ let values = v. reshaped ( toShape: valuesShape)
515
+ let valueIndices = indices. reshaped ( toShape: indicesSize)
516
+
517
+ // We need to sum up every slice `values[..., i, ....]` corresponding to
518
+ // `tensor[..., indices[i], ...]`. Since `unsortedSegmentSum` does not support an axis
519
+ // parameter, we transpose the gather dimension to the front, then use
520
+ // `unsortedSegmentSum` to build a `[gatherAxis, outerAxes, innerAxes]` tensor with all
521
+ // the gradients affecting each index in `gatherAxis` summed up.
522
+ let permutations = Tensor < Int32 > ( concatenating: [
523
+ Tensor < Int32 > ( [ Int32 ( outerSize) ] ) ,
524
+ outerIndices,
525
+ innerIndices] )
526
+ let transposedValues = values. transposed ( withPermutations: permutations)
527
+ let gradient = Raw . unsortedSegmentSum (
528
+ data: transposedValues,
529
+ segmentIds: valueIndices,
530
+ numSegments: shape [ posAxis] )
531
+
532
+ // Finally, we invert the above transpose operation by moving dimension 0 back to its
533
+ // original position.
534
+ let inversePermutations = Tensor < Int32 > ( concatenating: [
535
+ outerIndices + 1 ,
536
+ Tensor < Int32 > ( [ 0 ] ) ,
537
+ innerIndices] )
538
+ return gradient. transposed ( withPermutations: inversePermutations)
539
+ } )
540
+ }
541
+ }
542
+
543
+ public extension Tensor {
544
+ /// Returns the locations of non-zero / true values in this tensor.
545
+ ///
546
+ /// The coordinates are returned in a 2-D tensor where the first dimension (rows) represents the
547
+ /// number of non-zero elements, and the second dimension (columns) represents the coordinates
548
+ /// of the non-zero elements. Keep in mind that the shape of the output tensor can vary
549
+ /// depending on how many true values there are in this tensor. Indices are output in row-major
550
+ /// order.
551
+ ///
552
+ /// For example:
553
+ /// ```
554
+ /// // 'input' is [[true, false], [true, false]]
555
+ /// // 'input' has 2 true values and so the output has 2 rows.
556
+ /// // 'input' has rank of 2, and so the second dimension of the output has size 2.
557
+ /// input.nonZeroIndices() // is [[0, 0], [1, 0]]
558
+ ///
559
+ /// // 'input' is [[[ true, false], [ true, false]],
560
+ /// // [[false, true], [false, true]],
561
+ /// // [[false, false], [false, true]]]
562
+ /// // 'input' has 5 true values and so the output has 5 rows.
563
+ /// // 'input' has rank 3, and so the second dimension of the output has size 3.
564
+ /// input.nonZeroIndices() // is [[0, 0, 0],
565
+ /// // [0, 1, 0],
566
+ /// // [1, 0, 1],
567
+ /// // [1, 1, 1],
568
+ /// // [2, 1, 1]]
569
+ /// ```
570
+ ///
571
+ /// - Returns: A tensor with shape `(num_true, rank(condition))`.
572
+ @inlinable
573
+ func nonZeroIndices( ) -> Tensor < Int64 > {
574
+ return Raw . where_ ( self )
575
+ }
378
576
}
379
577
380
578
//===------------------------------------------------------------------------------------------===//
0 commit comments