Skip to content

Commit 7f08503

Browse files
umangyadavpashu123krzysz00
authored
Introduce arith.scaling_extf and arith.scaling_truncf (#141965)
This PR adds `arith.scaling_truncf` and `arith.scaling_extf` operations which supports the block quantization following OCP MXFP specs listed here https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf OCP MXFP Spec comes with reference implementation here https://github.com/microsoft/microxcaling/tree/main Interesting piece of reference code is this method `_quantize_mx` https://github.com/microsoft/microxcaling/blob/7bc41952de394f5cc5e782baf132e7c7542eb4e4/mx/mx_ops.py#L173. Both `arith.scaling_truncf` and `arith.scaling_extf` are designed to be an elementwise operation. Please see description about them in `ArithOps.td` file for more details. Internally, `arith.scaling_truncf` does the `arith.truncf(arith.divf(input/(2^scale)))`. `scale` should have necessary broadcast, clamping, normalization and NaN propagation done before callling into `arith.scaling_truncf`. `arith.scaling_extf` does the `arith.mulf(2^scale, input)` after taking care of necessary data type conversions. CC: @krzysz00 @dhernandez0 @bjacob @pashu123 @MaheshRavishankar @tgymnich --------- Co-authored-by: Prashant Kumar <[email protected]> Co-authored-by: Krzysztof Drewniak <[email protected]>
1 parent 5d6218d commit 7f08503

File tree

7 files changed

+411
-13
lines changed

7 files changed

+411
-13
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,58 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFast
12151215
attr-dict `:` type($in) `to` type($out) }];
12161216
}
12171217

1218+
//===----------------------------------------------------------------------===//
1219+
// Scaling ExtFOp
1220+
//===----------------------------------------------------------------------===//
1221+
def Arith_ScalingExtFOp
1222+
: Arith_Op<
1223+
"scaling_extf", [Pure, SameInputOutputTensorDims,
1224+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
1225+
DeclareOpInterfaceMethods<CastOpInterface>]>,
1226+
Arguments<(ins FloatLike:$in, FloatLike:$scale,
1227+
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
1228+
Results<(outs FloatLike:$out)> {
1229+
let summary = "Upcasts input floats using provided scales values following "
1230+
"OCP MXFP Spec";
1231+
let description = [{
1232+
This operation upcasts input floating-point values using provided scale
1233+
values. It expects both scales and the input operand to be of the same shape,
1234+
making the operation elementwise. Scales are usually calculated per block
1235+
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1236+
1237+
If scales are calculated per block where blockSize != 1, then scales may
1238+
require broadcasting to make this operation elementwise. For example, let's
1239+
say the input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1240+
assuming quantization happens on the last axis, the input can be reshaped to
1241+
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1242+
per block on the last axis. Therefore, scales will be of shape
1243+
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1244+
shape as long as it is broadcast compatible with the input, e.g.,
1245+
`<1 x 1 x ... (dimN/blockSize) x 1>`.
1246+
1247+
In this example, before calling into `arith.scaling_extf`, scales must be
1248+
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1249+
that there could be multiple quantization axes. Internally,
1250+
`arith.scaling_extf` would perform the following:
1251+
1252+
```
1253+
resultTy = get_type(result)
1254+
scaleTy = get_type(scale)
1255+
inputTy = get_type(input)
1256+
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1257+
scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
1258+
input.extf = arith.extf(input) : inputTy to resultTy
1259+
result = arith.mulf(scale.extf, input.extf)
1260+
```
1261+
It propagates NaN values. Therefore, if either scale or the input element
1262+
contains NaN, then the output element value will also be a NaN.
1263+
}];
1264+
let hasVerifier = 1;
1265+
let assemblyFormat =
1266+
[{ $in `,` $scale (`fastmath` `` $fastmath^)? attr-dict `:`
1267+
type($in) `,` type($scale) `to` type($out)}];
1268+
}
1269+
12181270
//===----------------------------------------------------------------------===//
12191271
// TruncIOp
12201272
//===----------------------------------------------------------------------===//
@@ -1280,6 +1332,63 @@ def Arith_TruncFOp :
12801332
attr-dict `:` type($in) `to` type($out) }];
12811333
}
12821334

