Skip to content

Commit 6784bf7

Browse files
author
Ivy Zhang
authored
[MLIR][Arith] add fastMathAttr on arith::extf and arith::truncf (#93443)
Add an `fastMathAttr` on `arith::extf` and `arith::truncf`. If these two ops are inserted by some promotion passes (like legalize-to-f32 / emulate-unsupported-floats), they will be labeled as `FastMathFlags::contract`, denoting that they can be then `eliminated by canonicalizer`. The `elimination` can help improve performance, while may introduce some numerical differences.
1 parent 86d8aec commit 6784bf7

File tree

7 files changed

+213
-40
lines changed

7 files changed

+213
-40
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,7 @@ def Arith_ExtSIOp : Arith_IToICastOp<"extsi"> {
11991199
// ExtFOp
12001200
//===----------------------------------------------------------------------===//
12011201

1202-
def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
1202+
def Arith_ExtFOp : Arith_FToFCastOp<"extf", [DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
12031203
let summary = "cast from floating-point to wider floating-point";
12041204
let description = [{
12051205
Cast a floating-point value to a larger floating-point-typed value.
@@ -1208,6 +1208,13 @@ def Arith_ExtFOp : Arith_FToFCastOp<"extf"> {
12081208
}];
12091209
let hasVerifier = 1;
12101210
let hasFolder = 1;
1211+
1212+
let arguments = (ins FloatLike:$in, DefaultValuedAttr<
1213+
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath);
1214+
let results = (outs FloatLike:$out);
1215+
1216+
let assemblyFormat = [{ $in (`fastmath` `` $fastmath^)?
1217+
attr-dict `:` type($in) `to` type($out) }];
12111218
}
12121219

12131220
//===----------------------------------------------------------------------===//
@@ -1246,8 +1253,11 @@ def Arith_TruncFOp :
12461253
Arith_Op<"truncf",
12471254
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
12481255
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
1256+
DeclareOpInterfaceMethods<ArithFastMathInterface>,
12491257
DeclareOpInterfaceMethods<CastOpInterface>]>,
12501258
Arguments<(ins FloatLike:$in,
1259+
DefaultValuedAttr<
1260+
Arith_FastMathAttr, "::mlir::arith::FastMathFlags::none">:$fastmath,
12511261
OptionalAttr<Arith_RoundingModeAttr>:$roundingmode)>,
12521262
Results<(outs FloatLike:$out)> {
12531263
let summary = "cast from floating-point to narrower floating-point";
@@ -1267,7 +1277,9 @@ def Arith_TruncFOp :
12671277

12681278
let hasFolder = 1;
12691279
let hasVerifier = 1;
1270-
let assemblyFormat = "$in ($roundingmode^)? attr-dict `:` type($in) `to` type($out)";
1280+
let assemblyFormat = [{ $in ($roundingmode^)?
1281+
(`fastmath` `` $fastmath^)?
1282+
attr-dict `:` type($in) `to` type($out) }];
12711283
}
12721284

12731285
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,20 @@ LogicalResult arith::ExtSIOp::verify() {
13901390
/// Fold extension of float constants when there is no information loss due the
13911391
/// difference in fp semantics.
13921392
OpFoldResult arith::ExtFOp::fold(FoldAdaptor adaptor) {
1393+
if (auto truncFOp = getOperand().getDefiningOp<TruncFOp>()) {
1394+
if (truncFOp.getOperand().getType() == getType()) {
1395+
arith::FastMathFlags truncFMF = truncFOp.getFastmath();
1396+
bool isTruncContract =
1397+
bitEnumContainsAll(truncFMF, arith::FastMathFlags::contract);
1398+
arith::FastMathFlags extFMF = getFastmath();
1399+
bool isExtContract =
1400+
bitEnumContainsAll(extFMF, arith::FastMathFlags::contract);
1401+
if (isTruncContract && isExtContract) {
1402+
return truncFOp.getOperand();
1403+
}
1404+
}
1405+
}
1406+
13931407
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
13941408
const llvm::fltSemantics &targetSemantics = resElemType.getFloatSemantics();
13951409
return constFoldCastOp<FloatAttr, FloatAttr>(

mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ void EmulateFloatPattern::rewrite(Operation *op, ArrayRef<Value> operands,
9494
SmallVector<Value> newResults(expandedOp->getResults());
9595
for (auto [res, oldType, newType] : llvm::zip_equal(
9696
MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) {
97-
if (oldType != newType)
98-
res = rewriter.create<arith::TruncFOp>(loc, oldType, res);
97+
if (oldType != newType) {
98+
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res);
99+
truncFOp.setFastmath(arith::FastMathFlags::contract);
100+
res = truncFOp.getResult();
101+
}
99102
}
100103
rewriter.replaceOp(op, newResults);
101104
}
@@ -114,7 +117,9 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions(
114117
});
115118
converter.addTargetMaterialization(
116119
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
117-
return b.create<arith::ExtFOp>(loc, target, input);
120+
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
121+
extFOp.setFastmath(arith::FastMathFlags::contract);
122+
return extFOp;
118123
});
119124
}
120125

mlir/lib/Dialect/Math/Transforms/LegalizeToF32.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ void mlir::math::populateLegalizeToF32TypeConverter(
5757
});
5858
typeConverter.addTargetMaterialization(
5959
[](OpBuilder &b, Type target, ValueRange input, Location loc) {
60-
return b.create<arith::ExtFOp>(loc, target, input);
60+
auto extFOp = b.create<arith::ExtFOp>(loc, target, input);
61+
extFOp.setFastmath(arith::FastMathFlags::contract);
62+
return extFOp;
6163
});
6264
}
6365

