28
28
#include " tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
29
29
#include " tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
30
30
#include " tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
31
+ #include " tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
32
+ #include " tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
31
33
#include " tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
32
34
#include " tensorflow/compiler/xla/client/lib/constants.h"
33
35
@@ -46,13 +48,29 @@ void OpFieldToString(std::ostream& stream, const char* field_name, bool value) {
46
48
void OpFieldToString (std::ostream& stream, const char * field_name, xla::int64 value) {
47
49
stream << " , " << field_name << " =" << value;
48
50
}
51
+ void OpFieldToString (std::ostream& stream, const char * field_name, float value) {
52
+ stream << " , " << field_name << " =" << value;
53
+ }
49
54
} // namespace swift_xla
50
55
51
56
namespace swift_xla {
52
57
namespace ir {
53
58
namespace ops {
54
59
namespace {
55
60
61
+ using BinaryOpBuilder = xla::XlaOp(*)(xla::XlaOp, xla::XlaOp, absl::Span<const xla::int64>);
62
+ template <BinaryOpBuilder T>
63
+ xla::XlaOp LowerBinaryOp (xla::XlaOp lhs, xla::XlaOp rhs) {
64
+ std::tie (lhs, rhs) = XlaHelpers::Promote (lhs, rhs);
65
+ return T (lhs, rhs, {});
66
+ }
67
+
68
+ xla::XlaOp LowerSqueeze (xla::XlaOp input, int dim) {
69
+ if (dim == -1 ) return SqueezeAllTrivialDimensions (input);
70
+ XLA_CHECK_GE (dim, 0 );
71
+ return SqueezeTrivialDimension (input, dim);
72
+ }
73
+
56
74
xla::XlaOp LowerCumSum (xla::XlaOp input, xla::int64 dim,
57
75
c10::optional<at::ScalarType> dtype, bool exclusive,
58
76
bool reverse) {
0 commit comments