6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
- #include " mlir/Dialect/Arith/Transforms/Passes.h"
10
-
11
9
#include " mlir/Dialect/Arith/IR/Arith.h"
10
+ #include " mlir/Dialect/Arith/Transforms/Passes.h"
12
11
#include " mlir/Dialect/Vector/IR/VectorOps.h"
12
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
13
13
#include " mlir/IR/ImplicitLocOpBuilder.h"
14
14
#include " mlir/IR/TypeUtilities.h"
15
15
#include " mlir/Transforms/DialectConversion.h"
@@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value,
31
31
return rewriter.create <arith::ConstantOp>(
32
32
loc, DenseElementsAttr::get (shapedTy, attr));
33
33
}
34
-
35
34
return rewriter.create <arith::ConstantOp>(loc, attr);
36
35
}
37
36
@@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
357
356
f32Bits = b.create <arith::SelectOp>(isNan, cF32NaN, f32Bits);
358
357
Value result = b.create <arith::BitcastOp>(f32Ty, f32Bits);
359
358
if (resultETy.getIntOrFloatBitWidth () < 32 ) {
360
- result = b.create <arith::TruncFOp>(resultTy, result);
359
+ result = b.create <arith::TruncFOp>(resultTy, result, nullptr ,
360
+ op.getFastmathAttr ());
361
361
} else if (resultETy.getIntOrFloatBitWidth () > 32 ) {
362
- result = b.create <arith::ExtFOp>(resultTy, result);
362
+ result = b.create <arith::ExtFOp>(resultTy, result, op. getFastmathAttr () );
363
363
}
364
364
rewriter.replaceOp (op, result);
365
365
return success ();
@@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
395
395
Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
396
396
397
397
if (operandETy.getIntOrFloatBitWidth () < 32 ) {
398
- operand = b.create <arith::ExtFOp>(f32Ty, operand);
398
+ operand = b.create <arith::ExtFOp>(f32Ty, operand, op. getFastmathAttr () );
399
399
} else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
400
- operand = b.create <arith::TruncFOp>(f32Ty, operand);
400
+ operand = b.create <arith::TruncFOp>(
401
+ f32Ty, operand, op.getRoundingmodeAttr (), op.getFastmathAttr ());
401
402
}
402
403
Value f32Bits = b.create <arith::BitcastOp>(i32Ty, operand);
403
404
Value cF32MantissaWidth = createConst (op->getLoc (), i32Ty, 23 , rewriter);
@@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
409
410
}
410
411
};
411
412
413
+ struct ScalingExtFOpConverter : public OpRewritePattern <arith::ScalingExtFOp> {
414
+ using OpRewritePattern::OpRewritePattern;
415
+ LogicalResult matchAndRewrite (arith::ScalingExtFOp op,
416
+ PatternRewriter &rewriter) const final {
417
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
418
+ Value inputOperand = op.getIn ();
419
+ Value scaleOperand = op.getScale ();
420
+ Type scaleTy = scaleOperand.getType ();
421
+ Type scaleETy = getElementTypeOrSelf (scaleOperand);
422
+ // allow implicit exponent extraction from 16/32 bits floats
423
+ if (scaleETy.getIntOrFloatBitWidth () >= 16 ) {
424
+ scaleETy = b.getF8E8M0Type ();
425
+ scaleTy = cloneToShapedType (scaleTy, scaleETy);
426
+ scaleOperand = b.create <arith::TruncFOp>(scaleTy, scaleOperand, nullptr ,
427
+ op.getFastmathAttr ());
428
+ }
429
+ if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
430
+ return rewriter.notifyMatchFailure (
431
+ op, " scaling_extf is using scales of type which can not be converted "
432
+ " to f8E8M0FNU" );
433
+ }
434
+ Type resultTy = op.getType ();
435
+ // extf on scale will essentially create floating point number
436
+ // of type resulTy that is 2^scale and will also propagate NaNs
437
+ Value scaleExt =
438
+ b.create <arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr ());
439
+ Value inputExt =
440
+ b.create <arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr ());
441
+ Value result =
442
+ b.create <arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr ());
443
+ rewriter.replaceOp (op, result);
444
+ return success ();
445
+ }
446
+ };
447
+
448
+ /*
449
+ Expands arith.ScalingTruncFOp(in, scale) into
450
+ scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
451
+ result = arith.truncf(in / (2^scale))
452
+ */
453
+ struct ScalingTruncFOpConverter
454
+ : public OpRewritePattern<arith::ScalingTruncFOp> {
455
+ using OpRewritePattern::OpRewritePattern;
456
+ LogicalResult matchAndRewrite (arith::ScalingTruncFOp op,
457
+ PatternRewriter &rewriter) const final {
458
+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
459
+ Value inputOperand = op.getIn ();
460
+ Value scaleOperand = op.getScale ();
461
+ Type scaleTy = scaleOperand.getType ();
462
+ Type scaleETy = getElementTypeOrSelf (scaleOperand);
463
+ // allow implicit exponent extraction from 16/32 bits floats
464
+ if (scaleETy.getIntOrFloatBitWidth () >= 16 ) {
465
+ scaleETy = b.getF8E8M0Type ();
466
+ scaleTy = cloneToShapedType (scaleTy, scaleETy);
467
+ scaleOperand = b.create <arith::TruncFOp>(scaleTy, scaleOperand, nullptr ,
468
+ op.getFastmathAttr ());
469
+ }
470
+ if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
471
+ return rewriter.notifyMatchFailure (
472
+ op, " scaling_truncf is using scales type which can not be converted "
473
+ " to f8E8M0FNU" );
474
+ }
475
+ Type resultTy = op.getType ();
476
+ Type inputTy = inputOperand.getType ();
477
+ // this will create a floating point number of type
478
+ // inputTy that is 2^scale and will also propagate NaNs
479
+ scaleOperand =
480
+ b.create <arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr ());
481
+ Value result = b.create <arith::DivFOp>(inputOperand, scaleOperand,
482
+ op.getFastmathAttr ());
483
+ Value resultCast = b.create <arith::TruncFOp>(
484
+ resultTy, result, op.getRoundingmodeAttr (), op.getFastmathAttr ());
485
+ rewriter.replaceOp (op, resultCast);
486
+ return success ();
487
+ }
488
+ };
489
+
412
490
struct ArithExpandOpsPass
413
491
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
414
492
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -432,7 +510,9 @@ struct ArithExpandOpsPass
432
510
arith::MaximumFOp,
433
511
arith::MinimumFOp,
434
512
arith::MaxNumFOp,
435
- arith::MinNumFOp
513
+ arith::MinNumFOp,
514
+ arith::ScalingExtFOp,
515
+ arith::ScalingTruncFOp
436
516
>();
437
517
438
518
if (includeBf16) {
@@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
492
572
patterns.getContext ());
493
573
}
494
574
575
+ void mlir::arith::populateExpandScalingExtTruncPatterns (
576
+ RewritePatternSet &patterns) {
577
+ patterns.add <ScalingExtFOpConverter, ScalingTruncFOpConverter>(
578
+ patterns.getContext ());
579
+ }
580
+
495
581
void mlir::arith::populateArithExpandOpsPatterns (RewritePatternSet &patterns) {
496
582
populateCeilFloorDivExpandOpsPatterns (patterns);
583
+ populateExpandScalingExtTruncPatterns (patterns);
497
584
// clang-format off
498
585
patterns.add <
499
586
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
@@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
503
590
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
504
591
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
505
592
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
506
- MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
593
+ MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
507
594
>(patterns.getContext ());
508
595
// clang-format on
509
596
}
0 commit comments