Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 75bd243

Browse files
committed
Add topk xla op.
1 parent b9f6c3b commit 75bd243

File tree

4 files changed

+29
-8
lines changed

4 files changed

+29
-8
lines changed

Sources/CX10/xla_tensor_wrapper.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,11 @@ OpaqueXLATensor* XLATensor_tan(OpaqueXLATensor* a) {
660660
OpaqueXLATensor* XLATensor_tanh(OpaqueXLATensor* a) {
661661
return new XLATensor(XLATensor::tanh(*a));
662662
}
663+
OpaqueXLATensor_pair XLATensor_topk(OpaqueXLATensor* a, int64_t k,
664+
int64_t dim, bool largest) {
665+
auto result = XLATensor::topk(*a, k, dim, largest, false);
666+
return {new XLATensor(std::get<0>(result)), new XLATensor(std::get<1>(result))};
667+
}
663668
OpaqueXLATensor* XLATensor_tf_Conv(OpaqueXLATensor* input,
664669
OpaqueXLATensor* filter, bool depthwise,
665670
Int64ArrayRef strides, TFPadding padding,

Sources/CX10/xla_tensor_wrapper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ XLA_API OpaqueXLATensor* XLATensor_sum(OpaqueXLATensor* a, Int64ArrayRef dims,
376376
Optional_XLAScalarType dtype);
377377
XLA_API OpaqueXLATensor* XLATensor_tan(OpaqueXLATensor* a);
378378
XLA_API OpaqueXLATensor* XLATensor_tanh(OpaqueXLATensor* a);
379+
XLA_API OpaqueXLATensor_pair XLATensor_topk(OpaqueXLATensor* a, int64_t k,
380+
int64_t dim, bool largest);
379381
XLA_API OpaqueXLATensor*
380382
XLATensor_tf_Conv(OpaqueXLATensor* input, OpaqueXLATensor* filter, bool depthwise,
381383
Int64ArrayRef strides, enum TFPadding padding,

Sources/x10/swift_bindings/XLATensor.swift

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,12 @@ extension XLATensor {
887887
return XLATensor(_handle: XLATensor_tanh(a.handle))
888888
}
889889

890+
static func topk(_ a: XLATensor, k: Int64, dim: Int64, largest: Bool) -> (XLATensor, XLATensor) {
891+
defer { _fixLifetime(a) }
892+
let output = XLATensor_topk(a.handle, k, dim, largest)
893+
return (XLATensor(_handle: output.x), XLATensor(_handle: output.y))
894+
}
895+
890896
static func tf_Conv(
891897
_ input: XLATensor, _ filter: XLATensor, _ depthwise: Bool, _ strides: [Int64],
892898
_ padding: TFPadding, _ explicit_paddings: [Int64],

Sources/x10/swift_bindings/apis/RawOpsManual.swift

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,7 +1425,8 @@ public enum _RawXLA {
14251425
_ startIndices: [Tensor<Int32>],
14261426
_ sliceShape: [Int64]
14271427
) -> Tensor<T> {
1428-
return Tensor(_xla: XLATensor.dynamic_slice(base.xlaTensor, startIndices.map { $0.xlaTensor }, sliceShape))
1428+
return Tensor(
1429+
_xla: XLATensor.dynamic_slice(base.xlaTensor, startIndices.map { $0.xlaTensor }, sliceShape))
14291430
}
14301431

14311432
public static func dynamicUpdateSlice<T: TensorFlowNumeric>(
@@ -3042,17 +3043,17 @@ public enum _RawXLA {
30423043
///
30433044
/// NOTE `tf.reverse` has now changed behavior in preparation for 1.0.
30443045
/// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0.
3045-
///
3046+
///
30463047
/// Given a `tensor`, and a `int32` tensor `axis` representing the set of
30473048
/// dimensions of `tensor` to reverse. This operation reverses each dimension
30483049
/// `i` for which there exists `j` s.t. `axis[j] == i`.
3049-
///
3050+
///
30503051
/// `tensor` can have up to 8 dimensions. The number of dimensions specified
30513052
/// in `axis` may be 0 or more entries. If an index is specified more than
30523053
/// once, a InvalidArgument error is raised.
3053-
///
3054+
///
30543055
/// For example:
3055-
///
3056+
///
30563057
/// ```
30573058
/// # tensor 't' is [[[[ 0, 1, 2, 3],
30583059
/// # [ 4, 5, 6, 7],
@@ -3061,23 +3062,23 @@ public enum _RawXLA {
30613062
/// # [16, 17, 18, 19],
30623063
/// # [20, 21, 22, 23]]]]
30633064
/// # tensor 't' shape is [1, 2, 3, 4]
3064-
///
3065+
///
30653066
/// # 'dims' is [3] or 'dims' is [-1]
30663067
/// reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
30673068
/// [ 7, 6, 5, 4],
30683069
/// [ 11, 10, 9, 8]],
30693070
/// [[15, 14, 13, 12],
30703071
/// [19, 18, 17, 16],
30713072
/// [23, 22, 21, 20]]]]
3072-
///
3073+
///
30733074
/// # 'dims' is '[1]' (or 'dims' is '[-3]')
30743075
/// reverse(t, dims) ==> [[[[12, 13, 14, 15],
30753076
/// [16, 17, 18, 19],
30763077
/// [20, 21, 22, 23]
30773078
/// [[ 0, 1, 2, 3],
30783079
/// [ 4, 5, 6, 7],
30793080
/// [ 8, 9, 10, 11]]]]
3080-
///
3081+
///
30813082
/// # 'dims' is '[2]' (or 'dims' is '[-2]')
30823083
/// reverse(t, dims) ==> [[[[8, 9, 10, 11],
30833084
/// [4, 5, 6, 7],
@@ -4259,6 +4260,13 @@ public enum _RawXLA {
42594260
return Tensor(_xla: XLATensor.tanh(x.xlaTensor))
42604261
}
42614262

4263+
public static func topk<T: FloatingPoint & TensorFlowScalar>(
4264+
_ a: Tensor<T>, k: Int64, dim: Int64, largest: Bool
4265+
) -> (Tensor<T>, Tensor<Int64>) {
4266+
let (r0, r1) = XLATensor.topk(a.xlaTensor, k: k, dim: dim, largest: largest)
4267+
return (Tensor(_xla: r0), Tensor(_xla: r1))
4268+
}
4269+
42624270
/// Assign `value` to the sliced l-value reference of `input`.
42634271
///
42644272
/// The values of `value` are assigned to the positions in the tensor `input` that

0 commit comments

Comments
 (0)