@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
35
35
return rewriter.create <arith::ConstantOp>(loc, attr);
36
36
}
37
37
38
+ // / Creates shapedType using shape from cloneFrom and base type from cloneTo
39
+ static Type cloneToShapedType (Type cloneFrom, Type cloneTo) {
40
+ if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
41
+ return shapedTy.clone (cloneTo);
42
+ }
43
+ return cloneTo;
44
+ }
45
+
38
46
namespace {
39
47
40
48
// / Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
225
233
return rewriter.notifyMatchFailure (op, " not a ext of bf16 to f32." );
226
234
}
227
235
228
- Type i16Ty = b.getI16Type ();
229
- Type i32Ty = b.getI32Type ();
230
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
231
- i16Ty = shapedTy.clone (i16Ty);
232
- i32Ty = shapedTy.clone (i32Ty);
233
- }
236
+ Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
237
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
234
238
235
239
Value bitcast = b.create <arith::BitcastOp>(i16Ty, operand);
236
240
Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264
268
op, " only applicable to default rounding mode." );
265
269
}
266
270
267
- Type i16Ty = b.getI16Type ();
268
- Type i32Ty = b.getI32Type ();
269
- Type f32Ty = b.getF32Type ();
270
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
271
- i16Ty = shapedTy.clone (i16Ty);
272
- i32Ty = shapedTy.clone (i32Ty);
273
- f32Ty = shapedTy.clone (f32Ty);
274
- }
271
+ Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
272
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
275
273
276
274
// Algorithm borrowed from this excellent code:
277
275
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -291,7 +289,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
291
289
// Constant used to make the rounding bias.
292
290
Value c7FFF = createConst (op.getLoc (), i32Ty, 0x7fff , rewriter);
293
291
// Constant used to generate a quiet NaN.
294
- Value c7FC0_i16 = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
292
+ Value c7FC0I16 = createConst (op.getLoc (), i16Ty, 0x7fc0 , rewriter);
295
293
// Small constants used to address bits.
296
294
Value c16 = createConst (op.getLoc (), i32Ty, 16 , rewriter);
297
295
Value c1 = createConst (op.getLoc (), i32Ty, 1 , rewriter);
@@ -313,18 +311,104 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
313
311
// Now that the rounding-bias has been added, truncating the low bits
314
312
// yields the correctly rounded result.
315
313
Value biasedAndShifted = b.create <arith::ShRUIOp>(biased, c16);
316
- Value normalCaseResult_i16 =
314
+ Value normalCaseResultI16 =
317
315
b.create <arith::TruncIOp>(i16Ty, biasedAndShifted);
318
316
// Select either the above-computed result, or a quiet NaN constant
319
317
// if the input was NaN.
320
318
Value select =
321
- b.create <arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16 );
319
+ b.create <arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16 );
322
320
Value result = b.create <arith::BitcastOp>(resultTy, select);
323
321
rewriter.replaceOp (op, result);
324
322
return success ();
325
323
}
326
324
};
327
325
326
+ struct F8E8M0ExtFOpConverter : public OpRewritePattern <arith::ExtFOp> {
327
+ using OpRewritePattern::OpRewritePattern;
328
+ LogicalResult matchAndRewrite (arith::ExtFOp op,
329
+ PatternRewriter &rewriter) const final {
330
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
331
+ Value operand = op.getOperand ();
332
+ Type operandTy = operand.getType ();
333
+ Type resultTy = op.getType ();
334
+ Type operandETy = getElementTypeOrSelf (operandTy);
335
+ Type resultETy = getElementTypeOrSelf (resultTy);
336
+
337
+ if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
338
+ return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
339
+ }
340
+
341
+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
342
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
343
+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
344
+
345
+ Value bitcast = b.create <arith::BitcastOp>(i8Ty, operand);
346
+ // create constants for NaNs
347
+ Value cF8NaN = createConst (op.getLoc (), i8Ty, 0xff , rewriter);
348
+ Value cF32NaN = createConst (op.getLoc (), i32Ty, 0xffffffff , rewriter);
349
+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
350
+
351
+ Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
352
+ Value f32Bits = b.create <arith::ShLIOp>(exti, cF32MantissaWidth);
353
+
354
+ Value isNan =
355
+ b.create <arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
356
+ // select for NaNs
357
+ f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
358
+ Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
359
+ if (resultETy.getIntOrFloatBitWidth () < 32 ) {
360
+ result = b.create <arith::TruncFOp>(resultTy, result);
361
+ } else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
362
+ result = b.create <arith::ExtFOp>(resultTy, result);
363
+ }
364
+ rewriter.replaceOp (op, result);
365
+ return success ();
366
+ }
367
+ };
368
+
369
+ /*
370
+ TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
371
+ Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
372
+ they all map to NaN in F8E8M0 Type.
373
+ */
374
+ struct F8E8M0TruncFOpConverter : public OpRewritePattern <arith::TruncFOp> {
375
+ using OpRewritePattern::OpRewritePattern;
376
+ LogicalResult matchAndRewrite (arith::TruncFOp op,
377
+ PatternRewriter &rewriter) const final {
378
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
379
+ Value operand = op.getOperand ();
380
+ Type operandTy = operand.getType ();
381
+ Type operandETy = getElementTypeOrSelf (operandTy);
382
+ Type resultTy = op.getType ();
383
+ Type resultETy = getElementTypeOrSelf (resultTy);
384
+ if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
385
+ return rewriter.notifyMatchFailure (op, " not a truncf to f8E8M0FNU" );
386
+ }
387
+
388
+ if (op.getRoundingmodeAttr ()) {
389
+ return rewriter.notifyMatchFailure (
390
+ op, " only applicable to default rounding mode." );
391
+ }
392
+
393
+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
394
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
395
+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
396
+
397
+ if (operandETy.getIntOrFloatBitWidth () < 32 ) {
398
+ operand = b.create <arith::ExtFOp>(f32Ty, operand);
399
+ } else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
400
+ operand = b.create <arith::TruncFOp>(f32Ty, operand);
401
+ }
402
+ Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
403
+ Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
404
+ Value f32SignExp = b.create <arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
405
+ Value exp8Bits = b.create <arith::TruncIOp>(i8Ty, f32SignExp);
406
+ Value result = b.create <arith::BitcastOp>(resultTy, exp8Bits);
407
+ rewriter.replaceOp (op, result);
408
+ return success ();
409
+ }
410
+ };
411
+
328
412
struct ArithExpandOpsPass
329
413
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
330
414
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -353,20 +437,34 @@ struct ArithExpandOpsPass
353
437
354
438
if (includeBf16) {
355
439
arith::populateExpandBFloat16Patterns (patterns);
356
- target.addDynamicallyLegalOp <arith::ExtFOp>(
357
- [](arith::ExtFOp op) {
358
- Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
359
- Type outETy = getElementTypeOrSelf (op.getType ());
360
- return !(inETy.isBF16 () && outETy.isF32 ());
361
- });
362
-
363
- target.addDynamicallyLegalOp <arith::TruncFOp>(
364
- [](arith::TruncFOp op) {
365
- Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
366
- Type outETy = getElementTypeOrSelf (op.getType ());
367
- return !(inETy.isF32 () && outETy.isBF16 ());
368
- });
369
440
}
441
+ if (includeF8E8M0) {
442
+ arith::populateExpandF8E8M0Patterns (patterns);
443
+ }
444
+
445
+ target.addDynamicallyLegalOp <arith::ExtFOp>(
446
+ [=](arith::ExtFOp op) {
447
+ Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
448
+ Type outETy = getElementTypeOrSelf (op.getType ());
449
+ bool legalTypes = true ;
450
+ if (includeBf16)
451
+ legalTypes &= !(inETy.isBF16 () && outETy.isF32 ());
452
+ if (includeF8E8M0)
453
+ legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
454
+ return legalTypes;
455
+ });
456
+
457
+ target.addDynamicallyLegalOp <arith::TruncFOp>(
458
+ [=](arith::TruncFOp op) {
459
+ Type inETy = getElementTypeOrSelf (op.getOperand ().getType ());
460
+ Type outETy = getElementTypeOrSelf (op.getType ());
461
+ bool legalTypes = true ;
462
+ if (includeBf16)
463
+ legalTypes &= !(inETy.isF32 () && outETy.isBF16 ());
464
+ if (includeF8E8M0)
465
+ legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
466
+ return legalTypes;
467
+ });
370
468
371
469
// clang-format on
372
470
if (failed (applyPartialConversion (getOperation (), target,
@@ -389,6 +487,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
389
487
patterns.getContext ());
390
488
}
391
489
490
+ void mlir::arith::populateExpandF8E8M0Patterns (RewritePatternSet &patterns) {
491
+ patterns.add <F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
492
+ patterns.getContext ());
493
+ }
494
+
392
495
void mlir::arith::populateArithExpandOpsPatterns (RewritePatternSet &patterns) {
393
496
populateCeilFloorDivExpandOpsPatterns (patterns);
394
497
// clang-format off
0 commit comments