Skip to content

Commit 9f6bf00

Browse files
authored
[DAGCombine] Add DAG optimisation for BF16_TO_FP (#69426)
fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
1 parent ae0b263 commit 9f6bf00

File tree

3 files changed

+17
-73
lines changed

3 files changed

+17
-73
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ namespace {
546546
SDValue visitFP_TO_FP16(SDNode *N);
547547
SDValue visitFP16_TO_FP(SDNode *N);
548548
SDValue visitFP_TO_BF16(SDNode *N);
549+
SDValue visitBF16_TO_FP(SDNode *N);
549550
SDValue visitVECREDUCE(SDNode *N);
550551
SDValue visitVPOp(SDNode *N);
551552
SDValue visitGET_FPENV_MEM(SDNode *N);
@@ -2047,6 +2048,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
20472048
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
20482049
case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
20492050
case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
2051+
case ISD::BF16_TO_FP: return visitBF16_TO_FP(N);
20502052
case ISD::FREEZE: return visitFREEZE(N);
20512053
case ISD::GET_FPENV_MEM: return visitGET_FPENV_MEM(N);
20522054
case ISD::SET_FPENV_MEM: return visitSET_FPENV_MEM(N);
@@ -26256,14 +26258,17 @@ SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
2625626258
}
2625726259

2625826260
SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
26261+
auto Op = N->getOpcode();
26262+
assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
26263+
"opcode should be FP16_TO_FP or BF16_TO_FP.");
2625926264
SDValue N0 = N->getOperand(0);
2626026265

26261-
// fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op)
26266+
// fold fp16_to_fp(op & 0xffff) -> fp16_to_fp(op) or
26267+
// fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
2626226268
if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
2626326269
ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
2626426270
if (AndConst && AndConst->getAPIntValue() == 0xffff) {
26265-
return DAG.getNode(ISD::FP16_TO_FP, SDLoc(N), N->getValueType(0),
26266-
N0.getOperand(0));
26271+
return DAG.getNode(Op, SDLoc(N), N->getValueType(0), N0.getOperand(0));
2626726272
}
2626826273
}
2626926274

@@ -26280,6 +26285,11 @@ SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
2628026285
return SDValue();
2628126286
}
2628226287

26288+
SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
26289+
// fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
26290+
return visitFP16_TO_FP(N);
26291+
}
26292+
2628326293
SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
2628426294
SDValue N0 = N->getOperand(0);
2628526295
EVT VT = N0.getValueType();