1335+
//===----------------------------------------------------------------------===//
1336+
// Scaling TruncFOp
1337+
//===----------------------------------------------------------------------===//
1338+
1339+
def Arith_ScalingTruncFOp
1340+
: Arith_Op<"scaling_truncf",
1341+
[Pure, SameInputOutputTensorDims,
1342+
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
1343+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
1344+
DeclareOpInterfaceMethods<CastOpInterface>]>,
1345+
Arguments<(ins FloatLike:$in, FloatLike:$scale,
1346+
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
1347+
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
1348+
Results<(outs FloatLike:$out)> {
1349+
let summary = "Downcasts input floating point values using provided scales "
1350+
"values following OCP MXFP Spec";
1351+
let description = [{
1352+
This operation downcasts input using the provided scale values. It expects
1353+
both scales and the input operand to be of the same shape and, therefore,
1354+
makes the operation elementwise. Scales are usually calculated per block
1355+
following the OCP MXFP spec as described in https://arxiv.org/abs/2310.10537.
1356+
Users are required to normalize and clamp the scales as necessary before calling
1357+
passing them to this operation. OCP MXFP spec also does the flushing of denorms
1358+
on the input operand, which should be handled during lowering by passing appropriate
1359+
fastMath flag to this operation.
1360+
1361+
If scales are calculated per block where blockSize != 1, scales may require
1362+
broadcasting to make this operation elementwise. For example, let's say the
1363+
input is of shape `<dim1 x dim2 x ... dimN>`. Given blockSize != 1 and
1364+
assuming quantization happens on the last axis, the input can be reshaped to
1365+
`<dim1 x dim2 x ... (dimN/blockSize) x blockSize>`. Scales will be calculated
1366+
per block on the last axis. Therefore, scales will be of shape
1367+
`<dim1 x dim2 x ... (dimN/blockSize) x 1>`. Scales could also be of some other
1368+
shape as long as it is broadcast compatible with the input, e.g.,
1369+
`<1 x 1 x ... (dimN/blockSize) x 1>`.
1370+
1371+
In this example, before calling into `arith.scaling_truncf`, scales must be
1372+
broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
1373+
that there could be multiple quantization axes. Internally,
1374+
`arith.scaling_truncf` would perform the following:
1375+
1376+
```
1377+
scaleTy = get_type(scale)
1378+
inputTy = get_type(input)
1379+
resultTy = get_type(result)
1380+
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
1381+
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
1382+
result = arith.divf(input, scale.extf)
1383+
result.cast = arith.truncf(result, resultTy)
1384+
```
1385+
}];
1386+
let hasVerifier = 1;
1387+
let assemblyFormat =
1388+
[{ $in `,` $scale ($roundingmode^)? (`fastmath` `` $fastmath^)? attr-dict `:`
1389+
type($in) `,` type($scale) `to` type($out)}];
1390+
}
1391+
12831392
//===----------------------------------------------------------------------===//
12841393
// UIToFPOp
12851394
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
6262
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
6363
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
6464

65+
/// Add patterns to expand scaling ExtF/TruncF ops to equivalent arith ops
66+
void populateExpandScalingExtTruncPatterns(RewritePatternSet &patterns);
67+
6568
/// Add patterns to expand Arith ops.
6669
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
6770

mlir/include/mlir/IR/Builders.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class Builder {
6060
Attribute metadata = Attribute());
6161

6262
// Types.
63+
FloatType getF8E8M0Type();
6364
FloatType getBF16Type();
6465
FloatType getF16Type();
6566
FloatType getTF32Type();

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,19 @@ bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
14511451

14521452
LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
14531453

1454+
//===----------------------------------------------------------------------===//
1455+
// ScalingExtFOp
1456+
//===----------------------------------------------------------------------===//
1457+
1458+
bool arith::ScalingExtFOp::areCastCompatible(TypeRange inputs,
1459+
TypeRange outputs) {
1460+
return checkWidthChangeCast<std::greater, FloatType>(inputs.front(), outputs);
1461+
}
1462+
1463+
LogicalResult arith::ScalingExtFOp::verify() {
1464+
return verifyExtOp<FloatType>(*this);
1465+
}
1466+
14541467
//===----------------------------------------------------------------------===//
14551468
// TruncIOp
14561469
//===----------------------------------------------------------------------===//
@@ -1565,6 +1578,19 @@ LogicalResult arith::TruncFOp::verify() {
15651578
return verifyTruncateOp<FloatType>(*this);
15661579
}
15671580

1581+
//===----------------------------------------------------------------------===//
1582+
// ScalingTruncFOp
1583+
//===----------------------------------------------------------------------===//
1584+
1585+
bool arith::ScalingTruncFOp::areCastCompatible(TypeRange inputs,
1586+
TypeRange outputs) {
1587+
return checkWidthChangeCast<std::less, FloatType>(inputs.front(), outputs);
1588+
}
1589+
1590+
LogicalResult arith::ScalingTruncFOp::verify() {
1591+
return verifyTruncateOp<FloatType>(*this);
1592+
}
1593+
15681594
//===----------------------------------------------------------------------===//
15691595
// AndIOp
15701596
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 96 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Dialect/Arith/Transforms/Passes.h"
10-
119
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1211
#include "mlir/Dialect/Vector/IR/VectorOps.h"
12+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1313
#include "mlir/IR/ImplicitLocOpBuilder.h"
1414
#include "mlir/IR/TypeUtilities.h"
1515
#include "mlir/Transforms/DialectConversion.h"
@@ -31,7 +31,6 @@ static Value createConst(Location loc, Type type, int value,
3131
return rewriter.create<arith::ConstantOp>(
3232
loc, DenseElementsAttr::get(shapedTy, attr));
3333
}
34-
3534
return rewriter.create<arith::ConstantOp>(loc, attr);
3635
}
3736

