Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d0ea5cb

Browse files
authored
Some more XLATensor cleanup. (#1110)
1 parent d91e91e commit d0ea5cb

File tree

9 files changed

+441
-292
lines changed

9 files changed

+441
-292
lines changed

Sources/CX10/xla_tensor_ops_wrapper.cc

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ xla::XlaOp LowerBinaryValueOp(xla::XlaOp lhs, xla::XlaOp rhs) {
144144
return T(lhs, rhs);
145145
}
146146

147+
std::vector<xla::XlaOp> LowerBroadcastTensors(xla::XlaOp lhs, xla::XlaOp rhs) {
148+
std::tie(lhs, rhs) = XlaHelpers::PromoteValues(lhs, rhs);
149+
return {lhs, rhs};
150+
}
151+
147152
xla::XlaOp LowerSqueeze(xla::XlaOp input, int dim) {
148153
if (dim == -1) return SqueezeAllTrivialDimensions(input);
149154
XLA_CHECK_GE(dim, 0);
@@ -324,9 +329,7 @@ xla::XlaOp LowerTfUnsortedSegmentSum(xla::XlaOp data, xla::XlaOp indices,
324329
combine);
325330
}
326331

327-
xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
328-
xla::XlaOp minval, xla::XlaOp maxval,
329-
LoweringContext* loctx = nullptr) {
332+
xla::BitGeneratorTy GetBestGenerator(LoweringContext* loctx = nullptr) {
330333
xla::BitGeneratorTy generator;
331334
if (!loctx || loctx->device().hw_type == swift_xla::DeviceType::TPU) {
332335
generator = xla::ThreeFryBitGenerator;
@@ -336,6 +339,30 @@ xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
336339
return xla::PhiloxBitGenerator(key, state, shape);
337340
};
338341
}
342+
return generator;
343+
}
344+
345+
xla::XlaOp LowerTfStatelessRandomNormal(xla::Shape shape, xla::XlaOp seeds,
346+
at::ScalarType dtype,
347+
LoweringContext* loctx = nullptr) {
348+
auto generator = GetBestGenerator(loctx);
349+
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
350+
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
351+
xla::XlaOp initial_state =
352+
xla::ConstantR0WithType(seeds.builder(), xla::U64, 0);
353+
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |
354+
ShiftLeft(ConvertElementType(seed1, xla::U64),
355+
ConstantR0WithType(seeds.builder(), xla::U64, 32));
356+
xla::XlaOp normal =
357+
xla::NormalFloatingPointDistribution(key, initial_state, generator, shape)
358+
.value;
359+
return normal;
360+
}
361+
362+
xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
363+
xla::XlaOp minval, xla::XlaOp maxval,
364+
LoweringContext* loctx = nullptr) {
365+
auto generator = GetBestGenerator(loctx);
339366
xla::XlaOp seed0 = xla::Reshape(xla::Slice(seeds, {0}, {1}, {1}), {});
340367
xla::XlaOp seed1 = xla::Reshape(xla::Slice(seeds, {1}, {2}, {1}), {});
341368
xla::XlaOp key = ConvertElementType(seed0, xla::U64) |

0 commit comments

Comments
 (0)