Skip to content

Commit 66052c8

Browse files
umangyadavJaddyen
authored andcommitted
Add arith expansion of f8E8M0 type for extf/trunc ops (llvm#140332)
F8E8M0 floating type is supposed to represent biased exponent bits of F32 type in OCP Micro scaling floating point formats. https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This PR expands `arith.truncf` and `arith.extf` to support this behavior. For the `arith.truncf` thing to note here is that F8E8M0FNU type has one NaN representation which is encoded as `0xFF`. Therefore alll kinds of NaNs and +/-Inf in Float32Type would map to NaN in F8E8M0FNU. F8E8M0FNU doesn't have a sign bit therefore it is a lossy and irreversible downcast. cc: @krzysz00 @MaheshRavishankar @Muzammiluddin-Syed-ECE
1 parent 7c0c418 commit 66052c8

File tree

4 files changed

+270
-34
lines changed

4 files changed

+270
-34
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ void populateCeilFloorDivExpandOpsPatterns(RewritePatternSet &patterns);
5959
/// Add patterns to expand Arith bf16 patterns to lower level bitcasts/shifts.
6060
void populateExpandBFloat16Patterns(RewritePatternSet &patterns);
6161

62+
/// Add patterns to expand Arith f8e8m0 patterns to lower level bitcasts/shifts.
63+
void populateExpandF8E8M0Patterns(RewritePatternSet &patterns);
64+
6265
/// Add patterns to expand Arith ops.
6366
void populateArithExpandOpsPatterns(RewritePatternSet &patterns);
6467

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ include "mlir/Pass/PassBase.td"
1414
def ArithExpandOpsPass : Pass<"arith-expand"> {
1515
let summary = "Legalize Arith ops to be convertible to LLVM.";
1616
let dependentDialects = ["vector::VectorDialect"];
17-
let options = [
18-
Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
19-
"Enable the BF16 expansion patterns">,
17+
let options =
18+
[Option<"includeBf16", "include-bf16", "bool", /*default=*/"false",
19+
"Enable the BF16 expansion patterns">,
20+
Option<"includeF8E8M0", "include-f8e8m0", "bool", /*default=*/"false",
21+
"Enable the F8E8M0 expansion patterns">,
2022
];
2123
}
2224

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 133 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
3535
return rewriter.create<arith::ConstantOp>(loc, attr);
3636
}
3737

38+
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
39+
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
40+
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
41+
return shapedTy.clone(cloneTo);
42+
}
43+
return cloneTo;
44+
}
45+
3846
namespace {
3947

4048
/// Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
225233
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
226234
}
227235

228-
Type i16Ty = b.getI16Type();
229-
Type i32Ty = b.getI32Type();
230-
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
231-
i16Ty = shapedTy.clone(i16Ty);
232-
i32Ty = shapedTy.clone(i32Ty);
233-
}
236+
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
237+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
234238

235239
Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
236240
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264268
op, "only applicable to default rounding mode.");
265269
}
266270

267-
Type i16Ty = b.getI16Type();
268-
Type i32Ty = b.getI32Type();
269-
Type f32Ty = b.getF32Type();
270-
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
271-
i16Ty = shapedTy.clone(i16Ty);
272-
i32Ty = shapedTy.clone(i32Ty);
273-
f32Ty = shapedTy.clone(f32Ty);
274-
}
271+
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
272+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
275273

276274
// Algorithm borrowed from this excellent code:
277275
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -291,7 +289,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
291289
// Constant used to make the rounding bias.
292290
Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter);
293291
// Constant used to generate a quiet NaN.
294-
Value c7FC0_i16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
292+
Value c7FC0I16 = createConst(op.getLoc(), i16Ty, 0x7fc0, rewriter);
295293
// Small constants used to address bits.
296294
Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter);
297295
Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter);
@@ -313,18 +311,104 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
313311
// Now that the rounding-bias has been added, truncating the low bits
314312
// yields the correctly rounded result.
315313
Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16);
316-
Value normalCaseResult_i16 =
314+
Value normalCaseResultI16 =
317315
b.create<arith::TruncIOp>(i16Ty, biasedAndShifted);
318316
// Select either the above-computed result, or a quiet NaN constant
319317
// if the input was NaN.
320318
Value select =
321-
b.create<arith::SelectOp>(isNan, c7FC0_i16, normalCaseResult_i16);
319+
b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16);
322320
Value result = b.create<arith::BitcastOp>(resultTy, select);
323321
rewriter.replaceOp(op, result);
324322
return success();
325323
}
326324
};
327325

