Skip to content

Commit 7042fcc

Browse files
author
Ivy Zhang
authored
[MLIR][Arith][Resubmit] add fastMathAttr on arith::extf and arith::truncf (#95346)
Add an `fastMathAttr` on `arith::extf` and `arith::truncf`. If these two ops are inserted by some promotion passes (like legalize-to-f32 / emulate-unsupported-floats), they will be labeled as `FastMathFlags::contract`, denoting that they can be then `eliminated by canonicalizer`. The `elimination` can help improve performance, while may introduce some numerical differences.
1 parent daac13f commit 7042fcc

File tree

6 files changed

+198
-24
lines changed

6 files changed

+198
-24
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
11991199
// ExtFOp
12001200
//===----------------------------------------------------------------------===//
12011201

1202-
def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
1202+
def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
12031203
let summary = "cast from floating-point to wider floating-point";
12041204
let description = [{
12051205
Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,6 +1208,13 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
12081208
}];
12091209
let hasVerifier = 1;
12101210
let hasFolder = 1;
1211+
1212+
let arguments = (ins FloatLike:$in,
1213+
OptionalAttr<Arith_FastMathAttr>:$fastmath);
1214+
let results = (outs FloatLike:$out);
1215+
1216+
let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
1217+
attr-dict `:` type($in) `to` type($out) }];
12111218
}
12121219

12131220
//===----------------------------------------------------------------------===//
@@ -1246,9 +1253,11 @@ def Arith_TruncFOp :
12461253
Arith_Op<"truncf",
12471254
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
12481255
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
1256+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
12491257
DeclareOpInterfaceMethods<CastOpInterface>]>,
12501258
Arguments<(ins FloatLike:$in,
1251-
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
1259+
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode,
1260+
OptionalAttr<Arith_FastMathAttr>:$fastmath)>,
12521261
Results<(outs FloatLike:$out)> {
12531262
let summary = "cast from floating-point to narrower floating-point";
12541263
let description = [{
@@ -1267,7 +1276,9 @@ def Arith_TruncFOp :
12671276

12681277
let hasFolder = 1;
12691278
let hasVerifier = 1;
1270-
let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
1279+
let assemblyFormat = [{ $in ($roundingmode^)?
1280+
(`fastmath` `` $fastmath^)?
1281+
attr-dict `:` type($in) `to` type($out) }];
12711282
}
12721283

12731284
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,22 @@ LogicalResult arith::ExtSIOp::verify() {
13901390
/// Fold extension of float constants when there is no information loss due the
13911391
/// difference in fp semantics.
13921392
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1393+
if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1394+
if (truncFOp.getOperand().getType() == getType()) {
1395+
arith::FastMathFlags truncFMF =
1396+
truncFOp.getFastmath().value_or(arith::FastMathFlags::none);
1397+
bool isTruncContract =
1398+
bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1399+
arith::FastMathFlags extFMF =
1400+
getFastmath().value_or(arith::FastMathFlags::none);
1401+
bool isExtContract =
1402+
bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1403+
if (isTruncContract && isExtContract) {
1404+
return truncFOp.getOperand();
1405+
}
1406+
}
1407+
}
1408+
13931409
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
13941410
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
13951411
return constFoldCastOp<FloatAttr, FloatAttr>(

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
9494
SmallVector<Value> newResults(expandedOp->getResults());
9595
for (auto [res, oldType, newType] : llvm::zip_equal(
9696
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
97-
if (oldType != newType)
98-
res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
97+
if (oldType != newType) {
98+
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
99+
truncFOp.setFastmath(arith::FastMathFlags::contract);
100+
res = truncFOp.getResult();
101+
}
99102
}
100103
rewriter.replaceOp(op, newResults);
101104
}
@@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
114117
});
115118
converter.addTargetMaterialization(
116119
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
117-
return b.create<arith::ExtFOp>(loc, target, input);
120+
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
121+
extFOp.setFastmath(arith::FastMathFlags::contract);
122+
return extFOp;
118123
});
119124
}
120125

mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
5757
});
5858
typeConverter.addTargetMaterialization(
5959
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
60-
return b.create<arith::ExtFOp>(loc, target, input);
60+
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
61+
extFOp.setFastmath(arith::FastMathFlags::contract);
62+
return extFOp;
6163
});
6264
}
6365

@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
8486
SmallVector<Value> results = (*legalized)->getResults();
8587
for (auto [result, newType, origType] : llvm::zip_equal(
8688
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
87-
if (newType != origType)
88-
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
89+
if (newType != origType) {
90+
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
91+
truncFOp.setFastmath(arith::FastMathFlags::contract);
92+
result = truncFOp.getResult();
93+
}
8994
}
9095
rewriter.replaceOp(op, results);
9196
return success();

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3031,6 +3031,143 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
30313031
return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
30323032
}
30333033

