Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Convert more xla ops to be generated from a spec. #1076

Merged
merged 1 commit into from
Sep 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions Sources/CX10/xla_tensor_ops_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@
#include "tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/layout_manager.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/segment_reduction_ops.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
#include "tensorflow/compiler/tf2xla/xla_tensor/xla_lower_util.h"
#include "tensorflow/compiler/tf2xla/lib/random.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/prng.h"
#include "xla_tensor_wrapper.h"

namespace at {
Expand Down Expand Up @@ -157,6 +161,18 @@ xla::XlaOp LowerMean(xla::XlaOp input,
: result;
}

xla::XlaOp LowerLogicalCast(xla::XlaOp input, at::ScalarType dtype) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
return ConvertToRaw(input, input_shape.element_type(),
MakeXlaPrimitiveType(dtype, /*device=*/nullptr),
TensorTypeToRawXlaType(dtype), /*device=*/nullptr);
}
xla::Shape ShapeLogicalCast(const Value& input, at::ScalarType dtype) {
xla::Shape result = input.shape();
result.set_element_type(MakeXlaPrimitiveType(dtype, /*device=*/nullptr));
return result;
}

xla::XlaOp LowerSum(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
bool keep_reduced_dimensions,
c10::optional<at::ScalarType> dtype) {
Expand Down Expand Up @@ -225,6 +241,90 @@ xla::Shape ShapeSlice(const Value& input, xla::int64 dim, xla::int64 start,
return select_shape;
}

xla::XlaOp LowerWhere(xla::XlaOp condition, xla::XlaOp input,
xla::XlaOp other) {
xla::XlaOp pred_condition =
ConvertTo(condition, XlaHelpers::TypeOfXlaOp(condition),
xla::PrimitiveType::PRED, /*device=*/nullptr);
std::tie(input, other) = XlaHelpers::PromoteShapes(input, other);
return xla::Select(pred_condition, input, other);
}

xla::XlaOp BuildOneHot(xla::XlaOp indices, xla::XlaOp on_value,
xla::XlaOp off_value, xla::int64 depth,
xla::int64 axis) {
xla::XlaBuilder* builder = indices.builder();
xla::Shape indices_shape = XlaHelpers::ShapeOfXlaOp(indices);
std::vector<xla::int64> broadcast_dims(indices_shape.dimensions().size());
if (axis < 0) axis = axis + broadcast_dims.size() + 1;
std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);

std::vector<xla::int64> output_dimensions(indices_shape.dimensions().size() +
1);
output_dimensions.assign(indices_shape.dimensions().begin(),
indices_shape.dimensions().end());
output_dimensions.insert(output_dimensions.begin() + axis, depth);
xla::Shape iota_shape = xla::ShapeUtil::MakeShape(
indices_shape.element_type(), output_dimensions);

return xla::Select(
xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
xla::Broadcast(on_value, output_dimensions),
xla::Broadcast(off_value, output_dimensions));
}

xla::XlaOp LowerTfUnsortedSegmentSum(xla::XlaOp data, xla::XlaOp indices,
xla::int64 num_segments) {
const xla::Shape& data_shape = XlaHelpers::ShapeOfXlaOp(data);
xla::XlaOp init_value = xla::Zero(data.builder(), data_shape.element_type());
auto combine = [](xla::XlaOp a, xla::XlaOp b) { return a + b; };
return UnsortedSegmentReduce(data, indices, init_value, num_segments,
combine);
}

xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
xla::XlaOp minval, xla::XlaOp maxval,
LoweringContext* loctx = nullptr) {
xla::BitGeneratorTy generator;
if (!loctx || loctx->device().hw_type == swift_xla::DeviceType::TPU) {
generator = xla::ThreeFryBitGenerator;
} else {
generator = [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) {
std::tie(state, key) = xla::ScramblePhiloxKey(key);
return xla::PhiloxBitGenerator(key, state, shape);
};
}
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
ShiftLeft(ConvertElementType(seed1, xla::U64),
ConstantR0WithType(seeds.builder(), xla::U64, 32));
xla::XlaOp initial_state =
xla::ConstantR0WithType(seeds.builder(), xla::U64, 0);
xla::PrimitiveType type = shape.element_type();
xla::XlaOp output;
switch (type) {
case xla::F32:
case xla::F64: {
return xla::UniformFloatingPointDistribution(
key, initial_state, generator, minval, maxval, shape)
.value;
}
case xla::S32:
case xla::S64: {
return xla::UniformIntDistribution(key, initial_state, generator, minval,
maxval, shape)
.value;
}
default: {
XLA_ERROR() << "Types other than F32, S32 and S64 are not implemented by "
"StatelessRngUniform; got "
<< xla::primitive_util::LowercasePrimitiveTypeName(type);
}
}
}

} // namespace
} // namespace ops
} // namespace ir
Expand Down
Loading