Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 3469a1b

Browse files
authored
Add eq,ge,gt,le,lt,ne,pow,relu,rem,squeeze,threshold specs. (#1072)
1 parent 5d6ebd7 commit 3469a1b

File tree

6 files changed

+587
-44
lines changed

6 files changed

+587
-44
lines changed

Sources/CX10/xla_tensor_ops_wrapper.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
2929
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
3030
#include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
31+
#include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
32+
#include "tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
3133
#include "tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
3234
#include "tensorflow/compiler/xla/client/lib/constants.h"
3335

@@ -46,13 +48,29 @@ void OpFieldToString(std::ostream& stream, const char* field_name, bool value) {
4648
void OpFieldToString(std::ostream& stream, const char* field_name, xla::int64 value) {
4749
stream << ", " << field_name << "=" << value;
4850
}
51+
void OpFieldToString(std::ostream& stream, const char* field_name, float value) {
52+
stream << ", " << field_name << "=" << value;
53+
}
4954
} // namespace swift_xla
5055

5156
namespace swift_xla {
5257
namespace ir {
5358
namespace ops {
5459
namespace {
5560

61+
using BinaryOpBuilder = xla::XlaOp(*)(xla::XlaOp, xla::XlaOp, absl::Span<const xla::int64>);
62+
template <BinaryOpBuilder T>
63+
xla::XlaOp LowerBinaryOp(xla::XlaOp lhs, xla::XlaOp rhs) {
64+
std::tie(lhs, rhs) = XlaHelpers::Promote(lhs, rhs);
65+
return T(lhs, rhs, {});
66+
}
67+
68+
xla::XlaOp LowerSqueeze(xla::XlaOp input, int dim) {
69+
if (dim == -1) return SqueezeAllTrivialDimensions(input);
70+
XLA_CHECK_GE(dim, 0);
71+
return SqueezeTrivialDimension(input, dim);
72+
}
73+
5674
xla::XlaOp LowerCumSum(xla::XlaOp input, xla::int64 dim,
5775
c10::optional<at::ScalarType> dtype, bool exclusive,
5876
bool reverse) {

0 commit comments

Comments
 (0)