llvm/test/CodeGen/RISCV/bfloat-convert.ll

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,6 @@ define i16 @fcvt_si_bf16(bfloat %a) nounwind {
3939
; RV64ID-LABEL: fcvt_si_bf16:
4040
; RV64ID: # %bb.0:
4141
; RV64ID-NEXT: fmv.x.w a0, fa0
42-
; RV64ID-NEXT: slli a0, a0, 48
43-
; RV64ID-NEXT: srli a0, a0, 48
4442
; RV64ID-NEXT: slli a0, a0, 16
4543
; RV64ID-NEXT: fmv.w.x fa5, a0
4644
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -100,8 +98,6 @@ define i16 @fcvt_si_bf16_sat(bfloat %a) nounwind {
10098
; RV64ID-LABEL: fcvt_si_bf16_sat:
10199
; RV64ID: # %bb.0: # %start
102100
; RV64ID-NEXT: fmv.x.w a0, fa0
103-
; RV64ID-NEXT: slli a0, a0, 48
104-
; RV64ID-NEXT: srli a0, a0, 48
105101
; RV64ID-NEXT: slli a0, a0, 16
106102
; RV64ID-NEXT: fmv.w.x fa5, a0
107103
; RV64ID-NEXT: feq.s a0, fa5, fa5
@@ -145,8 +141,6 @@ define i16 @fcvt_ui_bf16(bfloat %a) nounwind {
145141
; RV64ID-LABEL: fcvt_ui_bf16:
146142
; RV64ID: # %bb.0:
147143
; RV64ID-NEXT: fmv.x.w a0, fa0
148-
; RV64ID-NEXT: slli a0, a0, 48
149-
; RV64ID-NEXT: srli a0, a0, 48
150144
; RV64ID-NEXT: slli a0, a0, 16
151145
; RV64ID-NEXT: fmv.w.x fa5, a0
152146
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -196,8 +190,6 @@ define i16 @fcvt_ui_bf16_sat(bfloat %a) nounwind {
196190
; RV64ID-NEXT: lui a0, %hi(.LCPI3_0)
197191
; RV64ID-NEXT: flw fa5, %lo(.LCPI3_0)(a0)
198192
; RV64ID-NEXT: fmv.x.w a0, fa0
199-
; RV64ID-NEXT: slli a0, a0, 48
200-
; RV64ID-NEXT: srli a0, a0, 48
201193
; RV64ID-NEXT: slli a0, a0, 16
202194
; RV64ID-NEXT: fmv.w.x fa4, a0
203195
; RV64ID-NEXT: fmv.w.x fa3, zero
@@ -235,8 +227,6 @@ define i32 @fcvt_w_bf16(bfloat %a) nounwind {
235227
; RV64ID-LABEL: fcvt_w_bf16:
236228
; RV64ID: # %bb.0:
237229
; RV64ID-NEXT: fmv.x.w a0, fa0
238-
; RV64ID-NEXT: slli a0, a0, 48
239-
; RV64ID-NEXT: srli a0, a0, 48
240230
; RV64ID-NEXT: slli a0, a0, 16
241231
; RV64ID-NEXT: fmv.w.x fa5, a0
242232
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -281,8 +271,6 @@ define i32 @fcvt_w_bf16_sat(bfloat %a) nounwind {
281271
; RV64ID-LABEL: fcvt_w_bf16_sat:
282272
; RV64ID: # %bb.0: # %start
283273
; RV64ID-NEXT: fmv.x.w a0, fa0
284-
; RV64ID-NEXT: slli a0, a0, 48
285-
; RV64ID-NEXT: srli a0, a0, 48
286274
; RV64ID-NEXT: slli a0, a0, 16
287275
; RV64ID-NEXT: fmv.w.x fa5, a0
288276
; RV64ID-NEXT: fcvt.w.s a0, fa5, rtz
@@ -321,8 +309,6 @@ define i32 @fcvt_wu_bf16(bfloat %a) nounwind {
321309
; RV64ID-LABEL: fcvt_wu_bf16:
322310
; RV64ID: # %bb.0:
323311
; RV64ID-NEXT: fmv.x.w a0, fa0
324-
; RV64ID-NEXT: slli a0, a0, 48
325-
; RV64ID-NEXT: srli a0, a0, 48
326312
; RV64ID-NEXT: slli a0, a0, 16
327313
; RV64ID-NEXT: fmv.w.x fa5, a0
328314
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -361,8 +347,6 @@ define i32 @fcvt_wu_bf16_multiple_use(bfloat %x, ptr %y) nounwind {
361347
; RV64ID-LABEL: fcvt_wu_bf16_multiple_use:
362348
; RV64ID: # %bb.0:
363349
; RV64ID-NEXT: fmv.x.w a0, fa0
364-
; RV64ID-NEXT: slli a0, a0, 48
365-
; RV64ID-NEXT: srli a0, a0, 48
366350
; RV64ID-NEXT: slli a0, a0, 16
367351
; RV64ID-NEXT: fmv.w.x fa5, a0
368352
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -413,8 +397,6 @@ define i32 @fcvt_wu_bf16_sat(bfloat %a) nounwind {
413397
; RV64ID-LABEL: fcvt_wu_bf16_sat:
414398
; RV64ID: # %bb.0: # %start
415399
; RV64ID-NEXT: fmv.x.w a0, fa0
416-
; RV64ID-NEXT: slli a0, a0, 48
417-
; RV64ID-NEXT: srli a0, a0, 48
418400
; RV64ID-NEXT: slli a0, a0, 16
419401
; RV64ID-NEXT: fmv.w.x fa5, a0
420402
; RV64ID-NEXT: fcvt.wu.s a0, fa5, rtz
@@ -463,8 +445,6 @@ define i64 @fcvt_l_bf16(bfloat %a) nounwind {
463445
; RV64ID-LABEL: fcvt_l_bf16:
464446
; RV64ID: # %bb.0:
465447
; RV64ID-NEXT: fmv.x.w a0, fa0
466-
; RV64ID-NEXT: slli a0, a0, 48
467-
; RV64ID-NEXT: srli a0, a0, 48
468448
; RV64ID-NEXT: slli a0, a0, 16
469449
; RV64ID-NEXT: fmv.w.x fa5, a0
470450
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -606,8 +586,6 @@ define i64 @fcvt_l_bf16_sat(bfloat %a) nounwind {
606586
; RV64ID-LABEL: fcvt_l_bf16_sat:
607587
; RV64ID: # %bb.0: # %start
608588
; RV64ID-NEXT: fmv.x.w a0, fa0
609-
; RV64ID-NEXT: slli a0, a0, 48
610-
; RV64ID-NEXT: srli a0, a0, 48
611589
; RV64ID-NEXT: slli a0, a0, 16
612590
; RV64ID-NEXT: fmv.w.x fa5, a0
613591
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -654,8 +632,6 @@ define i64 @fcvt_lu_bf16(bfloat %a) nounwind {
654632
; RV64ID-LABEL: fcvt_lu_bf16:
655633
; RV64ID: # %bb.0:
656634
; RV64ID-NEXT: fmv.x.w a0, fa0
657-
; RV64ID-NEXT: slli a0, a0, 48
658-
; RV64ID-NEXT: srli a0, a0, 48
659635
; RV64ID-NEXT: slli a0, a0, 16
660636
; RV64ID-NEXT: fmv.w.x fa5, a0
661637
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -730,8 +706,6 @@ define i64 @fcvt_lu_bf16_sat(bfloat %a) nounwind {
730706
; RV64ID-LABEL: fcvt_lu_bf16_sat:
731707
; RV64ID: # %bb.0: # %start
732708
; RV64ID-NEXT: fmv.x.w a0, fa0
733-
; RV64ID-NEXT: slli a0, a0, 48
734-
; RV64ID-NEXT: srli a0, a0, 48
735709
; RV64ID-NEXT: slli a0, a0, 16
736710
; RV64ID-NEXT: fmv.w.x fa5, a0
737711
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -1200,8 +1174,6 @@ define float @fcvt_s_bf16(bfloat %a) nounwind {
12001174
; RV64ID-LABEL: fcvt_s_bf16:
12011175
; RV64ID: # %bb.0:
12021176
; RV64ID-NEXT: fmv.x.w a0, fa0
1203-
; RV64ID-NEXT: slli a0, a0, 48
1204-
; RV64ID-NEXT: srli a0, a0, 48
12051177
; RV64ID-NEXT: slli a0, a0, 16
12061178
; RV64ID-NEXT: fmv.w.x fa0, a0
12071179
; RV64ID-NEXT: ret
@@ -1313,8 +1285,6 @@ define double @fcvt_d_bf16(bfloat %a) nounwind {
13131285
; RV64ID-LABEL: fcvt_d_bf16:
13141286
; RV64ID: # %bb.0:
13151287
; RV64ID-NEXT: fmv.x.w a0, fa0
1316-
; RV64ID-NEXT: slli a0, a0, 48
1317-
; RV64ID-NEXT: srli a0, a0, 48
13181288
; RV64ID-NEXT: slli a0, a0, 16
13191289
; RV64ID-NEXT: fmv.w.x fa5, a0
13201290
; RV64ID-NEXT: fcvt.d.s fa0, fa5
@@ -1521,8 +1491,6 @@ define signext i8 @fcvt_w_s_i8(bfloat %a) nounwind {
15211491
; RV64ID-LABEL: fcvt_w_s_i8:
15221492
; RV64ID: # %bb.0:
15231493
; RV64ID-NEXT: fmv.x.w a0, fa0
1524-
; RV64ID-NEXT: slli a0, a0, 48
1525-
; RV64ID-NEXT: srli a0, a0, 48
15261494
; RV64ID-NEXT: slli a0, a0, 16
15271495
; RV64ID-NEXT: fmv.w.x fa5, a0
15281496
; RV64ID-NEXT: fcvt.l.s a0, fa5, rtz
@@ -1582,8 +1550,6 @@ define signext i8 @fcvt_w_s_sat_i8(bfloat %a) nounwind {
15821550
; RV64ID-LABEL: fcvt_w_s_sat_i8:
15831551
; RV64ID: # %bb.0: # %start
15841552
; RV64ID-NEXT: fmv.x.w a0, fa0
1585-
; RV64ID-NEXT: slli a0, a0, 48
1586-
; RV64ID-NEXT: srli a0, a0, 48
15871553
; RV64ID-NEXT: slli a0, a0, 16
15881554
; RV64ID-NEXT: fmv.w.x fa5, a0
15891555
; RV64ID-NEXT: feq.s a0, fa5, fa5
@@ -1627,8 +1593,6 @@ define zeroext i8 @fcvt_wu_s_i8(bfloat %a) nounwind {
16271593
; RV64ID-LABEL: fcvt_wu_s_i8:
16281594
; RV64ID: # %bb.0:
16291595
; RV64ID-NEXT: fmv.x.w a0, fa0
1630-
; RV64ID-NEXT: slli a0, a0, 48
1631-
; RV64ID-NEXT: srli a0, a0, 48
16321596
; RV64ID-NEXT: slli a0, a0, 16
16331597
; RV64ID-NEXT: fmv.w.x fa5, a0
16341598
; RV64ID-NEXT: fcvt.lu.s a0, fa5, rtz
@@ -1676,8 +1640,6 @@ define zeroext i8 @fcvt_wu_s_sat_i8(bfloat %a) nounwind {
16761640
; RV64ID-LABEL: fcvt_wu_s_sat_i8:
16771641
; RV64ID: # %bb.0: # %start
16781642
; RV64ID-NEXT: fmv.x.w a0, fa0
1679-
; RV64ID-NEXT: slli a0, a0, 48
1680-
; RV64ID-NEXT: srli a0, a0, 48
16811643
; RV64ID-NEXT: slli a0, a0, 16
16821644
; RV64ID-NEXT: fmv.w.x fa5, a0
16831645
; RV64ID-NEXT: fmv.w.x fa4, zero
@@ -1731,8 +1693,6 @@ define zeroext i32 @fcvt_wu_bf16_sat_zext(bfloat %a) nounwind {
17311693
; RV64ID-LABEL: fcvt_wu_bf16_sat_zext:
17321694
; RV64ID: # %bb.0: # %start
17331695
; RV64ID-NEXT: fmv.x.w a0, fa0
1734-
; RV64ID-NEXT: slli a0, a0, 48
1735-
; RV64ID-NEXT: srli a0, a0, 48
17361696
; RV64ID-NEXT: slli a0, a0, 16
17371697
; RV64ID-NEXT: fmv.w.x fa5, a0
17381698
; RV64ID-NEXT: fcvt.wu.s a0, fa5, rtz
@@ -1784,8 +1744,6 @@ define signext i32 @fcvt_w_bf16_sat_sext(bfloat %a) nounwind {
17841744
; RV64ID-LABEL: fcvt_w_bf16_sat_sext:
17851745
; RV64ID: # %bb.0: # %start
17861746
; RV64ID-NEXT: fmv.x.w a0, fa0
1787-
; RV64ID-NEXT: slli a0, a0, 48
1788-
; RV64ID-NEXT: srli a0, a0, 48
17891747
; RV64ID-NEXT: slli a0, a0, 16
17901748
; RV64ID-NEXT: fmv.w.x fa5, a0
17911749
; RV64ID-NEXT: fcvt.w.s a0, fa5, rtz

llvm/test/CodeGen/RISCV/bfloat.ll

Lines changed: 4 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ define float @bfloat_to_float(bfloat %a) nounwind {
164164
;
165165
; RV64ID-LP64-LABEL: bfloat_to_float:
166166
; RV64ID-LP64: # %bb.0:
167-
; RV64ID-LP64-NEXT: slli a0, a0, 48
168-
; RV64ID-LP64-NEXT: srli a0, a0, 48
169167
; RV64ID-LP64-NEXT: slli a0, a0, 16
170168
; RV64ID-LP64-NEXT: ret
171169
;
@@ -179,8 +177,6 @@ define float @bfloat_to_float(bfloat %a) nounwind {
179177
; RV64ID-LP64D-LABEL: bfloat_to_float:
180178
; RV64ID-LP64D: # %bb.0:
181179
; RV64ID-LP64D-NEXT: fmv.x.w a0, fa0
182-
; RV64ID-LP64D-NEXT: slli a0, a0, 48
183-
; RV64ID-LP64D-NEXT: srli a0, a0, 48
184180
; RV64ID-LP64D-NEXT: slli a0, a0, 16
185181
; RV64ID-LP64D-NEXT: fmv.w.x fa0, a0
186182
; RV64ID-LP64D-NEXT: ret
@@ -223,8 +219,6 @@ define double @bfloat_to_double(bfloat %a) nounwind {
223219
;
224220
; RV64ID-LP64-LABEL: bfloat_to_double:
225221
; RV64ID-LP64: # %bb.0:
226-
; RV64ID-LP64-NEXT: slli a0, a0, 48
227-
; RV64ID-LP64-NEXT: srli a0, a0, 48
228222
; RV64ID-LP64-NEXT: slli a0, a0, 16
229223
; RV64ID-LP64-NEXT: fmv.w.x fa5, a0
230224
; RV64ID-LP64-NEXT: fcvt.d.s fa5, fa5
@@ -242,8 +236,6 @@ define double @bfloat_to_double(bfloat %a) nounwind {
242236
; RV64ID-LP64D-LABEL: bfloat_to_double:
243237
; RV64ID-LP64D: # %bb.0:
244238
; RV64ID-LP64D-NEXT: fmv.x.w a0, fa0
245-
; RV64ID-LP64D-NEXT: slli a0, a0, 48
246-
; RV64ID-LP64D-NEXT: srli a0, a0, 48
247239
; RV64ID-LP64D-NEXT: slli a0, a0, 16
248240
; RV64ID-LP64D-NEXT: fmv.w.x fa5, a0
249241
; RV64ID-LP64D-NEXT: fcvt.d.s fa0, fa5
@@ -366,10 +358,6 @@ define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
366358
; RV64ID-LP64: # %bb.0:
367359
; RV64ID-LP64-NEXT: addi sp, sp, -16
368360
; RV64ID-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
369-
; RV64ID-LP64-NEXT: lui a2, 16
370-
; RV64ID-LP64-NEXT: addi a2, a2, -1
371-
; RV64ID-LP64-NEXT: and a0, a0, a2
372-
; RV64ID-LP64-NEXT: and a1, a1, a2
373361
; RV64ID-LP64-NEXT: slli a1, a1, 16
374362
; RV64ID-LP64-NEXT: fmv.w.x fa5, a1
375363
; RV64ID-LP64-NEXT: slli a0, a0, 16
@@ -408,11 +396,7 @@ define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
408396
; RV64ID-LP64D-NEXT: addi sp, sp, -16
409397
; RV64ID-LP64D-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
410398
; RV64ID-LP64D-NEXT: fmv.x.w a0, fa0
411-
; RV64ID-LP64D-NEXT: lui a1, 16
412-
; RV64ID-LP64D-NEXT: addi a1, a1, -1
413-
; RV64ID-LP64D-NEXT: and a0, a0, a1
414-
; RV64ID-LP64D-NEXT: fmv.x.w a2, fa1
415-
; RV64ID-LP64D-NEXT: and a1, a2, a1
399+
; RV64ID-LP64D-NEXT: fmv.x.w a1, fa1
416400
; RV64ID-LP64D-NEXT: slli a1, a1, 16
417401
; RV64ID-LP64D-NEXT: fmv.w.x fa5, a1
418402
; RV64ID-LP64D-NEXT: slli a0, a0, 16
@@ -604,12 +588,8 @@ define void @bfloat_store(ptr %a, bfloat %b, bfloat %c) nounwind {
604588
; RV64ID-LP64-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
605589
; RV64ID-LP64-NEXT: sd s0, 0(sp) # 8-byte Folded Spill
606590
; RV64ID-LP64-NEXT: mv s0, a0
607-
; RV64ID-LP64-NEXT: lui a0, 16
608-
; RV64ID-LP64-NEXT: addi a0, a0, -1
609-
; RV64ID-LP64-NEXT: and a1, a1, a0
610-
; RV64ID-LP64-NEXT: and a0, a2, a0
611-
; RV64ID-LP64-NEXT: slli a0, a0, 16
612-
; RV64ID-LP64-NEXT: fmv.w.x fa5, a0
591+
; RV64ID-LP64-NEXT: slli a2, a2, 16
592+
; RV64ID-LP64-NEXT: fmv.w.x fa5, a2
613593
; RV64ID-LP64-NEXT: slli a1, a1, 16
614594
; RV64ID-LP64-NEXT: fmv.w.x fa4, a1
615595
; RV64ID-LP64-NEXT: fadd.s fa5, fa4, fa5
@@ -651,11 +631,7 @@ define void @bfloat_store(ptr %a, bfloat %b, bfloat %c) nounwind {
651631
; RV64ID-LP64D-NEXT: sd s0, 0(sp) # 8-byte Folded Spill
652632
; RV64ID-LP64D-NEXT: mv s0, a0
653633
; RV64ID-LP64D-NEXT: fmv.x.w a0, fa0
654-
; RV64ID-LP64D-NEXT: lui a1, 16
655-
; RV64ID-LP64D-NEXT: addi a1, a1, -1
656-
; RV64ID-LP64D-NEXT: and a0, a0, a1
657-
; RV64ID-LP64D-NEXT: fmv.x.w a2, fa1
658-
; RV64ID-LP64D-NEXT: and a1, a2, a1
634+
; RV64ID-LP64D-NEXT: fmv.x.w a1, fa1
659635
; RV64ID-LP64D-NEXT: slli a1, a1, 16
660636
; RV64ID-LP64D-NEXT: fmv.w.x fa5, a1
661637
; RV64ID-LP64D-NEXT: slli a0, a0, 16

0 commit comments

Comments
 (0)