Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Add eq,ge,gt,le,lt,ne,pow,relu,rem,squeeze,threshold specs. #1072

Merged
merged 1 commit into from
Sep 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions Sources/CX10/xla_tensor_ops_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"

Expand All @@ -46,13 +48,29 @@ void OpFieldToString(std::ostream& stream, const char* field_name, bool value) {
void OpFieldToString(std::ostream& stream, const char* field_name, xla::int64 value) {
stream << ", " << field_name << "=" << value;
}
void OpFieldToString(std::ostream& stream, const char* field_name, float value) {
stream << ", " << field_name << "=" << value;
}
} // namespace swift_xla

namespace swift_xla {
namespace ir {
namespace ops {
namespace {

using BinaryOpBuilder = xla::XlaOp(*)(xla::XlaOp, xla::XlaOp, absl::Span<const xla::int64>);
template <BinaryOpBuilder T>
xla::XlaOp LowerBinaryOp(xla::XlaOp lhs, xla::XlaOp rhs) {
std::tie(lhs, rhs) = XlaHelpers::Promote(lhs, rhs);
return T(lhs, rhs, {});
}

xla::XlaOp LowerSqueeze(xla::XlaOp input, int dim) {
if (dim == -1) return SqueezeAllTrivialDimensions(input);
XLA_CHECK_GE(dim, 0);
return SqueezeTrivialDimension(input, dim);
}

xla::XlaOp LowerCumSum(xla::XlaOp input, xla::int64 dim,
c10::optional<at::ScalarType> dtype, bool exclusive,
bool reverse) {
Expand Down
Loading