Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 2aa292e

Browse files
authored
Convert qr and add svd to xla. (#1082)
1 parent 84797b2 commit 2aa292e

File tree

11 files changed

+520
-163
lines changed

11 files changed

+520
-163
lines changed

Sources/CX10/xla_tensor_ops_wrapper.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
#include "tensorflow/compiler/xla/client/lib/constants.h"
3939
#include "tensorflow/compiler/xla/client/lib/math.h"
4040
#include "tensorflow/compiler/xla/client/lib/prng.h"
41+
#include "tensorflow/compiler/xla/client/lib/qr.h"
42+
#include "tensorflow/compiler/xla/client/lib/svd.h"
4143
#include "xla_tensor_wrapper.h"
4244

4345
namespace at {
@@ -423,6 +425,55 @@ xla::int64 CanonicalizeCat(absl::Span<const Value> inputs, xla::int64 dim) {
423425
return dim;
424426
}
425427

428+
std::vector<xla::XlaOp> LowerQR(xla::XlaOp input, bool some) {
429+
xla::QRDecompositionResult qr_result =
430+
xla::QRDecomposition(input, /*full_matrices=*/!some,
431+
/*block_size=*/128, XlaHelpers::mat_mul_precision())
432+
.ValueOrDie();
433+
xla::XlaOp q = qr_result.q;
434+
xla::XlaOp r = qr_result.r;
435+
return {q, r};
436+
}
437+
438+
std::vector<xla::XlaOp> LowerSVD(xla::XlaOp input, bool compute_uv,
439+
bool full_matrix) {
440+
xla::SVDResult svd_result =
441+
xla::SVD(input, /*max_iter=*/100, /*epsilon=*/1e-6,
442+
XlaHelpers::mat_mul_precision());
443+
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
444+
xla::XlaOp u = svd_result.u;
445+
xla::XlaOp v = svd_result.v;
446+
if (!compute_uv) {
447+
u = xla::Zeros(input.builder(), XlaHelpers::ShapeOfXlaOp(u));
448+
v = xla::Zeros(input.builder(), XlaHelpers::ShapeOfXlaOp(v));
449+
} else if (!full_matrix) {
450+
xla::int64 m_dim = input_shape.dimensions(input_shape.rank() - 2);
451+
xla::int64 n_dim = input_shape.dimensions(input_shape.rank() - 1);
452+
std::vector<xla::int64> base_indices(input_shape.rank(), 0);
453+
454+
auto u_sizes = xla::util::ToVector<xla::int64>(input_shape.dimensions());
455+
u_sizes[input_shape.rank() - 1] = std::min(m_dim, n_dim);
456+
u = BuildSlice(u, base_indices, u_sizes);
457+
458+
auto v_sizes = xla::util::ToVector<xla::int64>(input_shape.dimensions());
459+
v_sizes[input_shape.rank() - 2] = n_dim;
460+
v_sizes[input_shape.rank() - 1] = std::min(m_dim, n_dim);
461+
v = BuildSlice(v, base_indices, v_sizes);
462+
}
463+
return {u, svd_result.d, v};
464+
}
465+
466+
xla::Shape ShapeOfXlaOpList(absl::Span<const xla::XlaOp> ops) {
467+
xla::Shape result;
468+
result.set_element_type(xla::TUPLE);
469+
result.mutable_tuple_shapes()->reserve(ops.size());
470+
for (const auto& op : ops) {
471+
xla::ShapeUtil::AppendShapeToTuple(XlaHelpers::ShapeOfXlaOp(op), &result);
472+
}
473+
TF_DCHECK_OK(xla::ShapeUtil::ValidateShapeWithOptionalLayout(result));
474+
return result;
475+
}
476+
426477
} // namespace
427478
} // namespace ops
428479
} // namespace ir

0 commit comments

Comments
 (0)