Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 145fff4

Browse files
committed
Add topk xla op.
1 parent b88bbf9 commit 145fff4

File tree

4 files changed

+33
-10
lines changed

4 files changed

+33
-10
lines changed

Sources/CX10/xla_tensor_wrapper.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,11 @@ OpaqueXLATensor* XLATensor_tan(OpaqueXLATensor* a) {
651651
OpaqueXLATensor* XLATensor_tanh(OpaqueXLATensor* a) {
652652
return new XLATensor(XLATensor::tanh(*a));
653653
}
654+
OpaqueXLATensor_pair XLATensor_topk(OpaqueXLATensor* a, int64_t k,
655+
int64_t dim, bool largest) {
656+
auto result = XLATensor::topk(*a, k, dim, largest, false);
657+
return {new XLATensor(std::get<0>(result)), new XLATensor(std::get<1>(result))};
658+
}
654659
OpaqueXLATensor* XLATensor_tf_Conv(OpaqueXLATensor* input,
655660
OpaqueXLATensor* filter, bool depthwise,
656661
Int64ArrayRef strides, TFPadding padding,

Sources/CX10/xla_tensor_wrapper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ XLA_API OpaqueXLATensor* XLATensor_sum(OpaqueXLATensor* a, Int64ArrayRef dims,
374374
Optional_XLAScalarType dtype);
375375
XLA_API OpaqueXLATensor* XLATensor_tan(OpaqueXLATensor* a);
376376
XLA_API OpaqueXLATensor* XLATensor_tanh(OpaqueXLATensor* a);
377+
XLA_API OpaqueXLATensor_pair XLATensor_topk(OpaqueXLATensor* a, int64_t k,
378+
int64_t dim, bool largest);
377379
XLA_API OpaqueXLATensor*
378380
XLATensor_tf_Conv(OpaqueXLATensor* input, OpaqueXLATensor* filter, bool depthwise,
379381
Int64ArrayRef strides, enum TFPadding padding,

Sources/x10/swift_bindings/XLATensor.swift

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,9 @@ extension XLATensor {
407407
return XLATensor(_handle: XLATensor_div(a.handle, b.handle))
408408
}
409409

410-
static func dynamic_slice(_ base: XLATensor, _ start_indices: [XLATensor], _ slice_shape: [Int64]) -> XLATensor {
410+
static func dynamic_slice(_ base: XLATensor, _ start_indices: [XLATensor], _ slice_shape: [Int64])
411+
-> XLATensor
412+
{
411413
start_indices.withArrayRef { start_indices in
412414
slice_shape.withArrayRef { slice_shape in
413415
return XLATensor(_handle: XLATensor_dynamic_slice(base.handle, start_indices, slice_shape))
@@ -761,7 +763,7 @@ extension XLATensor {
761763
}
762764

763765
static func replica_id(_ device: Device) -> XLATensor {
764-
return XLATensor(_handle: XLATensor_replica_id(device.cdevice));
766+
return XLATensor(_handle: XLATensor_replica_id(device.cdevice))
765767
}
766768

767769
static func resize_value(_ value: XLATensor, _ dims: [Int64]) -> XLATensor {
@@ -874,6 +876,12 @@ extension XLATensor {
874876
return XLATensor(_handle: XLATensor_tanh(a.handle))
875877
}
876878

879+
static func topk(_ a: XLATensor, k: Int64, dim: Int64, largest: Bool) -> (XLATensor, XLATensor) {
880+
defer { _fixLifetime(a) }
881+
let output = XLATensor_topk(a.handle, k, dim, largest)
882+
return (XLATensor(_handle: output.x), XLATensor(_handle: output.y))
883+
}
884+
877885
static func tf_Conv(
878886
_ input: XLATensor, _ filter: XLATensor, _ depthwise: Bool, _ strides: [Int64],
879887
_ padding: TFPadding, _ explicit_paddings: [Int64],

Sources/x10/swift_bindings/apis/RawOpsManual.swift

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1425,7 +1425,8 @@ public enum _RawXLA {
14251425
_ startIndices: [Tensor<Int32>],
14261426
_ sliceShape: [Int64]
14271427
) -> Tensor<T> {
1428-
return Tensor(_xla: XLATensor.dynamic_slice(base.xlaTensor, startIndices.map { $0.xlaTensor }, sliceShape))
1428+
return Tensor(
1429+
_xla: XLATensor.dynamic_slice(base.xlaTensor, startIndices.map { $0.xlaTensor }, sliceShape))
14291430
}
14301431

14311432
public static func dynamicUpdateSlice<T: TensorFlowNumeric>(
@@ -3042,17 +3043,17 @@ public enum _RawXLA {
30423043
///
30433044
/// NOTE `tf.reverse` has now changed behavior in preparation for 1.0.
30443045
/// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0.
3045-
///
3046+
///
30463047
/// Given a `tensor`, and a `int32` tensor `axis` representing the set of
30473048
/// dimensions of `tensor` to reverse. This operation reverses each dimension
30483049
/// `i` for which there exists `j` s.t. `axis[j] == i`.
3049-
///
3050+
///
30503051
/// `tensor` can have up to 8 dimensions. The number of dimensions specified
30513052
/// in `axis` may be 0 or more entries. If an index is specified more than
30523053
/// once, a InvalidArgument error is raised.
3053-
///
3054+
///
30543055
/// For example:
3055-
///
3056+
///
30563057
/// ```
30573058
/// # tensor 't' is [[[[ 0, 1, 2, 3],
30583059
/// # [ 4, 5, 6, 7],
@@ -3061,23 +3062,23 @@ public enum _RawXLA {
30613062
/// # [16, 17, 18, 19],
30623063
/// # [20, 21, 22, 23]]]]
30633064
/// # tensor 't' shape is [1, 2, 3, 4]
3064-
///
3065+
///
30653066
/// # 'dims' is [3] or 'dims' is [-1]
30663067
/// reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
30673068
/// [ 7, 6, 5, 4],
30683069
/// [ 11, 10, 9, 8]],
30693070
/// [[15, 14, 13, 12],
30703071
/// [19, 18, 17, 16],
30713072
/// [23, 22, 21, 20]]]]
3072-
///
3073+
///
30733074
/// # 'dims' is '[1]' (or 'dims' is '[-3]')
30743075
/// reverse(t, dims) ==> [[[[12, 13, 14, 15],
30753076
/// [16, 17, 18, 19],
30763077
/// [20, 21, 22, 23]
30773078
/// [[ 0, 1, 2, 3],
30783079
/// [ 4, 5, 6, 7],
30793080
/// [ 8, 9, 10, 11]]]]
3080-
///
3081+
///
30813082
/// # 'dims' is '[2]' (or 'dims' is '[-2]')
30823083
/// reverse(t, dims) ==> [[[[8, 9, 10, 11],
30833084
/// [4, 5, 6, 7],
@@ -4259,6 +4260,13 @@ public enum _RawXLA {
42594260
return Tensor(_xla: XLATensor.tanh(x.xlaTensor))
42604261
}
42614262

4263+
public static func topk<T: FloatingPoint & TensorFlowScalar>(
4264+
_ a: Tensor<T>, k: Int64, dim: Int64, largest: Bool
4265+
) -> (Tensor<T>, Tensor<Int64>) {
4266+
let (r0, r1) = XLATensor.topk(a.xlaTensor, k: k, dim: dim, largest: largest)
4267+
return (Tensor(_xla: r0), Tensor(_xla: r1))
4268+
}
4269+
42624270
/// Assign `value` to the sliced l-value reference of `input`.
42634271
///
42644272
/// The values of `value` are assigned to the positions in the tensor `input` that

0 commit comments

Comments
 (0)