3034+
// CHECK-LABEL: @sequences_fastmath_contract
3035+
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
3036+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3037+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3038+
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3039+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3040+
// CHECK: return [[TRUNCF]] : bf16
3041+
func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
3042+
%0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
3043+
%1 = math.absf %0 : f32
3044+
%2 = arith.truncf %1 fastmath<contract> : f32 to bf16
3045+
%3 = arith.extf %2 fastmath<contract> : bf16 to f32
3046+
%4 = math.sin %3 : f32
3047+
%5 = arith.truncf %4 fastmath<contract> : f32 to bf16
3048+
return %5 : bf16
3049+
}
3050+
3051+
// CHECK-LABEL: @sequences_no_fastmath
3052+
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
3053+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3054+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3055+
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
3056+
// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
3057+
// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
3058+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3059+
// CHECK: return [[TRUNCF]] : bf16
3060+
func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
3061+
%0 = arith.extf %arg0 : bf16 to f32
3062+
%1 = math.absf %0 : f32
3063+
%2 = arith.truncf %1 : f32 to bf16
3064+
%3 = arith.extf %2 : bf16 to f32
3065+
%4 = math.sin %3 : f32
3066+
%5 = arith.truncf %4 : f32 to bf16
3067+
return %5 : bf16
3068+
}
3069+
3070+
// CHECK-LABEL: @eliminate_cast_to_f16
3071+
// CHECK: return [[arg0:%.+]] : f32
3072+
func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
3073+
%0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
3074+
%1 = arith.extf %0 fastmath<contract> : f16 to f32
3075+
return %1 : f32
3076+
}
3077+
3078+
// CHECK-LABEL: @eliminate_cast_to_bf16
3079+
// CHECK: return [[arg0:%.+]] : f32
3080+
func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
3081+
%0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
3082+
%1 = arith.extf %0 fastmath<contract> : bf16 to f32
3083+
return %1 : f32
3084+
}
3085+
3086+
// CHECK-LABEL: @bf16_sin_vector
3087+
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
3088+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3089+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3090+
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3091+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3092+
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
3093+
func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3094+
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3095+
%1 = math.absf %0 : vector<32x32x32xf32>
3096+
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3097+
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3098+
%4 = math.sin %3 : vector<32x32x32xf32>
3099+
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3100+
return %5 : vector<32x32x32xbf16>
3101+
}
3102+
3103+
// CHECK-LABEL: @f16_sin_vector
3104+
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
3105+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3106+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3107+
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3108+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3109+
// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
3110+
func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
3111+
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
3112+
%1 = math.absf %0 : vector<32x32x32xf32>
3113+
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
3114+
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
3115+
%4 = math.sin %3 : vector<32x32x32xf32>
3116+
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
3117+
return %5 : vector<32x32x32xf16>
3118+
}
3119+
3120+
// CHECK-LABEL: @bf16_branch_vector
3121+
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
3122+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3123+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3124+
// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
3125+
// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
3126+
// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
3127+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
3128+
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
3129+
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3130+
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3131+
%1 = math.absf %0 : vector<32x32x32xf32>
3132+
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3133+
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3134+
%4 = math.sin %3 : vector<32x32x32xf32>
3135+
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3136+
%6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3137+
%7 = math.cos %3 : vector<32x32x32xf32>
3138+
%8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3139+
%9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3140+
%10 = arith.addf %6, %9 : vector<32x32x32xf32>
3141+
%11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3142+
return %11 : vector<32x32x32xbf16>
3143+
}
3144+
3145+
// CHECK-LABEL: @bf16_fma
3146+
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
3147+
// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
3148+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
3149+
// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
3150+
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
3151+
// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
3152+
// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
3153+
// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
3154+
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
3155+
// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
3156+
func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3157+
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3158+
%1 = math.absf %0 : vector<32x32x32xf32>
3159+
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3160+
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3161+
%4 = math.sin %3 : vector<32x32x32xf32>
3162+
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3163+
%6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3164+
%7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
3165+
%8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3166+
%9 = arith.addf %8, %6 : vector<32x32x32xf32>
3167+
%10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3168+
return %10 : vector<32x32x32xbf16>
3169+
}
3170+
30343171
{-#
30353172
dialect_resources: {
30363173
builtin: {

mlir/test/Dialect/Arith/emulate-unsupported-floats.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
44
// CHECK-LABEL: @basic_expansion
55
// CHECK-SAME: [[X:%.+]]: bf16
66
// CHECK-DAG: [[C:%.+]] = arith.constant {{.*}} : bf16
7-
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
8-
// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] : bf16 to f32
7+
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
8+
// CHECK-DAG: [[C_EXP:%.+]] = arith.extf [[C]] fastmath<contract> : bf16 to f32
99
// CHECK: [[Y_EXP:%.+]] = arith.addf [[X_EXP]], [[C_EXP]] : f32
10-
// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] : f32 to bf16
10+
// CHECK: [[Y:%.+]] = arith.truncf [[Y_EXP]] fastmath<contract> : f32 to bf16
1111
// CHECK: return [[Y]]
1212
%c = arith.constant 1.0 : bf16
1313
%y = arith.addf %x, %c : bf16
@@ -19,15 +19,15 @@ func.func @basic_expansion(%x: bf16) -> bf16 {
1919
func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
2020
// CHECK-LABEL: @chained
2121
// CHECK-SAME: [[X:%.+]]: bf16, [[Y:%.+]]: bf16, [[Z:%.+]]: bf16
22-
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] : bf16 to f32
23-
// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] : bf16 to f32
24-
// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] : bf16 to f32
22+
// CHECK-DAG: [[X_EXP:%.+]] = arith.extf [[X]] fastmath<contract> : bf16 to f32
23+
// CHECK-DAG: [[Y_EXP:%.+]] = arith.extf [[Y]] fastmath<contract> : bf16 to f32
24+
// CHECK-DAG: [[Z_EXP:%.+]] = arith.extf [[Z]] fastmath<contract> : bf16 to f32
2525
// CHECK: [[P_EXP:%.+]] = arith.addf [[X_EXP]], [[Y_EXP]] : f32
26-
// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] : f32 to bf16
27-
// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] : bf16 to f32
26+
// CHECK: [[P:%.+]] = arith.truncf [[P_EXP]] fastmath<contract> : f32 to bf16
27+
// CHECK: [[P_EXP2:%.+]] = arith.extf [[P]] fastmath<contract> : bf16 to f32
2828
// CHECK: [[Q_EXP:%.+]] = arith.mulf [[P_EXP2]], [[Z_EXP]]
29-
// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] : f32 to bf16
30-
// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] : bf16 to f32
29+
// CHECK: [[Q:%.+]] = arith.truncf [[Q_EXP]] fastmath<contract> : f32 to bf16
30+
// CHECK: [[Q_EXP2:%.+]] = arith.extf [[Q]] fastmath<contract> : bf16 to f32
3131
// CHECK: [[RES:%.+]] = arith.cmpf ole, [[P_EXP2]], [[Q_EXP2]] : f32
3232
// CHECK: return [[RES]]
3333
%p = arith.addf %x, %y : bf16
@@ -41,12 +41,12 @@ func.func @chained(%x: bf16, %y: bf16, %z: bf16) -> i1 {
4141
func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
4242
// CHECK-LABEL: @memops
4343
// CHECK: [[V:%.+]] = memref.load {{.*}} : memref<4xf8E4M3FNUZ>
44-
// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] : f8E4M3FNUZ to f32
44+
// CHECK: [[V_EXP:%.+]] = arith.extf [[V]] fastmath<contract> : f8E4M3FNUZ to f32
4545
// CHECK: memref.store [[V]]
4646
// CHECK: [[W:%.+]] = memref.load
47-
// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] : f8E4M3FNUZ to f32
47+
// CHECK: [[W_EXP:%.+]] = arith.extf [[W]] fastmath<contract> : f8E4M3FNUZ to f32
4848
// CHECK: [[X_EXP:%.+]] = arith.addf [[V_EXP]], [[W_EXP]] : f32
49-
// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] : f32 to f8E4M3FNUZ
49+
// CHECK: [[X:%.+]] = arith.truncf [[X_EXP]] fastmath<contract> : f32 to f8E4M3FNUZ
5050
// CHECK: memref.store [[X]]
5151
%c0 = arith.constant 0 : index
5252
%c1 = arith.constant 1 : index
@@ -63,9 +63,9 @@ func.func @memops(%a: memref<4xf8E4M3FNUZ>, %b: memref<4xf8E4M3FNUZ>) {
6363
func.func @vectors(%a: vector<4xf8E4M3FNUZ>) -> vector<4xf32> {
6464
// CHECK-LABEL: @vectors
6565
// CHECK-SAME: [[A:%.+]]: vector<4xf8E4M3FNUZ>
66-
// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
66+
// CHECK: [[A_EXP:%.+]] = arith.extf [[A]] fastmath<contract> : vector<4xf8E4M3FNUZ> to vector<4xf32>
6767
// CHECK: [[B_EXP:%.+]] = arith.mulf [[A_EXP]], [[A_EXP]] : vector<4xf32>
68-
// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] : vector<4xf32> to vector<4xf8E4M3FNUZ>
68+
// CHECK: [[B:%.+]] = arith.truncf [[B_EXP]] fastmath<contract> : vector<4xf32> to vector<4xf8E4M3FNUZ>
6969
// CHECK: [[RET:%.+]] = arith.extf [[B]] : vector<4xf8E4M3FNUZ> to vector<4xf32>
7070
// CHECK: return [[RET]]
7171
%b = arith.mulf %a, %a : vector<4xf8E4M3FNUZ>

0 commit comments

Comments
 (0)