Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add topk xla op. #1065

Merged
merged 1 commit into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Sources/CX10/xla_tensor_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,11 @@ OpaqueXLATensor* XLATensor_tan(OpaqueXLATensor* a) {
OpaqueXLATensor* XLATensor_tanh(OpaqueXLATensor* a) {
return new XLATensor(XLATensor::tanh(*a));
}
OpaqueXLATensor_pair XLATensor_topk(OpaqueXLATensor* a, int64_t k,
int64_t dim, bool largest) {
auto result = XLATensor::topk(*a, k, dim, largest, false);
return {new XLATensor(std::get<0>(result)), new XLATensor(std::get<1>(result))};
}
OpaqueXLATensor* XLATensor_tf_Conv(OpaqueXLATensor* input,
OpaqueXLATensor* filter, bool depthwise,
Int64ArrayRef strides, TFPadding padding,
Expand Down
2 changes: 2 additions & 0 deletions Sources/CX10/xla_tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,8 @@ XLA_API OpaqueXLATensor* XLATensor_sum(OpaqueXLATensor* a, Int64ArrayRef dims,
Optional_XLAScalarType dtype);
XLA_API OpaqueXLATensor* XLATensor_tan(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_tanh(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor_pair XLATensor_topk(OpaqueXLATensor* a, int64_t k,
int64_t dim, bool largest);
XLA_API OpaqueXLATensor*
XLATensor_tf_Conv(OpaqueXLATensor* input, OpaqueXLATensor* filter, bool depthwise,
Int64ArrayRef strides, enum TFPadding padding,
Expand Down
6 changes: 6 additions & 0 deletions Sources/x10/swift_bindings/XLATensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,12 @@ extension XLATensor {
return XLATensor(_handle: XLATensor_tanh(a.handle))
}

static func topk(_ a: XLATensor, k: Int64, dim: Int64, largest: Bool) -> (XLATensor, XLATensor) {
defer { _fixLifetime(a) }
let output = XLATensor_topk(a.handle, k, dim, largest)
return (XLATensor(_handle: output.x), XLATensor(_handle: output.y))
}

static func tf_Conv(
_ input: XLATensor, _ filter: XLATensor, _ depthwise: Bool, _ strides: [Int64],
_ padding: TFPadding, _ explicit_paddings: [Int64],
Expand Down
24 changes: 16 additions & 8 deletions Sources/x10/swift_bindings/apis/RawOpsManual.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,8 @@ public enum _RawXLA {
_ startIndices: [Tensor<Int32>],
_ sliceShape: [Int64]
) -> Tensor<T> {
return Tensor(_xla: XLATensor.dynamic_slice(base.xlaTensor, startIndices.map { $0.xlaTensor }, sliceShape))
return Tensor(
_xla: XLATensor.dynamic_slice(base.xlaTensor, startIndices.map { $0.xlaTensor }, sliceShape))
}

public static func dynamicUpdateSlice<T: TensorFlowNumeric>(
Expand Down Expand Up @@ -3042,17 +3043,17 @@ public enum _RawXLA {
///
/// NOTE `tf.reverse` has now changed behavior in preparation for 1.0.
/// `tf.reverse_v2` is currently an alias that will be deprecated before TF 1.0.
///
///
/// Given a `tensor`, and a `int32` tensor `axis` representing the set of
/// dimensions of `tensor` to reverse. This operation reverses each dimension
/// `i` for which there exists `j` s.t. `axis[j] == i`.
///
///
/// `tensor` can have up to 8 dimensions. The number of dimensions specified
/// in `axis` may be 0 or more entries. If an index is specified more than
/// once, a InvalidArgument error is raised.
///
///
/// For example:
///
///
/// ```
/// # tensor 't' is [[[[ 0, 1, 2, 3],
/// # [ 4, 5, 6, 7],
Expand All @@ -3061,23 +3062,23 @@ public enum _RawXLA {
/// # [16, 17, 18, 19],
/// # [20, 21, 22, 23]]]]
/// # tensor 't' shape is [1, 2, 3, 4]
///
///
/// # 'dims' is [3] or 'dims' is [-1]
/// reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
/// [ 7, 6, 5, 4],
/// [ 11, 10, 9, 8]],
/// [[15, 14, 13, 12],
/// [19, 18, 17, 16],
/// [23, 22, 21, 20]]]]
///
///
/// # 'dims' is '[1]' (or 'dims' is '[-3]')
/// reverse(t, dims) ==> [[[[12, 13, 14, 15],
/// [16, 17, 18, 19],
/// [20, 21, 22, 23]
/// [[ 0, 1, 2, 3],
/// [ 4, 5, 6, 7],
/// [ 8, 9, 10, 11]]]]
///
///
/// # 'dims' is '[2]' (or 'dims' is '[-2]')
/// reverse(t, dims) ==> [[[[8, 9, 10, 11],
/// [4, 5, 6, 7],
Expand Down Expand Up @@ -4259,6 +4260,13 @@ public enum _RawXLA {
return Tensor(_xla: XLATensor.tanh(x.xlaTensor))
}

public static func topk<T: FloatingPoint & TensorFlowScalar>(
_ a: Tensor<T>, k: Int64, dim: Int64, largest: Bool
) -> (Tensor<T>, Tensor<Int64>) {
let (r0, r1) = XLATensor.topk(a.xlaTensor, k: k, dim: dim, largest: largest)
return (Tensor(_xla: r0), Tensor(_xla: r1))
}

/// Assign `value` to the sliced l-value reference of `input`.
///
/// The values of `value` are assigned to the positions in the tensor `input` that
Expand Down