18
18
#define XLA_API __attribute__ ((__visibility__(" default" )))
19
19
#endif
20
20
21
- #include " xla_tensor_wrapper.h"
22
-
23
21
#include " tensorflow/compiler/xla/xla_client/debug_macros.h"
24
22
#include " tensorflow/compiler/xla/xla_client/util.h"
23
+ #include " tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
24
+ #include " tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
25
+ #include " tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
25
26
#include " tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
26
27
#include " tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
27
28
#include " tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
28
- #include " tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
29
29
#include " tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
30
+ #include " tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
30
31
#include " tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
31
- #include " tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
32
- #include " tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
33
- #include " tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
32
+ #include " tensorflow/compiler/tf2xla/xla_tensor/xla_lower_util.h"
34
33
#include " tensorflow/compiler/xla/client/lib/constants.h"
34
+ #include " tensorflow/compiler/xla/client/lib/math.h"
35
+ #include " xla_tensor_wrapper.h"
35
36
36
37
namespace at {
37
38
xla::hash_t Hash (const c10::optional<at::ScalarType>& dtype) {
38
39
return xla::util::Hash (swift_xla::OptionalOr<int >(dtype, -1 ));
39
40
}
41
+ xla::hash_t Hash (const at::Scalar& value) {
42
+ return value.isFloatingPoint () ? xla::util::Hash (value.toDouble ())
43
+ : xla::util::Hash (value.toLong ());
44
+ }
40
45
}
41
46
namespace swift_xla {
42
47
void OpFieldToString (std::ostream& stream, const char * field_name, const c10::optional<at::ScalarType>& dtype) {
@@ -51,20 +56,45 @@ void OpFieldToString(std::ostream& stream, const char* field_name, xla::int64 va
51
56
void OpFieldToString (std::ostream& stream, const char * field_name, float value) {
52
57
stream << " , " << field_name << " =" << value;
53
58
}
59
+ void OpFieldToString (std::ostream& stream, const char * field_name,
60
+ const std::vector<xla::int64>& value) {
61
+ stream << " , " << field_name << " =[" ;
62
+ for (size_t i = 0 ; i < value.size (); ++i) {
63
+ if (i != 0 ) stream << " , " ;
64
+ stream << value[i];
65
+ }
66
+ stream << " ]" ;
67
+ }
68
+ void OpFieldToString (std::ostream& stream, const char * field_name,
69
+ const at::Scalar& value) {
70
+ stream << " , " << field_name << " =" ;
71
+ if (value.isFloatingPoint ())
72
+ stream << value.toDouble ();
73
+ else
74
+ stream << value.toLong ();
75
+ }
54
76
} // namespace swift_xla
55
77
56
78
namespace swift_xla {
57
79
namespace ir {
58
80
namespace ops {
59
81
namespace {
60
82
61
- using BinaryOpBuilder = xla::XlaOp(*)(xla::XlaOp, xla::XlaOp, absl::Span<const xla::int64>);
62
- template <BinaryOpBuilder T>
83
+ using BinaryOpBuilderWithDim = xla::XlaOp (*)(xla::XlaOp, xla::XlaOp,
84
+ absl::Span<const xla::int64>);
85
+ template <BinaryOpBuilderWithDim T>
63
86
xla::XlaOp LowerBinaryOp (xla::XlaOp lhs, xla::XlaOp rhs) {
64
87
std::tie (lhs, rhs) = XlaHelpers::Promote (lhs, rhs);
65
88
return T (lhs, rhs, {});
66
89
}
67
90
91
+ using BinaryOpBuilder = xla::XlaOp (*)(xla::XlaOp, xla::XlaOp);
92
+ template <BinaryOpBuilder T>
93
+ xla::XlaOp LowerBinaryValueOp (xla::XlaOp lhs, xla::XlaOp rhs) {
94
+ std::tie (lhs, rhs) = XlaHelpers::PromoteValues (lhs, rhs);
95
+ return T (lhs, rhs);
96
+ }
97
+
68
98
xla::XlaOp LowerSqueeze (xla::XlaOp input, int dim) {
69
99
if (dim == -1 ) return SqueezeAllTrivialDimensions (input);
70
100
XLA_CHECK_GE (dim, 0 );
@@ -107,6 +137,94 @@ xla::Shape CumOpShapeFn(const Value& input, xla::int64 dim,
107
137
return input.shape ();
108
138
}
109
139
140
+ xla::XlaOp LowerClamp (xla::XlaOp xla_input, xla::XlaOp xla_min,
141
+ xla::XlaOp xla_max) {
142
+ xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp (xla_input);
143
+ xla_min = ConvertTo (xla_min, XlaHelpers::TypeOfXlaOp (xla_min), input_type,
144
+ /* device=*/ nullptr );
145
+ xla_max = ConvertTo (xla_max, XlaHelpers::TypeOfXlaOp (xla_max), input_type,
146
+ /* device=*/ nullptr );
147
+ return xla::Clamp (xla_min, xla_input, xla_max);
148
+ }
149
+
150
+ xla::XlaOp LowerMean (xla::XlaOp input,
151
+ const std::vector<xla::int64>& dimensions,
152
+ bool keep_reduced_dimensions,
153
+ const c10::optional<at::ScalarType>& dtype) {
154
+ xla::XlaOp result = BuildMean (input, dimensions, keep_reduced_dimensions);
155
+ return dtype ? xla::ConvertElementType (
156
+ result, MakeXlaPrimitiveType (*dtype, /* device=*/ nullptr ))
157
+ : result;
158
+ }
159
+
160
+ xla::XlaOp LowerSum (xla::XlaOp input, absl::Span<const xla::int64> dimensions,
161
+ bool keep_reduced_dimensions,
162
+ c10::optional<at::ScalarType> dtype) {
163
+ return BuildSum (CastToScalarType (input, dtype), dimensions,
164
+ keep_reduced_dimensions);
165
+ }
166
+
167
+ std::vector<xla::int64> CanonicalizeFlip (xla::Shape shape,
168
+ absl::Span<const xla::int64> dims) {
169
+ auto dimensions =
170
+ XlaHelpers::GetCanonicalDimensionIndices (dims, shape.rank ());
171
+ std::set<xla::int64> unique_dims (dimensions.begin (), dimensions.end ());
172
+ XLA_CHECK_EQ (unique_dims.size (), dimensions.size ());
173
+ return dimensions;
174
+ }
175
+
176
+ std::vector<xla::int64> CanonicalizeExpand (xla::Shape shape,
177
+ absl::Span<const xla::int64> dims) {
178
+ std::vector<xla::int64> dimensions (dims.begin (), dims.end ());
179
+ XLA_CHECK_GE (dimensions.size (), shape.rank ()) << shape;
180
+ xla::int64 base = dimensions.size () - shape.rank ();
181
+ for (size_t i = 0 ; i < shape.rank (); ++i) {
182
+ if (dimensions[base + i] == -1 ) {
183
+ dimensions[base + i] = shape.dimensions (i);
184
+ }
185
+ }
186
+ return dimensions;
187
+ }
188
+
189
+ xla::XlaOp LowerPad (xla::XlaOp input, absl::Span<const xla::int64> pad,
190
+ const at::Scalar& value) {
191
+ const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp (input);
192
+ return xla::Pad (input,
193
+ XlaHelpers::ScalarValue (value, input_shape.element_type (),
194
+ input.builder ()),
195
+ XlaHelpers::MakeXlaPaddingConfigFromNdPadding (pad));
196
+ }
197
+
198
+ std::vector<xla::int64> CanonicalizePad (xla::Shape shape,
199
+ absl::Span<const xla::int64> pad) {
200
+ std::vector<xla::int64> complete_pad (pad.begin (), pad.end ());
201
+ complete_pad.resize (2 * shape.rank ());
202
+ return complete_pad;
203
+ }
204
+
205
+ xla::int64 SliceGetStride (xla::int64 start, xla::int64 end, xla::int64 stride) {
206
+ if (stride == 0 ) {
207
+ XLA_CHECK_EQ (start, end);
208
+ stride = 1 ;
209
+ }
210
+ return stride;
211
+ }
212
+
213
+ xla::XlaOp LowerSlice (xla::XlaOp input, xla::int64 dim, xla::int64 start,
214
+ xla::int64 end, xla::int64 stride) {
215
+ return xla::SliceInDim (input, start, end, SliceGetStride (start, end, stride),
216
+ dim);
217
+ }
218
+
219
+ xla::Shape ShapeSlice (const Value& input, xla::int64 dim, xla::int64 start,
220
+ xla::int64 end, xla::int64 stride) {
221
+ xla::int64 effective_stride = SliceGetStride (start, end, stride);
222
+ xla::Shape select_shape (input.shape ());
223
+ select_shape.set_dimensions (
224
+ dim, (end - start + effective_stride - 1 ) / effective_stride);
225
+ return select_shape;
226
+ }
227
+
110
228
} // namespace
111
229
} // namespace ops
112
230
} // namespace ir
0 commit comments