@@ -84,8 +86,11 @@ LogicalResult LegalizeToF32RewritePattern::matchAndRewrite(
8486
SmallVector<Value> results = (*legalized)->getResults();
8587
for (auto [result, newType, origType] : llvm::zip_equal(
8688
results, (*legalized)->getResultTypes(), op->getResultTypes())) {
87-
if (newType != origType)
88-
result = rewriter.create<arith::TruncFOp>(loc, origType, result);
89+
if (newType != origType) {
90+
auto truncFOp = rewriter.create<arith::TruncFOp>(loc, origType, result);
91+
truncFOp.setFastmath(arith::FastMathFlags::contract);
92+
result = truncFOp.getResult();
93+
}
8994
}
9095
rewriter.replaceOp(op, results);
9196
return success();

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -162,23 +162,23 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
162162
// Checking conversion of integer types to floating point.
163163
// CHECK-LABEL: @fpext
164164
func.func @fpext(%arg0 : f16, %arg1 : f32) {
165-
// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f32
165+
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f32
166166
%0 = arith.extf %arg0: f16 to f32
167-
// CHECK-NEXT: = llvm.fpext {{.*}} : f16 to f64
167+
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f16 to f64
168168
%1 = arith.extf %arg0: f16 to f64
169-
// CHECK-NEXT: = llvm.fpext {{.*}} : f32 to f64
169+
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f64
170170
%2 = arith.extf %arg1: f32 to f64
171171
return
172172
}
173173

174174
// Checking conversion of integer types to floating point.
175175
// CHECK-LABEL: @fpext
176176
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
177-
// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf32>
177+
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf32>
178178
%0 = arith.extf %arg0: vector<2xf16> to vector<2xf32>
179-
// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf16> to vector<2xf64>
179+
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf16> to vector<2xf64>
180180
%1 = arith.extf %arg0: vector<2xf16> to vector<2xf64>
181-
// CHECK-NEXT: = llvm.fpext {{.*}} : vector<2xf32> to vector<2xf64>
181+
// CHECK-NEXT: = llvm.fpext {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf64>
182182
%2 = arith.extf %arg1: vector<2xf32> to vector<2xf64>
183183
return
184184
}
@@ -268,38 +268,38 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
268268
// Checking conversion of integer types to floating point.
269269
// CHECK-LABEL: @fptrunc
270270
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
271-
// CHECK-NEXT: = llvm.fptrunc {{.*}} : f32 to f16
271+
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f32 to f16
272272
%0 = arith.truncf %arg0: f32 to f16
273-
// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f16
273+
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f16
274274
%1 = arith.truncf %arg1: f64 to f16
275-
// CHECK-NEXT: = llvm.fptrunc {{.*}} : f64 to f32
275+
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : f64 to f32
276276
%2 = arith.truncf %arg1: f64 to f32
277277
return
278278
}
279279

280280
// Checking conversion of integer types to floating point.
281281
// CHECK-LABEL: @fptrunc
282282
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
283-
// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf32> to vector<2xf16>
283+
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf32> to vector<2xf16>
284284
%0 = arith.truncf %arg0: vector<2xf32> to vector<2xf16>
285-
// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf16>
285+
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf16>
286286
%1 = arith.truncf %arg1: vector<2xf64> to vector<2xf16>
287-
// CHECK-NEXT: = llvm.fptrunc {{.*}} : vector<2xf64> to vector<2xf32>
287+
// CHECK-NEXT: = llvm.fptrunc {{.*}} {fastmath = #arith.fastmath<none>} : vector<2xf64> to vector<2xf32>
288288
%2 = arith.truncf %arg1: vector<2xf64> to vector<2xf32>
289289
return
290290
}
291291

