Skip to content

Introduce arith.scaling_extf and arith.scaling_truncf #141965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 38 commits into from
Jun 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
1ed7462
Make it elementwise op
umangyadav May 28, 2025
91bb889
Add flushing logic
umangyadav May 28, 2025
8eebbea
Fix build issues
umangyadav May 28, 2025
acc6658
clamping on exponent
umangyadav May 29, 2025
6797446
propagate rounding mode and fast math attrs
umangyadav May 29, 2025
3ad83bd
Add some more notes
umangyadav May 29, 2025
9f755c2
Merge branch 'main' into scaling_cvt
umangyadav May 29, 2025
5e49a72
add scaling_extf tests
umangyadav May 29, 2025
682573e
Fix some issues
umangyadav May 29, 2025
de4497b
add test for scaling_truncf
umangyadav May 29, 2025
e239157
add some more tests
umangyadav May 29, 2025
646465c
Fix Formatting
umangyadav May 29, 2025
20b0928
Merge branch 'main' into scaling_cvt
umangyadav May 29, 2025
80c080f
Remove TODO
umangyadav May 29, 2025
b5df100
Update mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
umangyadav May 29, 2025
b6589ae
Update mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
umangyadav May 29, 2025
12c52a6
Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
umangyadav May 29, 2025
b3cadf2
Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
umangyadav May 29, 2025
5558b03
Merge remote-tracking branch 'upstream/main' into scaling_cvt
umangyadav May 29, 2025
fc90780
Allow implicit truncf to f8E8M0FN type to extract exponent bits
umangyadav May 29, 2025
8f91e28
USe floating point to normalize scales
umangyadav May 30, 2025
dc7b67f
Rewrite description
umangyadav May 30, 2025
109ddc5
change error message
umangyadav May 30, 2025
f3d9865
some nits
umangyadav May 30, 2025
95a7558
Merge remote-tracking branch 'upstream/main' into scaling_cvt
umangyadav May 30, 2025
3ccb208
Formatting
umangyadav May 30, 2025
d154341
Change comment
umangyadav May 30, 2025
d8a76fa
Update mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
umangyadav May 31, 2025
a0aa490
Update mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
umangyadav May 31, 2025
ff66dad
address some review comments
umangyadav May 31, 2025
10a1bc3
Merge branch 'main' into scaling_cvt
umangyadav May 31, 2025
3c7980d
Merge remote-tracking branch 'upstream/main' into scaling_cvt
umangyadav Jun 2, 2025
f7c1b79
Fix docs
umangyadav Jun 2, 2025
229f6b8
Merge remote-tracking branch 'upstream/main' into scaling_cvt
umangyadav Jun 6, 2025
45e7dba
Simplify arith.scaling_truncf to just do division and trunction. Deno…
umangyadav Jun 6, 2025
80061d6
address review comments and add tests
umangyadav Jun 6, 2025
a38ac5e
Formatting
umangyadav Jun 6, 2025
8151fc7
Merge branch 'main' into scaling_cvt
umangyadav Jun 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,58 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
attr-dict `:` type($in) `to` type($out) }];
}

//===----------------------------------------------------------------------===//
// Scaling ExtFOp
//===----------------------------------------------------------------------===//
def Arith_ScalingExtFOp
: Arith_Op<
"scaling_extf", [Pure, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in, FloatLike:$scale,
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary = "Upcasts input floats using provided scales values following "
"OCP MXFP Spec";
let description = [{
This operation upcasts input floating-point values using provided scale
values. It expects both scales and the input operand to be of the same shape,
making the operation elementwise. Scales are usually calculated per block
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.

If scales are calculated per block where blockSize != 1, then scales may
require broadcasting to make this operation elementwise. For example, let's
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
assuming quantization happens on the last axis, the input can be reshaped to
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
per block on the last axis. Therefore, scales will be of shape
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
shape as long as it is broadcast compatible with the input, e.g.,
`<1 x 1 x ... (dimN/blockSize) x 1>`.

In this example, before calling into `arith.scaling_extf`, scales must be
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
that there could be multiple quantization axes. Internally,
`arith.scaling_extf` would perform the following:

```
resultTy = get_type(result)
scaleTy = get_type(scale)
inputTy = get_type(input)
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
input.extf = arith.extf(input) : inputTy to resultTy
result = arith.mulf(scale.extf, input.extf)
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
}];
let hasVerifier = 1;
let assemblyFormat =
[{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:`
type($in) `,` type($scale) `to` type($out)}];
}