326+
struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
327+
using OpRewritePattern::OpRewritePattern;
328+
LogicalResult matchAndRewrite(arith::ExtFOp op,
329+
PatternRewriter &rewriter) const final {
330+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
331+
Value operand = op.getOperand();
332+
Type operandTy = operand.getType();
333+
Type resultTy = op.getType();
334+
Type operandETy = getElementTypeOrSelf(operandTy);
335+
Type resultETy = getElementTypeOrSelf(resultTy);
336+
337+
if (!llvm::isa<Float8E8M0FNUType>(operandETy)) {
338+
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
339+
}
340+
341+
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
342+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
343+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
344+
345+
Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
346+
// create constants for NaNs
347+
Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter);
348+
Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter);
349+
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
350+
351+
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
352+
Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth);
353+
354+
Value isNan =
355+
b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN);
356+
// select for NaNs
357+
f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits);
358+
Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits);
359+
if (resultETy.getIntOrFloatBitWidth() < 32) {
360+
result = b.create<arith::TruncFOp>(resultTy, result);
361+
} else if (resultETy.getIntOrFloatBitWidth() > 32) {
362+
result = b.create<arith::ExtFOp>(resultTy, result);
363+
}
364+
rewriter.replaceOp(op, result);
365+
return success();
366+
}
367+
};
368+
369+
/*
370+
TruncF to F8E8M0 is expected to extract exponent bits out of F32 type
371+
Since All kinds of Infs and NaNs are mapped to same exponent bits in F32 type,
372+
they all map to NaN in F8E8M0 Type.
373+
*/
374+
struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
375+
using OpRewritePattern::OpRewritePattern;
376+
LogicalResult matchAndRewrite(arith::TruncFOp op,
377+
PatternRewriter &rewriter) const final {
378+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
379+
Value operand = op.getOperand();
380+
Type operandTy = operand.getType();
381+
Type operandETy = getElementTypeOrSelf(operandTy);
382+
Type resultTy = op.getType();
383+
Type resultETy = getElementTypeOrSelf(resultTy);
384+
if (!llvm::isa<Float8E8M0FNUType>(resultETy)) {
385+
return rewriter.notifyMatchFailure(op, "not a truncf to f8E8M0FNU");
386+
}
387+
388+
if (op.getRoundingmodeAttr()) {
389+
return rewriter.notifyMatchFailure(
390+
op, "only applicable to default rounding mode.");
391+
}
392+
393+
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
394+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
395+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
396+
397+
if (operandETy.getIntOrFloatBitWidth() < 32) {
398+
operand = b.create<arith::ExtFOp>(f32Ty, operand);
399+
} else if (operandETy.getIntOrFloatBitWidth() > 32) {
400+
operand = b.create<arith::TruncFOp>(f32Ty, operand);
401+
}
402+
Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand);
403+
Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter);
404+
Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth);
405+
Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp);
406+
Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits);
407+
rewriter.replaceOp(op, result);
408+
return success();
409+
}
410+
};
411+
328412
struct ArithExpandOpsPass
329413
: public arith::impl::ArithExpandOpsPassBase<ArithExpandOpsPass> {
330414
using ArithExpandOpsPassBase::ArithExpandOpsPassBase;
@@ -353,20 +437,34 @@ struct ArithExpandOpsPass
353437

354438
if (includeBf16) {
355439
arith::populateExpandBFloat16Patterns(patterns);
356-
target.addDynamicallyLegalOp<arith::ExtFOp>(
357-
[](arith::ExtFOp op) {
358-
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
359-
Type outETy = getElementTypeOrSelf(op.getType());
360-
return !(inETy.isBF16() && outETy.isF32());
361-
});
362-
363-
target.addDynamicallyLegalOp<arith::TruncFOp>(
364-
[](arith::TruncFOp op) {
365-
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
366-
Type outETy = getElementTypeOrSelf(op.getType());
367-
return !(inETy.isF32() && outETy.isBF16());
368-
});
369440
}
441+
if (includeF8E8M0) {
442+
arith::populateExpandF8E8M0Patterns(patterns);
443+
}
444+
445+
target.addDynamicallyLegalOp<arith::ExtFOp>(
446+
[=](arith::ExtFOp op) {
447+
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
448+
Type outETy = getElementTypeOrSelf(op.getType());
449+
bool legalTypes = true;
450+
if (includeBf16)
451+
legalTypes &= !(inETy.isBF16() && outETy.isF32());
452+
if (includeF8E8M0)
453+
legalTypes &= !llvm::isa<Float8E8M0FNUType>(inETy);
454+
return legalTypes;
455+
});
456+
457+
target.addDynamicallyLegalOp<arith::TruncFOp>(
458+
[=](arith::TruncFOp op) {
459+
Type inETy = getElementTypeOrSelf(op.getOperand().getType());
460+
Type outETy = getElementTypeOrSelf(op.getType());
461+
bool legalTypes = true;
462+
if (includeBf16)
463+
legalTypes &= !(inETy.isF32() && outETy.isBF16());
464+
if (includeF8E8M0)
465+
legalTypes &= !(llvm::isa<Float8E8M0FNUType>(outETy));
466+
return legalTypes;
467+
});
370468

371469
// clang-format on
372470
if (failed(applyPartialConversion(getOperation(), target,
@@ -389,6 +487,11 @@ void mlir::arith::populateExpandBFloat16Patterns(RewritePatternSet &patterns) {
389487
patterns.getContext());
390488
}
391489

490+
void mlir::arith::populateExpandF8E8M0Patterns(RewritePatternSet &patterns) {
491+
patterns.add<F8E8M0ExtFOpConverter, F8E8M0TruncFOpConverter>(
492+
patterns.getContext());
493+
}
494+
392495
void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) {
393496
populateCeilFloorDivExpandOpsPatterns(patterns);
394497
// clang-format off

mlir/test/Dialect/Arith/expand-ops.mlir

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -arith-expand="include-bf16=true" -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -arith-expand="include-bf16=true include-f8e8m0=true" -split-input-file | FileCheck %s
22

33
// Test ceil divide with signed integer
44
// CHECK-LABEL: func @ceildivi
@@ -248,6 +248,134 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> {
248248
// CHECK-LABEL: @truncf_vector_f32
249249
// CHECK-NOT: arith.truncf
250250

251+
// -----
252+
func.func @truncf_f32_to_f8E8M0FNU(%arg0 : f32) -> f8E8M0FNU {
253+
%0 = arith.truncf %arg0 : f32 to f8E8M0FNU
254+
return %0 : f8E8M0FNU
255+
}
256+
// CHECK-LABLE: @truncf_f32_to_f8E8M0FNU
257+
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f32 to i32
258+
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
259+
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
260+
// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
261+
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
262+
// CHECK: return %[[RESULT]]
263+
264+
// -----
265+
266+
func.func @truncf_f16_to_f8E8M0FNU(%arg0 : f16) -> f8E8M0FNU {
267+
%0 = arith.truncf %arg0 : f16 to f8E8M0FNU
268+
return %0 : f8E8M0FNU
269+
}
270+
// CHECK-LABLE: @truncf_f16_to_f8E8M0FNU
271+
// CHECK: %[[EXTF:.+]] = arith.extf %arg0 : f16 to f32
272+
// CHECK: %[[BITCAST:.+]] = arith.bitcast %[[EXTF]] : f32 to i32
273+
// CHECK: %[[C23_i32:.+]] = arith.constant 23 : i32
274+
// CHECK: %[[SHRUI:.+]] = arith.shrui %[[BITCAST]], %[[C23_i32]] : i32
275+
// CHECK: %[[TRUNCI:.+]] = arith.trunci %[[SHRUI]] : i32 to i8
276+
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[TRUNCI]] : i8 to f8E8M0FNU
277+
// CHECK: return %[[RESULT]]
278+
279+
// -----
280+
281+
func.func @truncf_vector_f32_to_f8E8M0FNU(%arg0 : vector<4xf32>) -> vector<4xf8E8M0FNU> {
282+
%0 = arith.truncf %arg0 : vector<4xf32> to vector<4xf8E8M0FNU>
283+
return %0 : vector<4xf8E8M0FNU>
284+
}
285+
286+
// CHECK-LABEL: @truncf_vector_f32_to_f8E8M0FNU
287+
// CHECK-NOT: arith.truncf
288+
289+
// -----
290+
291+
func.func @truncf_vector_f16_to_f8E8M0FNU(%arg0 : vector<4xf16>) -> vector<4xf8E8M0FNU> {
292+
%0 = arith.truncf %arg0 : vector<4xf16> to vector<4xf8E8M0FNU>
293+
return %0 : vector<4xf8E8M0FNU>
294+
}
295+
296+
// CHECK-LABEL: @truncf_vector_f16_to_f8E8M0FNU
297+
// CHECK-NOT: arith.truncf
298+
299+
// -----
300+
301+
func.func @truncf_vector_bf16_to_f8E8M0FNU(%arg0 : vector<4xbf16>) -> vector<4xf8E8M0FNU> {
302+
%0 = arith.truncf %arg0 : vector<4xbf16> to vector<4xf8E8M0FNU>
303+
return %0 : vector<4xf8E8M0FNU>
304+
}
305+
306+
// CHECK-LABEL: @truncf_vector_bf16_to_f8E8M0FNU
307+
// CHECK-NOT: arith.truncf
308+
309+
310+
// -----
311+
func.func @extf_f8E8M0FNU_to_f32(%arg0 : f8E8M0FNU) -> f32 {
312+
%0 = arith.extf %arg0 : f8E8M0FNU to f32
313+
return %0 : f32
314+
}
315+
316+
// CHECK-LABLE: @extf_f8E8M0FNU_to_f32
317+
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
318+
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
319+
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
320+
// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
321+
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
322+
// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
323+
// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
324+
// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
325+
// CHECK: %[[RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
326+
// CHECK: return %[[RESULT]]
327+
328+
// -----
329+
330+
func.func @extf_f8E8M0FNU_to_f16(%arg0 : f8E8M0FNU) -> f16 {
331+
%0 = arith.extf %arg0 : f8E8M0FNU to f16
332+
return %0 : f16
333+
}
334+
335+
// CHECK-LABLE: @extf_f8E8M0FNU_to_f16
336+
// CHECK: %[[BITCAST:.+]] = arith.bitcast %arg0 : f8E8M0FNU to i8
337+
// CHECK-DAG: %[[CF8NAN:.+]] = arith.constant -1 : i8
338+
// CHECK-DAG: %[[CF32NAN:.+]] = arith.constant -1 : i32
339+
// CHECK-DAG: %[[C23_i32:.+]] = arith.constant 23 : i32
340+
// CHECK: %[[EXTUI:.+]] = arith.extui %[[BITCAST]] : i8 to i32
341+
// CHECK: %[[SHLI:.+]] = arith.shli %[[EXTUI]], %[[C23_i32]] : i32
342+
// CHECK: %[[CMP_NAN:.+]] = arith.cmpi eq, %[[BITCAST]], %[[CF8NAN]] : i8
343+
// CHECK: %[[SELECT_NAN:.+]] = arith.select %[[CMP_NAN]], %[[CF32NAN]], %[[SHLI]] : i32
344+
// CHECK: %[[F32_RESULT:.+]] = arith.bitcast %[[SELECT_NAN]] : i32 to f32
345+
// CHECK: %[[F16_RESULT:.+]] = arith.truncf %[[F32_RESULT]] : f32 to f16
346+
// CHECK: return %[[F16_RESULT]]
347+
348+
// -----
349+
350+
func.func @extf_vector_f8E8M0FNU_to_f32(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf32> {
351+
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf32>
352+
return %0 : vector<4xf32>
353+
}
354+
355+
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f32
356+
// CHECK-NOT: arith.extf
357+
358+
// -----
359+
360+
func.func @extf_vector_f8E8M0FNU_to_f16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xf16> {
361+
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xf16>
362+
return %0 : vector<4xf16>
363+
}
364+
365+
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_f16
366+
// CHECK-NOT: arith.extf
367+
368+
// -----
369+
370+
func.func @extf_vector_f8E8M0FNU_to_bf16(%arg0 : vector<4xf8E8M0FNU>) -> vector<4xbf16> {
371+
%0 = arith.extf %arg0 : vector<4xf8E8M0FNU> to vector<4xbf16>
372+
return %0 : vector<4xbf16>
373+
}
374+
375+
// CHECK-LABEL: @extf_vector_f8E8M0FNU_to_bf16
376+
// CHECK-NOT: arith.extf
377+
378+
251379
// -----
252380

253381
func.func @maxsi(%a: i32, %b: i32) -> i32 {

0 commit comments

Comments
 (0)