@@ -357,9 +356,10 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
357356
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
358357
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
359358
if (resultETy.getIntOrFloatBitWidth() < 32) {
360-
result = b.create<arith::TruncFOp>(resultTy, result);
359+
result = b.create<arith::TruncFOp>(resultTy, result, nullptr,
360+
op.getFastmathAttr());
361361
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
362-
result = b.create<arith::ExtFOp>(resultTy, result);
362+
result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr());
363363
}
364364
rewriter.replaceOp(op, result);
365365
return success();
@@ -395,9 +395,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
395395
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
396396

397397
if (operandETy.getIntOrFloatBitWidth() < 32) {
398-
operand = b.create<arith::ExtFOp>(f32Ty, operand);
398+
operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr());
399399
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
400-
operand = b.create<arith::TruncFOp>(f32Ty, operand);
400+
operand = b.create<arith::TruncFOp>(
401+
f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr());
401402
}
402403
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
403404
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
@@ -409,6 +410,83 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
409410
}
410411
};
411412

413+
struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> {
414+
using OpRewritePattern::OpRewritePattern;
415+
LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
416+
PatternRewriter &rewriter) const final {
417+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
418+
Value inputOperand = op.getIn();
419+
Value scaleOperand = op.getScale();
420+
Type scaleTy = scaleOperand.getType();
421+
Type scaleETy = getElementTypeOrSelf(scaleOperand);
422+
// allow implicit exponent extraction from 16/32 bits floats
423+
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
424+
scaleETy = b.getF8E8M0Type();
425+
scaleTy = cloneToShapedType(scaleTy, scaleETy);
426+
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
427+
op.getFastmathAttr());
428+
}
429+
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
430+
return rewriter.notifyMatchFailure(
431+
op, "scaling_extf is using scales of type which can not be converted "
432+
"to f8E8M0FNU");
433+
}
434+
Type resultTy = op.getType();
435+
// extf on scale will essentially create floating point number
436+
// of type resulTy that is 2^scale and will also propagate NaNs
437+
Value scaleExt =
438+
b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr());
439+
Value inputExt =
440+
b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr());
441+
Value result =
442+
b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr());
443+
rewriter.replaceOp(op, result);
444+
return success();
445+
}
446+
};
447+
448+
/*
449+
Expands arith.ScalingTruncFOp(in, scale) into
450+
scale = arith.truncf(scale) : scaleTy -> f8E8M0FNU
451+
result = arith.truncf(in / (2^scale))
452+
*/
453+
struct ScalingTruncFOpConverter
454+
: public OpRewritePattern<arith::ScalingTruncFOp> {
455+
using OpRewritePattern::OpRewritePattern;
456+
LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
457+
PatternRewriter &rewriter) const final {
458+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
459+
Value inputOperand = op.getIn();
460+
Value scaleOperand = op.getScale();
461+
Type scaleTy = scaleOperand.getType();
462+
Type scaleETy = getElementTypeOrSelf(scaleOperand);
463+
// allow implicit exponent extraction from 16/32 bits floats
464+
if (scaleETy.getIntOrFloatBitWidth() >= 16) {
465+
scaleETy = b.getF8E8M0Type();
466+
scaleTy = cloneToShapedType(scaleTy, scaleETy);
467+
scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr,
468+
op.getFastmathAttr());
469+
}
470+
if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) {
471+
return rewriter.notifyMatchFailure(
472+
op, "scaling_truncf is using scales type which can not be converted "
473+
"to f8E8M0FNU");
474+
}
475+
Type resultTy = op.getType();
476+
Type inputTy = inputOperand.getType();
477+
// this will create a floating point number of type
478+
// inputTy that is 2^scale and will also propagate NaNs
479+
scaleOperand =
480+
b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr());
481+
Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand,
482+
op.getFastmathAttr());
483+
Value resultCast = b.create<arith::TruncFOp>(
484+
resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr());
485+
rewriter.replaceOp(op, resultCast);
486+
return success();
487+
}
488+
};
489+
412490
struct ArithExpandOpsPass
413491
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
414492
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -432,7 +510,9 @@ struct ArithExpandOpsPass
432510
arith::MaximumFOp,
433511
arith::MinimumFOp,
434512
arith::MaxNumFOp,
435-
arith::MinNumFOp
513+
arith::MinNumFOp,
514+
arith::ScalingExtFOp,
515+
arith::ScalingTruncFOp
436516
>();
437517