//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1280,6 +1332,63 @@ def Arith_TruncFOp :
attr-dict `:` type($in) `to` type($out) }];
}

//===----------------------------------------------------------------------===//
// Scaling TruncFOp
//===----------------------------------------------------------------------===//

def Arith_ScalingTruncFOp
: Arith_Op<"scaling_truncf",
[Pure, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in, FloatLike:$scale,
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
Results<(outs FloatLike:$out)> {
let summary = "Downcasts input floating point values using provided scales "
"values following OCP MXFP Spec";
let description = [{
This operation downcasts input using the provided scale values. It expects
both scales and the input operand to be of the same shape and, therefore,
makes the operation elementwise. Scales are usually calculated per block
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
Users are required to normalize and clamp the scales as necessary before calling
passing them to this operation. OCP MXFP spec also does the flushing of denorms
on the input operand, which should be handled during lowering by passing appropriate
fastMath flag to this operation.

If scales are calculated per block where blockSize != 1, scales may require
broadcasting to make this operation elementwise. For example, let's say the
input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
assuming quantization happens on the last axis, the input can be reshaped to
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
per block on the last axis. Therefore, scales will be of shape
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
shape as long as it is broadcast compatible with the input, e.g.,
`<1 x 1 x ... (dimN/blockSize) x 1>`.

In this example, before calling into `arith.scaling_truncf`, scales must be
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
that there could be multiple quantization axes. Internally,
`arith.scaling_truncf` would perform the following:

```
scaleTy = get_type(scale)
inputTy = get_type(input)
resultTy = get_type(result)
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
result = arith.divf(input, scale.extf)
result.cast = arith.truncf(result, resultTy)
```
}];
let hasVerifier = 1;
let assemblyFormat =
[{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:`
type($in) `,` type($scale) `to` type($out)}];
}

//===----------------------------------------------------------------------===//
// UIToFPOp
//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);

/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);

/// Add patterns to expand Arith ops.
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Builder {
Attribute metadata = Attribute());

// Types.
FloatType getF8E8M0Type();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
Expand Down
26 changes: 26 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {

LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }

//===----------------------------------------------------------------------===//
// ScalingExtFOp
//===----------------------------------------------------------------------===//

bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
TypeRange outputs) {
return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
}

LogicalResult arith::ScalingExtFOp::verify() {
return verifyExtOp<FloatType>(*this);
}

//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
return verifyTruncateOp<FloatType>(*this);
}

//===----------------------------------------------------------------------===//
// ScalingTruncFOp
//===----------------------------------------------------------------------===//

bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
TypeRange outputs) {
return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
}

LogicalResult arith::ScalingTruncFOp::verify() {
return verifyTruncateOp<FloatType>(*this);
}

//===----------------------------------------------------------------------===//
// AndIOp
//===----------------------------------------------------------------------===//
Expand Down
105 changes: 96 additions & 9 deletions mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Arith/Transforms/Passes.h"

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value,
return rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(shapedTy, attr));
}

return rewriter.create<arith::ConstantOp>(loc, attr);
}

Expand Down Expand Up @@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
if (resultETy.getIntOrFloatBitWidth() < 32) {
result = b.create<arith::TruncFOp>(resultTy, result);
result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
op.getFastmathAttr());
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
result = b.create<arith::ExtFOp>(resultTy, result);
result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
}
rewriter.replaceOp(op, result);
return success();
Expand Down Expand Up @@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());

if (operandETy.getIntOrFloatBitWidth() < 32) {
operand = b.create<arith::ExtFOp>(f32Ty, operand);
operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
operand = b.create<arith::TruncFOp>(f32Ty, operand);
operand = b.create<arith::TruncFOp>(
f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
}
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
Expand All @@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
}
};

struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value inputOperand = op.getIn();
Value scaleOperand = op.getScale();
Type scaleTy = scaleOperand.getType();
Type scaleETy = getElementTypeOrSelf(scaleOperand);
// allow implicit exponent extraction from 16/32 bits floats
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
op.getFastmathAttr());
}
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
return rewriter.notifyMatchFailure(
op, "scaling_extf is using scales of type which can not be converted "
"to f8E8M0FNU");
}
Type resultTy = op.getType();
// extf on scale will essentially create floating point number
// of type resulTy that is 2^scale and will also propagate NaNs
Value scaleExt =
b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
Value inputExt =
b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
Value result =
b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
rewriter.replaceOp(op, result);
return success();
}
};

/*
Expands arith.ScalingTruncFOp(in, scale) into
scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
result = arith.truncf(in / (2^scale))
*/
struct ScalingTruncFOpConverter
: public OpRewritePattern<arith::ScalingTruncFOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const final {
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Value inputOperand = op.getIn();
Value scaleOperand = op.getScale();
Type scaleTy = scaleOperand.getType();
Type scaleETy = getElementTypeOrSelf(scaleOperand);
// allow implicit exponent extraction from 16/32 bits floats
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
scaleETy = b.getF8E8M0Type();
scaleTy = cloneToShapedType(scaleTy, scaleETy);
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
op.getFastmathAttr());
}
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
return rewriter.notifyMatchFailure(
op, "scaling_truncf is using scales type which can not be converted "
"to f8E8M0FNU");
}
Type resultTy = op.getType();
Type inputTy = inputOperand.getType();
// this will create a floating point number of type
// inputTy that is 2^scale and will also propagate NaNs
scaleOperand =
b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
op.getFastmathAttr());
Value resultCast = b.create<arith::TruncFOp>(
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check resultTy <= f32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we check resultTy <= f32?

Verify() checks that output width is smaller compared to input.

https://github.com/umangyadav/llvm-project/blob/d1543414578abf95a495b4eb6fe9b6201de8e9f6/mlir/lib/Dialect/Arith/IR/ArithOps.cpp#L1587

there are other arith ops, shouldn't we propagate to those as well? also for ScalingExtFOpConverter

No, other arith.truncf are mainly for scales dtype conversion which just operates on exponent and not really affected by rounding mode or fast math.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, verify checks that output width is smaller than input width. But I understand the output of this function is always f32. Then, I wonder if somebody can do input, scale -> f128, result -> f64. Then, it's true that output width < input width and we are still trying to truncate "result" which is f32 into f64. Not sure if I misunderstood something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.
Then, Verify() checks is a strict check. Therefore output_bit_width < input_bit_width.
So this would never really be truncating to f32 resultTy in practice.

But I understand the output of this function is always f32

No, why do you think so ? Output dtype will be whatever user has specified.

Copy link
Contributor

@dhernandez0 dhernandez0 Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, why do you think so ? Output dtype will be whatever user has specified.

I mean result of the function before truncation. result.dtype = f32, right?

In practice, Float64/80/128 dtypes are something that is not expected. I think it is safe to assume F32 is the largest dtype that can appear on the input.

I think arith dialect is not supposed to be hardware specific, so even though for us it's not expected. I'd prefer to enforce or check the assumption somehow. But it seems ok for me anyway, whatever you decide.

rewriter.replaceOp(op, resultCast);
return success();
}
};

struct ArithExpandOpsPass
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
Expand All @@ -432,7 +510,9 @@ struct ArithExpandOpsPass
arith::MaximumFOp,
arith::MinimumFOp,
arith::MaxNumFOp,
arith::MinNumFOp
arith::MinNumFOp,
arith::ScalingExtFOp,
arith::ScalingTruncFOp
>();

if (includeBf16) {
Expand Down Expand Up @@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
patterns.getContext());
}

void mlir::arith::populateExpandScalingExtTruncPatterns(
RewritePatternSet &patterns) {
patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
patterns.getContext());
}

void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
populateCeilFloorDivExpandOpsPatterns(patterns);
populateExpandScalingExtTruncPatterns(patterns);
// clang-format off
patterns.add<
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
Expand All @@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
>(patterns.getContext());
// clang-format on
}
2 changes: 2 additions & 0 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//

FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }

FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }

FloatType Builder::getF16Type() { return Float16Type::get(context); }
Expand Down
Loading
Loading