@@ -144,6 +144,11 @@ xla::XlaOp LowerBinaryValueOp(xla::XlaOp lhs, xla::XlaOp rhs) {
144
144
return T (lhs, rhs);
145
145
}
146
146
147
+ std::vector<xla::XlaOp> LowerBroadcastTensors (xla::XlaOp lhs, xla::XlaOp rhs) {
148
+ std::tie (lhs, rhs) = XlaHelpers::PromoteValues (lhs, rhs);
149
+ return {lhs, rhs};
150
+ }
151
+
147
152
xla::XlaOp LowerSqueeze (xla::XlaOp input, int dim) {
148
153
if (dim == -1 ) return SqueezeAllTrivialDimensions (input);
149
154
XLA_CHECK_GE (dim, 0 );
@@ -324,9 +329,7 @@ xla::XlaOp LowerTfUnsortedSegmentSum(xla::XlaOp data, xla::XlaOp indices,
324
329
combine);
325
330
}
326
331
327
- xla::XlaOp LowerTfStatelessRandomUniform (xla::Shape shape, xla::XlaOp seeds,
328
- xla::XlaOp minval, xla::XlaOp maxval,
329
- LoweringContext* loctx = nullptr ) {
332
+ xla::BitGeneratorTy GetBestGenerator (LoweringContext* loctx = nullptr ) {
330
333
xla::BitGeneratorTy generator;
331
334
if (!loctx || loctx->device ().hw_type == swift_xla::DeviceType::TPU) {
332
335
generator = xla::ThreeFryBitGenerator;
@@ -336,6 +339,30 @@ xla::XlaOp LowerTfStatelessRandomUniform(xla::Shape shape, xla::XlaOp seeds,
336
339
return xla::PhiloxBitGenerator (key, state, shape);
337
340
};
338
341
}
342
+ return generator;
343
+ }
344
+
345
+ xla::XlaOp LowerTfStatelessRandomNormal (xla::Shape shape, xla::XlaOp seeds,
346
+ at::ScalarType dtype,
347
+ LoweringContext* loctx = nullptr ) {
348
+ auto generator = GetBestGenerator (loctx);
349
+ xla::XlaOp seed0 = xla::Reshape (xla::Slice (seeds, {0 }, {1 }, {1 }), {});
350
+ xla::XlaOp seed1 = xla::Reshape (xla::Slice (seeds, {1 }, {2 }, {1 }), {});
351
+ xla::XlaOp initial_state =
352
+ xla::ConstantR0WithType (seeds.builder (), xla::U64, 0 );
353
+ xla::XlaOp key = ConvertElementType (seed0, xla::U64) |
354
+ ShiftLeft (ConvertElementType (seed1, xla::U64),
355
+ ConstantR0WithType (seeds.builder (), xla::U64, 32 ));
356
+ xla::XlaOp normal =
357
+ xla::NormalFloatingPointDistribution (key, initial_state, generator, shape)
358
+ .value ;
359
+ return normal;
360
+ }
361
+
362
+ xla::XlaOp LowerTfStatelessRandomUniform (xla::Shape shape, xla::XlaOp seeds,
363
+ xla::XlaOp minval, xla::XlaOp maxval,
364
+ LoweringContext* loctx = nullptr ) {
365
+ auto generator = GetBestGenerator (loctx);
339
366
xla::XlaOp seed0 = xla::Reshape (xla::Slice (seeds, {0 }, {1 }, {1 }), {});
340
367
xla::XlaOp seed1 = xla::Reshape (xla::Slice (seeds, {1 }, {2 }, {1 }), {});
341
368
xla::XlaOp key = ConvertElementType (seed0, xla::U64) |
0 commit comments