|
24 | 24 | #include "tensorflow/compiler/tf2xla/xla_tensor/data_ops.h"
|
25 | 25 | #include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
|
26 | 26 | #include "tensorflow/compiler/tf2xla/xla_tensor/helpers.h"
|
| 27 | +#include "tensorflow/compiler/tf2xla/xla_tensor/layout_manager.h" |
27 | 28 | #include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
|
28 | 29 | #include "tensorflow/compiler/tf2xla/xla_tensor/ops/infer_output_shape.h"
|
29 | 30 | #include "tensorflow/compiler/tf2xla/xla_tensor/reduction.h"
|
| 31 | +#include "tensorflow/compiler/tf2xla/xla_tensor/segment_reduction_ops.h" |
30 | 32 | #include "tensorflow/compiler/tf2xla/xla_tensor/softmax_builder.h"
|
31 | 33 | #include "tensorflow/compiler/tf2xla/xla_tensor/tensor_util.h"
|
32 | 34 | #include "tensorflow/compiler/tf2xla/xla_tensor/xla_lower_util.h"
|
| 35 | +#include "tensorflow/compiler/tf2xla/lib/random.h" |
33 | 36 | #include "tensorflow/compiler/xla/client/lib/constants.h"
|
34 | 37 | #include "tensorflow/compiler/xla/client/lib/math.h"
|
| 38 | +#include "tensorflow/compiler/xla/client/lib/prng.h" |
35 | 39 | #include "xla_tensor_wrapper.h"
|
36 | 40 |
|
37 | 41 | namespace at {
|
@@ -157,6 +161,18 @@ xla::XlaOp LowerMean(xla::XlaOp input,
|
157 | 161 | : result;
|
158 | 162 | }
|
159 | 163 |
|
| 164 | +xla::XlaOp LowerLogicalCast(xla::XlaOp input, at::ScalarType dtype) { |
| 165 | + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); |
| 166 | + return ConvertToRaw(input, input_shape.element_type(), |
| 167 | + MakeXlaPrimitiveType(dtype, /*device=*/nullptr), |
| 168 | + TensorTypeToRawXlaType(dtype), /*device=*/nullptr); |
| 169 | +} |
| 170 | +xla::Shape ShapeLogicalCast(const Value& input, at::ScalarType dtype) { |
| 171 | + xla::Shape result = input.shape(); |
| 172 | + result.set_element_type(MakeXlaPrimitiveType(dtype, /*device=*/nullptr)); |
| 173 | + return result; |
| 174 | +} |
| 175 | + |
160 | 176 | xla::XlaOp LowerSum(xla::XlaOp input, absl::Span<const xla::int64> dimensions,
|
161 | 177 | bool keep_reduced_dimensions,
|
162 | 178 | c10::optional<at::ScalarType> dtype) {
|
@@ -225,6 +241,90 @@ xla::Shape ShapeSlice(const Value& input, xla::int64 dim, xla::int64 start,
|
225 | 241 | return select_shape;
|
226 | 242 | }
|
227 | 243 |
|
| 244 | +xla::XlaOp LowerWhere(xla::XlaOp condition, xla::XlaOp input, |
| 245 | + xla::XlaOp other) { |
| 246 | + xla::XlaOp pred_condition = |
| 247 | + ConvertTo(condition, XlaHelpers::TypeOfXlaOp(condition), |
| 248 | + xla::PrimitiveType::PRED, /*device=*/nullptr); |
| 249 | + std::tie(input, other) = XlaHelpers::PromoteShapes(input, other); |
| 250 | + return xla::Select(pred_condition, input, other); |
| 251 | +} |
| 252 | + |
| 253 | +xla::XlaOp BuildOneHot(xla::XlaOp indices, xla::XlaOp on_value, |
| 254 | + xla::XlaOp off_value, xla::int64 depth, |
| 255 | + xla::int64 axis) { |
| 256 | + xla::XlaBuilder* builder = indices.builder(); |
| 257 | + xla::Shape indices_shape = XlaHelpers::ShapeOfXlaOp(indices); |
| 258 | + std::vector<xla::int64> broadcast_dims(indices_shape.dimensions().size()); |
| 259 | + if (axis < 0) axis = axis + broadcast_dims.size() + 1; |
| 260 | + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); |
| 261 | + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); |
| 262 | + |
| 263 | + std::vector<xla::int64> output_dimensions(indices_shape.dimensions().size() + |
| 264 | + 1); |
| 265 | + output_dimensions.assign(indices_shape.dimensions().begin(), |
| 266 | + indices_shape.dimensions().end()); |
| 267 | + output_dimensions.insert(output_dimensions.begin() + axis, depth); |
| 268 | + xla::Shape iota_shape = xla::ShapeUtil::MakeShape( |
| 269 | + indices_shape.element_type(), output_dimensions); |
| 270 | + |
| 271 | + return xla::Select( |
| 272 | + xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims), |
| 273 | + xla::Broadcast(on_value, output_dimensions), |
| 274 | + xla::Broadcast(off_value, output_dimensions)); |
| 275 | +} |
| 276 | + |
| 277 | +xla::XlaOp LowerTfUnsortedSegmentSum(xla::XlaOp data, xla::XlaOp indices, |
| 278 | + xla::int64 num_segments) { |
| 279 | + const xla::Shape& data_shape = XlaHelpers::ShapeOfXlaOp(data); |
| 280 | + xla::XlaOp init_value = xla::Zero(data.builder(), data_shape.element_type()); |
| 281 | + auto combine = [](xla::XlaOp a, xla::XlaOp b) { return a + b; }; |
| 282 | + return UnsortedSegmentReduce(data, indices, init_value, num_segments, |
| 283 | + combine); |
| 284 | +} |
| 285 | + |
| 286 | +xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds, |
| 287 | + xla::XlaOp minval, xla::XlaOp maxval, |
| 288 | + LoweringContext* loctx = nullptr) { |
| 289 | + xla::BitGeneratorTy generator; |
| 290 | + if (!loctx || loctx->device().hw_type == swift_xla::DeviceType::TPU) { |
| 291 | + generator = xla::ThreeFryBitGenerator; |
| 292 | + } else { |
| 293 | + generator = [](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { |
| 294 | + std::tie(state, key) = xla::ScramblePhiloxKey(key); |
| 295 | + return xla::PhiloxBitGenerator(key, state, shape); |
| 296 | + }; |
| 297 | + } |
| 298 | + xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {}); |
| 299 | + xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {}); |
| 300 | + xla::XlaOp key = ConvertElementType(seed0, xla::U64) | |
| 301 | + ShiftLeft(ConvertElementType(seed1, xla::U64), |
| 302 | + ConstantR0WithType(seeds.builder(), xla::U64, 32)); |
| 303 | + xla::XlaOp initial_state = |
| 304 | + xla::ConstantR0WithType(seeds.builder(), xla::U64, 0); |
| 305 | + xla::PrimitiveType type = shape.element_type(); |
| 306 | + xla::XlaOp output; |
| 307 | + switch (type) { |
| 308 | + case xla::F32: |
| 309 | + case xla::F64: { |
| 310 | + return xla::UniformFloatingPointDistribution( |
| 311 | + key, initial_state, generator, minval, maxval, shape) |
| 312 | + .value; |
| 313 | + } |
| 314 | + case xla::S32: |
| 315 | + case xla::S64: { |
| 316 | + return xla::UniformIntDistribution(key, initial_state, generator, minval, |
| 317 | + maxval, shape) |
| 318 | + .value; |
| 319 | + } |
| 320 | + default: { |
| 321 | + XLA_ERROR() << "Types other than F32, S32 and S64 are not implemented by " |
| 322 | + "StatelessRngUniform; got " |
| 323 | + << xla::primitive_util::LowercasePrimitiveTypeName(type); |
| 324 | + } |
| 325 | + } |
| 326 | +} |
| 327 | + |
228 | 328 | } // namespace
|
229 | 329 | } // namespace ops
|
230 | 330 | } // namespace ir
|
|
0 commit comments