Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 8f8ad5a

Browse files
eaplataniosrxwei
authored andcommitted
Makes 'Tensor.gathering' generic over 'Int32' and 'Int64'. (#272)
Added a 'TensorFlowIndex' protocol and make 'Int32' and 'Int64' conform to it.
1 parent e363592 commit 8f8ad5a

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

Sources/TensorFlow/Core/DataTypes.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ public typealias TensorFlowNumeric = TensorFlowScalar & Numeric
7070
public typealias TensorFlowSignedNumeric = TensorFlowScalar & SignedNumeric
7171
public typealias TensorFlowInteger = TensorFlowScalar & BinaryInteger
7272

73+
/// An integer data type that represents integer types which can be used as tensor indices in
74+
/// TensorFlow.
75+
public protocol TensorFlowIndex: TensorFlowInteger {}
76+
77+
extension Int32: TensorFlowIndex {}
78+
extension Int64: TensorFlowIndex {}
79+
7380
/// A floating-point data type that conforms to `Differentiable` and is compatible with TensorFlow.
7481
///
7582
/// - Note: `Tensor` conditionally conforms to `Differentiable` when the `Scalar` associated type

Sources/TensorFlow/Operators/Basic.swift

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,10 @@ public extension Tensor {
383383
/// - Returns: The gathered tensor.
384384
@inlinable
385385
@differentiable(wrt: self, vjp: _vjpGathering where Scalar : TensorFlowFloatingPoint)
386-
func gathering(atIndices indices: Tensor<Int32>, alongAxis axis: Int = 0) -> Tensor {
386+
func gathering<Index: TensorFlowIndex>(
387+
atIndices indices: Tensor<Index>,
388+
alongAxis axis: Int = 0
389+
) -> Tensor {
387390
return Raw.gatherV2(params: self, indices: indices, axis: Tensor<Int32>(Int32(axis)))
388391
}
389392

@@ -408,15 +411,15 @@ public extension Tensor {
408411
/// - Returns: The gathered tensor.
409412
@inlinable
410413
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
411-
func batchGathering(atIndices indices: Tensor<Int32>) -> Tensor {
414+
func batchGathering<Index: TensorFlowIndex>(atIndices indices: Tensor<Index>) -> Tensor {
412415
var batchIndices = indices
413-
var accumulated = Tensor<Int32>(ones: [])
414-
accumulated *= Swift.withoutDerivative(at: shapeTensor) { $0[1] }
416+
var accumulated = Tensor<Index>(ones: [])
417+
accumulated *= Swift.withoutDerivative(at: shapeTensor) { Tensor<Index>($0[1]) }
415418
let dValue = Swift.withoutDerivative(at: shapeTensor) { $0[0] }
416-
let dIndices = Tensor<Int32>(
417-
rangeFrom: Tensor<Int32>(zeros: []),
418-
to: dValue,
419-
stride: Tensor<Int32>(ones: [])
419+
let dIndices = Tensor<Index>(
420+
rangeFrom: Tensor<Index>(zeros: []),
421+
to: Tensor<Index>(dValue),
422+
stride: Tensor<Index>(ones: [])
420423
) * accumulated
421424
let dShape = Tensor<Int32>(concatenating: [
422425
dValue.rankLifted(),
@@ -519,8 +522,8 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
519522
}
520523

521524
@inlinable
522-
func _vjpGathering(
523-
atIndices indices: Tensor<Int32>,
525+
func _vjpGathering<Index: TensorFlowIndex>(
526+
atIndices indices: Tensor<Index>,
524527
alongAxis axis: Int = 0
525528
) -> (Tensor, (Tensor) -> Tensor) {
526529
let result = gathering(atIndices: indices, alongAxis: axis)

0 commit comments

Comments
 (0)