Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c4e338f

Browse files
authored
Convert more xla ops to be generated from a spec. (#1076)
Added ops include: log_softmax logical_cast softmax tf_OneHot tf_StatelessRandomUniform tf_UnsortedSegmentSum truncated_normal update_slice where xla_slice
1 parent a2fa8c4 commit c4e338f

File tree

7 files changed

+655
-116
lines changed

7 files changed

+655
-116
lines changed

Sources/CX10/xla_tensor_ops_wrapper.cc

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,18 @@
2424
#include "tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
2525
#include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
2626
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
27+
#include "tensorflow/compiler/tf2xla/xla_tensor/layout_manager.h"
2728
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
2829
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
2930
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
31+
#include "tensorflow/compiler/tf2xla/xla_tensor/segment_reduction_ops.h"
3032
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
3133
#include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
3234
#include "tensorflow/compiler/tf2xla/xla_tensor/xla_lower_util.h"
35+
#include "tensorflow/compiler/tf2xla/lib/random.h"
3336
#include "tensorflow/compiler/xla/client/lib/constants.h"
3437
#include "tensorflow/compiler/xla/client/lib/math.h"
38+
#include "tensorflow/compiler/xla/client/lib/prng.h"
3539
#include "xla_tensor_wrapper.h"
3640

3741
namespace at {
@@ -157,6 +161,18 @@ xla::XlaOp LowerMean(xla::XlaOp input,
157161
: result;
158162
}
159163

164+
xla::XlaOp LowerLogicalCast(xla::XlaOp input, at::ScalarType dtype) {
165+
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
166+
return ConvertToRaw(input, input_shape.element_type(),
167+
MakeXlaPrimitiveType(dtype, /*device=*/nullptr),
168+
TensorTypeToRawXlaType(dtype), /*device=*/nullptr);
169+
}
170+
xla::Shape ShapeLogicalCast(const Value& input, at::ScalarType dtype) {
171+
xla::Shape result = input.shape();
172+
result.set_element_type(MakeXlaPrimitiveType(dtype, /*device=*/nullptr));
173+
return result;
174+
}
175+
160176
xla::XlaOp LowerSum(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
161177
bool keep_reduced_dimensions,
162178
c10::optional<at::ScalarType> dtype) {
@@ -225,6 +241,90 @@ xla::Shape ShapeSlice(const Value& input, xla::int64 dim, xla::int64 start,
225241
return select_shape;
226242
}
227243

244+
xla::XlaOp LowerWhere(xla::XlaOp condition, xla::XlaOp input,
245+
xla::XlaOp other) {
246+
xla::XlaOp pred_condition =
247+
ConvertTo(condition, XlaHelpers::TypeOfXlaOp(condition),
248+
xla::PrimitiveType::PRED, /*device=*/nullptr);
249+
std::tie(input, other) = XlaHelpers::PromoteShapes(input, other);
250+
return xla::Select(pred_condition, input, other);
251+
}
252+
253+
xla::XlaOp BuildOneHot(xla::XlaOp indices, xla::XlaOp on_value,
254+
xla::XlaOp off_value, xla::int64 depth,
255+
xla::int64 axis) {
256+
xla::XlaBuilder* builder = indices.builder();
257+
xla::Shape indices_shape = XlaHelpers::ShapeOfXlaOp(indices);
258+
std::vector<xla::int64> broadcast_dims(indices_shape.dimensions().size());
259+
if (axis < 0) axis = axis + broadcast_dims.size() + 1;
260+
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
261+
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
262+
263+
std::vector<xla::int64> output_dimensions(indices_shape.dimensions().size() +
264+
1);
265+
output_dimensions.assign(indices_shape.dimensions().begin(),
266+
indices_shape.dimensions().end());
267+
output_dimensions.insert(output_dimensions.begin() + axis, depth);
268+
xla::Shape iota_shape = xla::ShapeUtil::MakeShape(
269+
indices_shape.element_type(), output_dimensions);
270+
271+
return xla::Select(
272+
xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
273+
xla::Broadcast(on_value, output_dimensions),
274+
xla::Broadcast(off_value, output_dimensions));
275+
}
276+
277+
xla::XlaOp LowerTfUnsortedSegmentSum(xla::XlaOp data, xla::XlaOp indices,
278+
xla::int64 num_segments) {
279+
const xla::Shape& data_shape = XlaHelpers::ShapeOfXlaOp(data);
280+
xla::XlaOp init_value = xla::Zero(data.builder(), data_shape.element_type());
281+
auto combine = [](xla::XlaOp a, xla::XlaOp b) { return a + b; };
282+
return UnsortedSegmentReduce(data, indices, init_value, num_segments,
283+
combine);
284+
}
285+
286+
xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
287+
xla::XlaOp minval, xla::XlaOp maxval,
288+
LoweringContext* loctx = nullptr) {
289+
xla::BitGeneratorTy generator;
290+
if (!loctx || loctx->device().hw_type == swift_xla::DeviceType::TPU) {
291+
generator = xla::ThreeFryBitGenerator;
292+
} else {
293+
generator = [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
294+
std::tie(state, key) = xla::ScramblePhiloxKey(key);
295+
return xla::PhiloxBitGenerator(key, state, shape);
296+
};
297+
}
298+
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
299+
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
300+
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
301+
ShiftLeft(ConvertElementType(seed1, xla::U64),
302+
ConstantR0WithType(seeds.builder(), xla::U64, 32));
303+
xla::XlaOp initial_state =
304+
xla::ConstantR0WithType(seeds.builder(), xla::U64, 0);
305+
xla::PrimitiveType type = shape.element_type();
306+
xla::XlaOp output;
307+
switch (type) {
308+
case xla::F32:
309+
case xla::F64: {
310+
return xla::UniformFloatingPointDistribution(
311+
key, initial_state, generator, minval, maxval, shape)
312+
.value;
313+
}
314+
case xla::S32:
315+
case xla::S64: {
316+
return xla::UniformIntDistribution(key, initial_state, generator, minval,
317+
maxval, shape)
318+
.value;
319+
}
320+
default: {
321+
XLA_ERROR() << "Types other than F32, S32 and S64 are not implemented by "
322+
"StatelessRngUniform; got "
323+
<< xla::primitive_util::LowercasePrimitiveTypeName(type);
324+
}
325+
}
326+
}
327+
228328
} // namespace
229329
} // namespace ops
230330
} // namespace ir

0 commit comments

Comments
 (0)