-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Introduce arith.scaling_extf
and arith.scaling_truncf
#141965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1ed7462
91bb889
8eebbea
acc6658
6797446
3ad83bd
9f755c2
5e49a72
682573e
de4497b
e239157
646465c
20b0928
80c080f
b5df100
b6589ae
12c52a6
b3cadf2
5558b03
fc90780
8f91e28
dc7b67f
109ddc5
f3d9865
95a7558
3ccb208
d154341
d8a76fa
a0aa490
ff66dad
10a1bc3
3c7980d
f7c1b79
229f6b8
45e7dba
80061d6
a38ac5e
8151fc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,10 +6,10 @@ | |
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
|
||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/BuiltinTypeInterfaces.h" | ||
#include "mlir/IR/ImplicitLocOpBuilder.h" | ||
#include "mlir/IR/TypeUtilities.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
@@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value, | |
return rewriter.create<arith::ConstantOp>( | ||
loc, DenseElementsAttr::get(shapedTy, attr)); | ||
} | ||
|
||
return rewriter.create<arith::ConstantOp>(loc, attr); | ||
} | ||
|
||
|
@@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { | |
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits); | ||
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits); | ||
if (resultETy.getIntOrFloatBitWidth() < 32) { | ||
result = b.create<arith::TruncFOp>(resultTy, result); | ||
result = b.create<arith::TruncFOp>(resultTy, result, nullptr, | ||
op.getFastmathAttr()); | ||
} else if (resultETy.getIntOrFloatBitWidth() > 32) { | ||
result = b.create<arith::ExtFOp>(resultTy, result); | ||
result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr()); | ||
} | ||
rewriter.replaceOp(op, result); | ||
return success(); | ||
|
@@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { | |
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); | ||
|
||
if (operandETy.getIntOrFloatBitWidth() < 32) { | ||
operand = b.create<arith::ExtFOp>(f32Ty, operand); | ||
operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr()); | ||
} else if (operandETy.getIntOrFloatBitWidth() > 32) { | ||
operand = b.create<arith::TruncFOp>(f32Ty, operand); | ||
operand = b.create<arith::TruncFOp>( | ||
f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); | ||
} | ||
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand); | ||
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); | ||
|
@@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { | |
} | ||
}; | ||
|
||
struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
LogicalResult matchAndRewrite(arith::ScalingExtFOp op, | ||
PatternRewriter &rewriter) const final { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
Value inputOperand = op.getIn(); | ||
Value scaleOperand = op.getScale(); | ||
Type scaleTy = scaleOperand.getType(); | ||
Type scaleETy = getElementTypeOrSelf(scaleOperand); | ||
// allow implicit exponent extraction from 16/32 bits floats | ||
if (scaleETy.getIntOrFloatBitWidth() >= 16) { | ||
scaleETy = b.getF8E8M0Type(); | ||
scaleTy = cloneToShapedType(scaleTy, scaleETy); | ||
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, | ||
op.getFastmathAttr()); | ||
} | ||
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { | ||
return rewriter.notifyMatchFailure( | ||
op, "scaling_extf is using scales of type which can not be converted " | ||
"to f8E8M0FNU"); | ||
} | ||
Type resultTy = op.getType(); | ||
// extf on scale will essentially create floating point number | ||
// of type resulTy that is 2^scale and will also propagate NaNs | ||
Value scaleExt = | ||
b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr()); | ||
Value inputExt = | ||
b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr()); | ||
Value result = | ||
b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr()); | ||
rewriter.replaceOp(op, result); | ||
return success(); | ||
} | ||
}; | ||
|
||
/* | ||
Expands arith.ScalingTruncFOp(in, scale) into | ||
scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU | ||
result = arith.truncf(in / (2^scale)) | ||
*/ | ||
struct ScalingTruncFOpConverter | ||
: public OpRewritePattern<arith::ScalingTruncFOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op, | ||
PatternRewriter &rewriter) const final { | ||
ImplicitLocOpBuilder b(op.getLoc(), rewriter); | ||
Value inputOperand = op.getIn(); | ||
Value scaleOperand = op.getScale(); | ||
Type scaleTy = scaleOperand.getType(); | ||
Type scaleETy = getElementTypeOrSelf(scaleOperand); | ||
// allow implicit exponent extraction from 16/32 bits floats | ||
if (scaleETy.getIntOrFloatBitWidth() >= 16) { | ||
scaleETy = b.getF8E8M0Type(); | ||
scaleTy = cloneToShapedType(scaleTy, scaleETy); | ||
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, | ||
op.getFastmathAttr()); | ||
} | ||
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { | ||
return rewriter.notifyMatchFailure( | ||
op, "scaling_truncf is using scales type which can not be converted " | ||
"to f8E8M0FNU"); | ||
} | ||
Type resultTy = op.getType(); | ||
Type inputTy = inputOperand.getType(); | ||
// this will create a floating point number of type | ||
// inputTy that is 2^scale and will also propagate NaNs | ||
scaleOperand = | ||
b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr()); | ||
Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand, | ||
op.getFastmathAttr()); | ||
Value resultCast = b.create<arith::TruncFOp>( | ||
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we check resultTy <= f32? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Verify() checks that output width is smaller compared to input.
No, other There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, verify checks that output width is smaller than input width. But I understand the output of this function is always f32. Then, I wonder if somebody can do input, scale -> f128, result -> f64. Then, it's true that output width < input width and we are still trying to truncate "result" which is f32 into f64. Not sure if I misunderstood something? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
No, why do you think so ? Output dtype will be whatever user has specified. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I mean result of the function before truncation. result.dtype = f32, right?
I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide. |
||
rewriter.replaceOp(op, resultCast); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct ArithExpandOpsPass | ||
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> { | ||
using ArithExpandOpsPassBase::ArithExpandOpsPassBase; | ||
|
@@ -432,7 +510,9 @@ struct ArithExpandOpsPass | |
arith::MaximumFOp, | ||
arith::MinimumFOp, | ||
arith::MaxNumFOp, | ||
arith::MinNumFOp | ||
arith::MinNumFOp, | ||
arith::ScalingExtFOp, | ||
arith::ScalingTruncFOp | ||
>(); | ||
|
||
if (includeBf16) { | ||
|
@@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) { | |
patterns.getContext()); | ||
} | ||
|
||
void mlir::arith::populateExpandScalingExtTruncPatterns( | ||
RewritePatternSet &patterns) { | ||
patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>( | ||
patterns.getContext()); | ||
} | ||
|
||
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { | ||
populateCeilFloorDivExpandOpsPatterns(patterns); | ||
populateExpandScalingExtTruncPatterns(patterns); | ||
// clang-format off | ||
patterns.add< | ||
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>, | ||
|
@@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { | |
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>, | ||
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>, | ||
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>, | ||
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> | ||
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT> | ||
>(patterns.getContext()); | ||
// clang-format on | ||
} |
Uh oh!
There was an error while loading. Please reload this page.