|
38 | 38 | #include "tensorflow/compiler/xla/client/lib/constants.h"
|
39 | 39 | #include "tensorflow/compiler/xla/client/lib/math.h"
|
40 | 40 | #include "tensorflow/compiler/xla/client/lib/prng.h"
|
| 41 | +#include "tensorflow/compiler/xla/client/lib/qr.h" |
| 42 | +#include "tensorflow/compiler/xla/client/lib/svd.h" |
41 | 43 | #include "xla_tensor_wrapper.h"
|
42 | 44 |
|
43 | 45 | namespace at {
|
@@ -423,6 +425,55 @@ xla::int64 CanonicalizeCat(absl::Span<const Value> inputs, xla::int64 dim) {
|
423 | 425 | return dim;
|
424 | 426 | }
|
425 | 427 |
|
| 428 | +std::vector<xla::XlaOp> LowerQR(xla::XlaOp input, bool some) { |
| 429 | + xla::QRDecompositionResult qr_result = |
| 430 | + xla::QRDecomposition(input, /*full_matrices=*/!some, |
| 431 | + /*block_size=*/128, XlaHelpers::mat_mul_precision()) |
| 432 | + .ValueOrDie(); |
| 433 | + xla::XlaOp q = qr_result.q; |
| 434 | + xla::XlaOp r = qr_result.r; |
| 435 | + return {q, r}; |
| 436 | +} |
| 437 | + |
| 438 | +std::vector<xla::XlaOp> LowerSVD(xla::XlaOp input, bool compute_uv, |
| 439 | + bool full_matrix) { |
| 440 | + xla::SVDResult svd_result = |
| 441 | + xla::SVD(input, /*max_iter=*/100, /*epsilon=*/1e-6, |
| 442 | + XlaHelpers::mat_mul_precision()); |
| 443 | + const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); |
| 444 | + xla::XlaOp u = svd_result.u; |
| 445 | + xla::XlaOp v = svd_result.v; |
| 446 | + if (!compute_uv) { |
| 447 | + u = xla::Zeros(input.builder(), XlaHelpers::ShapeOfXlaOp(u)); |
| 448 | + v = xla::Zeros(input.builder(), XlaHelpers::ShapeOfXlaOp(v)); |
| 449 | + } else if (!full_matrix) { |
| 450 | + xla::int64 m_dim = input_shape.dimensions(input_shape.rank() - 2); |
| 451 | + xla::int64 n_dim = input_shape.dimensions(input_shape.rank() - 1); |
| 452 | + std::vector<xla::int64> base_indices(input_shape.rank(), 0); |
| 453 | + |
| 454 | + auto u_sizes = xla::util::ToVector<xla::int64>(input_shape.dimensions()); |
| 455 | + u_sizes[input_shape.rank() - 1] = std::min(m_dim, n_dim); |
| 456 | + u = BuildSlice(u, base_indices, u_sizes); |
| 457 | + |
| 458 | + auto v_sizes = xla::util::ToVector<xla::int64>(input_shape.dimensions()); |
| 459 | + v_sizes[input_shape.rank() - 2] = n_dim; |
| 460 | + v_sizes[input_shape.rank() - 1] = std::min(m_dim, n_dim); |
| 461 | + v = BuildSlice(v, base_indices, v_sizes); |
| 462 | + } |
| 463 | + return {u, svd_result.d, v}; |
| 464 | +} |
| 465 | + |
| 466 | +xla::Shape ShapeOfXlaOpList(absl::Span<const xla::XlaOp> ops) { |
| 467 | + xla::Shape result; |
| 468 | + result.set_element_type(xla::TUPLE); |
| 469 | + result.mutable_tuple_shapes()->reserve(ops.size()); |
| 470 | + for (const auto& op : ops) { |
| 471 | + xla::ShapeUtil::AppendShapeToTuple(XlaHelpers::ShapeOfXlaOp(op), &result); |
| 472 | + } |
| 473 | + TF_DCHECK_OK(xla::ShapeUtil::ValidateShapeWithOptionalLayout(result)); |
| 474 | + return result; |
| 475 | +} |
| 476 | + |
426 | 477 | } // namespace
|
427 | 478 | } // namespace ops
|
428 | 479 | } // namespace ir
|
|
0 commit comments