292292
// CHECK-LABEL: experimental_constrained_fptrunc
293293
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
294-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
294+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore {fastmath = #arith.fastmath<none>} : f64 to f32
295295
%0 = arith.truncf %arg0 to_nearest_even : f64 to f32
296-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore : f64 to f32
296+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} downward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
297297
%1 = arith.truncf %arg0 downward : f64 to f32
298-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore : f64 to f32
298+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} upward ignore {fastmath = #arith.fastmath<none>} : f64 to f32
299299
%2 = arith.truncf %arg0 upward : f64 to f32
300-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore : f64 to f32
300+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} towardzero ignore {fastmath = #arith.fastmath<none>} : f64 to f32
301301
%3 = arith.truncf %arg0 toward_zero : f64 to f32
302-
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore : f64 to f32
302+
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearestaway ignore {fastmath = #arith.fastmath<none>} : f64 to f32
303303
%4 = arith.truncf %arg0 to_nearest_away : f64 to f32
304304
return
305305
}

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3031,6 +3031,143 @@ func.func @mulsi_extended_i0() -> (i0, i0) {
30313031
return %mulsi_extended#0, %mulsi_extended#1 : i0, i0
30323032
}
30333033

3034+
// CHECK-LABEL: @sequences_fastmath_contract
3035+
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
3036+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3037+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3038+
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3039+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3040+
// CHECK: return [[TRUNCF]] : bf16
3041+
func.func @sequences_fastmath_contract(%arg0: bf16) -> bf16 {
3042+
%0 = arith.extf %arg0 fastmath<contract> : bf16 to f32
3043+
%1 = math.absf %0 : f32
3044+
%2 = arith.truncf %1 fastmath<contract> : f32 to bf16
3045+
%3 = arith.extf %2 fastmath<contract> : bf16 to f32
3046+
%4 = math.sin %3 : f32
3047+
%5 = arith.truncf %4 fastmath<contract> : f32 to bf16
3048+
return %5 : bf16
3049+
}
3050+
3051+
// CHECK-LABEL: @sequences_no_fastmath
3052+
// CHECK-SAME: ([[ARG0:%.+]]: bf16)
3053+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3054+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3055+
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ABSF]]
3056+
// CHECK: [[EXTF1:%.+]] = arith.extf [[TRUNCF1]]
3057+
// CHECK: [[SIN:%.+]] = math.sin [[EXTF1]]
3058+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3059+
// CHECK: return [[TRUNCF]] : bf16
3060+
func.func @sequences_no_fastmath(%arg0: bf16) -> bf16 {
3061+
%0 = arith.extf %arg0 : bf16 to f32
3062+
%1 = math.absf %0 : f32
3063+
%2 = arith.truncf %1 : f32 to bf16
3064+
%3 = arith.extf %2 : bf16 to f32
3065+
%4 = math.sin %3 : f32
3066+
%5 = arith.truncf %4 : f32 to bf16
3067+
return %5 : bf16
3068+
}
3069+
3070+
// CHECK-LABEL: @eliminate_cast_to_f16
3071+
// CHECK: return [[arg0:%.+]] : f32
3072+
func.func @eliminate_cast_to_f16(%arg0: f32) -> f32 {
3073+
%0 = arith.truncf %arg0 fastmath<contract> : f32 to f16
3074+
%1 = arith.extf %0 fastmath<contract> : f16 to f32
3075+
return %1 : f32
3076+
}
3077+
3078+
// CHECK-LABEL: @eliminate_cast_to_bf16
3079+
// CHECK: return [[arg0:%.+]] : f32
3080+
func.func @eliminate_cast_to_bf16(%arg0: f32) -> f32 {
3081+
%0 = arith.truncf %arg0 fastmath<contract> : f32 to bf16
3082+
%1 = arith.extf %0 fastmath<contract> : bf16 to f32
3083+
return %1 : f32
3084+
}
3085+
3086+
// CHECK-LABEL: @bf16_sin_vector
3087+
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
3088+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3089+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3090+
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3091+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3092+
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
3093+
func.func @bf16_sin_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3094+
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3095+
%1 = math.absf %0 : vector<32x32x32xf32>
3096+
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3097+
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3098+
%4 = math.sin %3 : vector<32x32x32xf32>
3099+
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3100+
return %5 : vector<32x32x32xbf16>
3101+
}
3102+
3103+
// CHECK-LABEL: @f16_sin_vector
3104+
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xf16>)
3105+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3106+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3107+
// CHECK: [[SIN:%.+]] = math.sin [[ABSF]]
3108+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[SIN]]
3109+
// CHECK: return [[TRUNCF]] : vector<32x32x32xf16>
3110+
func.func @f16_sin_vector(%arg0: vector<32x32x32xf16>) -> vector<32x32x32xf16> {
3111+
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
3112+
%1 = math.absf %0 : vector<32x32x32xf32>
3113+
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
3114+
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xf16> to vector<32x32x32xf32>
3115+
%4 = math.sin %3 : vector<32x32x32xf32>
3116+
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xf16>
3117+
return %5 : vector<32x32x32xf16>
3118+
}
3119+
3120+
// CHECK-LABEL: @bf16_branch_vector
3121+
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>)
3122+
// CHECK: [[EXTF:%.+]] = arith.extf [[ARG0]]
3123+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF]]
3124+
// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
3125+
// CHECK-DAG: [[COS:%.+]] = math.cos [[ABSF]]
3126+
// CHECK: [[ADDF:%.+]] = arith.addf [[SIN]], [[COS]]
3127+
// CHECK: [[TRUNCF:%.+]] = arith.truncf [[ADDF]]
3128+
// CHECK: return [[TRUNCF]] : vector<32x32x32xbf16>
3129+
func.func @bf16_branch_vector(%arg0: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3130+
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3131+
%1 = math.absf %0 : vector<32x32x32xf32>
3132+
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3133+
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3134+
%4 = math.sin %3 : vector<32x32x32xf32>
3135+
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3136+
%6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3137+
%7 = math.cos %3 : vector<32x32x32xf32>
3138+
%8 = arith.truncf %7 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3139+
%9 = arith.extf %8 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3140+
%10 = arith.addf %6, %9 : vector<32x32x32xf32>
3141+
%11 = arith.truncf %10 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3142+
return %11 : vector<32x32x32xbf16>
3143+
}
3144+
3145+
// CHECK-LABEL: @bf16_fma
3146+
// CHECK-SAME: ([[ARG0:%.+]]: vector<32x32x32xbf16>, [[ARG1:%.+]]: vector<32x32x32xbf16>, [[ARG2:%.+]]: vector<32x32x32xbf16>)
3147+
// CHECK: [[EXTF0:%.+]] = arith.extf [[ARG0]]
3148+
// CHECK: [[ABSF:%.+]] = math.absf [[EXTF0]]
3149+
// CHECK-DAG: [[SIN:%.+]] = math.sin [[ABSF]]
3150+
// CHECK: [[TRUNCF0:%.+]] = arith.truncf [[SIN]]
3151+
// CHECK-DAG: [[FMA:%.+]] = math.fma [[TRUNCF0]], [[ARG1]], [[ARG2]]
3152+
// CHECK: [[EXTF1:%.+]] = arith.extf [[FMA]]
3153+
// CHECK: [[ADDF:%.+]] = arith.addf [[EXTF1]], [[SIN]]
3154+
// CHECK: [[TRUNCF1:%.+]] = arith.truncf [[ADDF]]
3155+
// CHECK: return [[TRUNCF1]] : vector<32x32x32xbf16>
3156+
func.func @bf16_fma(%arg0: vector<32x32x32xbf16>, %arg1: vector<32x32x32xbf16>, %arg2: vector<32x32x32xbf16>) -> vector<32x32x32xbf16> {
3157+
%0 = arith.extf %arg0 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3158+
%1 = math.absf %0 : vector<32x32x32xf32>
3159+
%2 = arith.truncf %1 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3160+
%3 = arith.extf %2 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3161+
%4 = math.sin %3 : vector<32x32x32xf32>
3162+
%5 = arith.truncf %4 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3163+
%6 = arith.extf %5 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3164+
%7 = math.fma %5, %arg1, %arg2 : vector<32x32x32xbf16>
3165+
%8 = arith.extf %7 fastmath<contract> : vector<32x32x32xbf16> to vector<32x32x32xf32>
3166+
%9 = arith.addf %8, %6 : vector<32x32x32xf32>
3167+
%10 = arith.truncf %9 fastmath<contract> : vector<32x32x32xf32> to vector<32x32x32xbf16>
3168+
return %10 : vector<32x32x32xbf16>
3169+
}
3170+
30343171
{-#
30353172
dialect_resources: {
30363173
builtin: {

0 commit comments

Comments
 (0)