438518
if (includeBf16) {
@@ -492,8 +572,15 @@ void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
492572
patterns.getContext());
493573
}
494574

575+
void mlir::arith::populateExpandScalingExtTruncPatterns(
576+
RewritePatternSet &patterns) {
577+
patterns.add<ScalingExtFOpConverter, ScalingTruncFOpConverter>(
578+
patterns.getContext());
579+
}
580+
495581
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
496582
populateCeilFloorDivExpandOpsPatterns(patterns);
583+
populateExpandScalingExtTruncPatterns(patterns);
497584
// clang-format off
498585
patterns.add<
499586
MaxMinIOpConverter<MaxSIOp, arith::CmpIPredicate::sgt>,
@@ -503,7 +590,7 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
503590
MaximumMinimumFOpConverter<MaximumFOp, arith::CmpFPredicate::UGT>,
504591
MaximumMinimumFOpConverter<MinimumFOp, arith::CmpFPredicate::ULT>,
505592
MaxNumMinNumFOpConverter<MaxNumFOp, arith::CmpFPredicate::UGT>,
506-
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
593+
MaxNumMinNumFOpConverter<MinNumFOp, arith::CmpFPredicate::ULT>
507594
>(patterns.getContext());
508595
// clang-format on
509596
}

mlir/lib/IR/Builders.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
3434
// Types.
3535
//===----------------------------------------------------------------------===//
3636

37+
FloatType Builder::getF8E8M0Type() { return Float8E8M0FNUType::get(context); }
38+
3739
FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }
3840

3941
FloatType Builder::getF16Type() { return Float16Type::get(context); }

0 commit comments

Comments
 (0)