Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 61e7724

Browse files
authored
Add specs for assorted ops. (#1073)
abs, acos, acosh, add, all, any, argmax, argmin, asin, asinh, atan, atanh, ceil, clamp, constant_pad_nd, cos, cosh, div, exp, expand, expm1, flip, floor, is_finite, is_inf, is_nan, log, log1p, logicalAnd, logicalNot, logicalOr, matmul, max, maximum, mean, min, minimum, mul, mm, neg, rsqrt, sigmoid, sign, sin, sinh, slice, sqrt, sub, sum, tan, and tanh
1 parent 3469a1b commit 61e7724

File tree

6 files changed

+2609
-471
lines changed

6 files changed

+2609
-471
lines changed

Sources/CX10/xla_tensor_ops_wrapper.cc

Lines changed: 126 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,30 @@
1818
#define XLA_API __attribute__((__visibility__("default")))
1919
#endif
2020

21-
#include "xla_tensor_wrapper.h"
22-
2321
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
2422
#include "tensorflow/compiler/xla/xla_client/util.h"
23+
#include "tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
24+
#include "tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
25+
#include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
2526
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
2627
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
2728
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
28-
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
2929
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
30+
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
3031
#include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
31-
#include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
32-
#include "tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
33-
#include "tensorflow/compiler/tf2xla/xla_tensor/convert_ops.h"
32+
#include "tensorflow/compiler/tf2xla/xla_tensor/xla_lower_util.h"
3433
#include "tensorflow/compiler/xla/client/lib/constants.h"
34+
#include "tensorflow/compiler/xla/client/lib/math.h"
35+
#include "xla_tensor_wrapper.h"
3536

3637
namespace at {
3738
xla::hash_t Hash(const c10::optional<at::ScalarType>& dtype) {
3839
return xla::util::Hash(swift_xla::OptionalOr<int>(dtype, -1));
3940
}
41+
xla::hash_t Hash(const at::Scalar& value) {
42+
return value.isFloatingPoint() ? xla::util::Hash(value.toDouble())
43+
: xla::util::Hash(value.toLong());
44+
}
4045
}
4146
namespace swift_xla {
4247
void OpFieldToString(std::ostream& stream, const char* field_name, const c10::optional<at::ScalarType>& dtype) {
@@ -51,20 +56,45 @@ void OpFieldToString(std::ostream& stream, const char* field_name, xla::int64 va
5156
void OpFieldToString(std::ostream& stream, const char* field_name, float value) {
5257
stream << ", " << field_name << "=" << value;
5358
}
59+
void OpFieldToString(std::ostream& stream, const char* field_name,
60+
const std::vector<xla::int64>& value) {
61+
stream << ", " << field_name << "=[";
62+
for (size_t i = 0; i < value.size(); ++i) {
63+
if (i != 0) stream << ", ";
64+
stream << value[i];
65+
}
66+
stream << "]";
67+
}
68+
void OpFieldToString(std::ostream& stream, const char* field_name,
69+
const at::Scalar& value) {
70+
stream << ", " << field_name << "=";
71+
if (value.isFloatingPoint())
72+
stream << value.toDouble();
73+
else
74+
stream << value.toLong();
75+
}
5476
} // namespace swift_xla
5577

5678
namespace swift_xla {
5779
namespace ir {
5880
namespace ops {
5981
namespace {
6082

61-
using BinaryOpBuilder = xla::XlaOp(*)(xla::XlaOp, xla::XlaOp, absl::Span<const xla::int64>);
62-
template <BinaryOpBuilder T>
83+
using BinaryOpBuilderWithDim = xla::XlaOp (*)(xla::XlaOp, xla::XlaOp,
84+
absl::Span<const xla::int64>);
85+
template <BinaryOpBuilderWithDim T>
6386
xla::XlaOp LowerBinaryOp(xla::XlaOp lhs, xla::XlaOp rhs) {
6487
std::tie(lhs, rhs) = XlaHelpers::Promote(lhs, rhs);
6588
return T(lhs, rhs, {});
6689
}
6790

91+
using BinaryOpBuilder = xla::XlaOp (*)(xla::XlaOp, xla::XlaOp);
92+
template <BinaryOpBuilder T>
93+
xla::XlaOp LowerBinaryValueOp(xla::XlaOp lhs, xla::XlaOp rhs) {
94+
std::tie(lhs, rhs) = XlaHelpers::PromoteValues(lhs, rhs);
95+
return T(lhs, rhs);
96+
}
97+
6898
xla::XlaOp LowerSqueeze(xla::XlaOp input, int dim) {
6999
if (dim == -1) return SqueezeAllTrivialDimensions(input);
70100
XLA_CHECK_GE(dim, 0);
@@ -107,6 +137,94 @@ xla::Shape CumOpShapeFn(const Value& input, xla::int64 dim,
107137
return input.shape();
108138
}
109139

140+
xla::XlaOp LowerClamp(xla::XlaOp xla_input, xla::XlaOp xla_min,
141+
xla::XlaOp xla_max) {
142+
xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input);
143+
xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type,
144+
/*device=*/nullptr);
145+
xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type,
146+
/*device=*/nullptr);
147+
return xla::Clamp(xla_min, xla_input, xla_max);
148+
}
149+
150+
xla::XlaOp LowerMean(xla::XlaOp input,
151+
const std::vector<xla::int64>& dimensions,
152+
bool keep_reduced_dimensions,
153+
const c10::optional<at::ScalarType>& dtype) {
154+
xla::XlaOp result = BuildMean(input, dimensions, keep_reduced_dimensions);
155+
return dtype ? xla::ConvertElementType(
156+
result, MakeXlaPrimitiveType(*dtype, /*device=*/nullptr))
157+
: result;
158+
}
159+
160+
xla::XlaOp LowerSum(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
161+
bool keep_reduced_dimensions,
162+
c10::optional<at::ScalarType> dtype) {
163+
return BuildSum(CastToScalarType(input, dtype), dimensions,
164+
keep_reduced_dimensions);
165+
}
166+
167+
std::vector<xla::int64> CanonicalizeFlip(xla::Shape shape,
168+
absl::Span<const xla::int64> dims) {
169+
auto dimensions =
170+
XlaHelpers::GetCanonicalDimensionIndices(dims, shape.rank());
171+
std::set<xla::int64> unique_dims(dimensions.begin(), dimensions.end());
172+
XLA_CHECK_EQ(unique_dims.size(), dimensions.size());
173+
return dimensions;
174+
}
175+
176+
std::vector<xla::int64> CanonicalizeExpand(xla::Shape shape,
177+
absl::Span<const xla::int64> dims) {
178+
std::vector<xla::int64> dimensions(dims.begin(), dims.end());
179+
XLA_CHECK_GE(dimensions.size(), shape.rank()) << shape;
180+
xla::int64 base = dimensions.size() - shape.rank();
181+
for (size_t i = 0; i < shape.rank(); ++i) {
182+
if (dimensions[base + i] == -1) {
183+
dimensions[base + i] = shape.dimensions(i);
184+
}
185+
}
186+
return dimensions;
187+
}
188+
189+
xla::XlaOp LowerPad(xla::XlaOp input, absl::Span<const xla::int64> pad,
190+
const at::Scalar& value) {
191+
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
192+
return xla::Pad(input,
193+
XlaHelpers::ScalarValue(value, input_shape.element_type(),
194+
input.builder()),
195+
XlaHelpers::MakeXlaPaddingConfigFromNdPadding(pad));
196+
}
197+
198+
std::vector<xla::int64> CanonicalizePad(xla::Shape shape,
199+
absl::Span<const xla::int64> pad) {
200+
std::vector<xla::int64> complete_pad(pad.begin(), pad.end());
201+
complete_pad.resize(2 * shape.rank());
202+
return complete_pad;
203+
}
204+
205+
xla::int64 SliceGetStride(xla::int64 start, xla::int64 end, xla::int64 stride) {
206+
if (stride == 0) {
207+
XLA_CHECK_EQ(start, end);
208+
stride = 1;
209+
}
210+
return stride;
211+
}
212+
213+
xla::XlaOp LowerSlice(xla::XlaOp input, xla::int64 dim, xla::int64 start,
214+
xla::int64 end, xla::int64 stride) {
215+
return xla::SliceInDim(input, start, end, SliceGetStride(start, end, stride),
216+
dim);
217+
}
218+
219+
xla::Shape ShapeSlice(const Value& input, xla::int64 dim, xla::int64 start,
220+
xla::int64 end, xla::int64 stride) {
221+
xla::int64 effective_stride = SliceGetStride(start, end, stride);
222+
xla::Shape select_shape(input.shape());
223+
select_shape.set_dimensions(
224+
dim, (end - start + effective_stride - 1) / effective_stride);
225+
return select_shape;
226+
}
227+
110228
} // namespace
111229
} // namespace ops
112230
} // namespace ir

0 commit comments

Comments
 (0)