Skip to content

Commit 6c6903d

Browse files
committed
[NVPTX] Add folding for cvt.rn.bf16x2.f32 (#116109)
1 parent 5587627 commit 6c6903d

File tree

4 files changed

+236
-113
lines changed

4 files changed

+236
-113
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,6 +727,20 @@ let hasSideEffects = false in {
727727
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
728728
}
729729

730+
def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
731+
return N->hasOneUse();
732+
}]>;
733+
734+
def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse Float32Regs:$a)),
735+
(bf16 (fpround_oneuse Float32Regs:$b)))),
736+
(CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>,
737+
Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>;
738+
739+
def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse Float32Regs:$a)),
740+
(f16 (fpround_oneuse Float32Regs:$b)))),
741+
(CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>,
742+
Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>;
743+
730744
//-----------------------------------
731745
// Selection instructions (selp)
732746
//-----------------------------------

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
204204
;
205205
; SM80-LABEL: test_faddx2(
206206
; SM80: {
207-
; SM80-NEXT: .reg .b16 %rs<7>;
207+
; SM80-NEXT: .reg .b16 %rs<5>;
208208
; SM80-NEXT: .reg .b32 %r<4>;
209209
; SM80-NEXT: .reg .f32 %f<7>;
210210
; SM80-EMPTY:
@@ -216,18 +216,16 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
216216
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
217217
; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
218218
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
219-
; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
220219
; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
221220
; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
222221
; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
223-
; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
224-
; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
222+
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
225223
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
226224
; SM80-NEXT: ret;
227225
;
228226
; SM80-FTZ-LABEL: test_faddx2(
229227
; SM80-FTZ: {
230-
; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
228+
; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
231229
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
232230
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
233231
; SM80-FTZ-EMPTY:
@@ -239,12 +237,10 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
239237
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
240238
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
241239
; SM80-FTZ-NEXT: add.rn.ftz.f32 %f3, %f2, %f1;
242-
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
243240
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
244241
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
245242
; SM80-FTZ-NEXT: add.rn.ftz.f32 %f6, %f5, %f4;
246-
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
247-
; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
243+
; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
248244
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
249245
; SM80-FTZ-NEXT: ret;
250246
;
@@ -311,7 +307,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
311307
;
312308
; SM80-LABEL: test_fsubx2(
313309
; SM80: {
314-
; SM80-NEXT: .reg .b16 %rs<7>;
310+
; SM80-NEXT: .reg .b16 %rs<5>;
315311
; SM80-NEXT: .reg .b32 %r<4>;
316312
; SM80-NEXT: .reg .f32 %f<7>;
317313
; SM80-EMPTY:
@@ -323,18 +319,16 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
323319
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
324320
; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
325321
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
326-
; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
327322
; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
328323
; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
329324
; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
330-
; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
331-
; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
325+
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
332326
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
333327
; SM80-NEXT: ret;
334328
;
335329
; SM80-FTZ-LABEL: test_fsubx2(
336330
; SM80-FTZ: {
337-
; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
331+
; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
338332
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
339333
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
340334
; SM80-FTZ-EMPTY:
@@ -346,12 +340,10 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
346340
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
347341
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
348342
; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f3, %f2, %f1;
349-
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
350343
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
351344
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
352345
; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f6, %f5, %f4;
353-
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
354-
; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
346+
; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
355347
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
356348
; SM80-FTZ-NEXT: ret;
357349
;
@@ -418,7 +410,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
418410
;
419411
; SM80-LABEL: test_fmulx2(
420412
; SM80: {
421-
; SM80-NEXT: .reg .b16 %rs<7>;
413+
; SM80-NEXT: .reg .b16 %rs<5>;
422414
; SM80-NEXT: .reg .b32 %r<4>;
423415
; SM80-NEXT: .reg .f32 %f<7>;
424416
; SM80-EMPTY:
@@ -430,18 +422,16 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
430422
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
431423
; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
432424
; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
433-
; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
434425
; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
435426
; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
436427
; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
437-
; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
438-
; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
428+
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
439429
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
440430
; SM80-NEXT: ret;
441431
;
442432
; SM80-FTZ-LABEL: test_fmulx2(
443433
; SM80-FTZ: {
444-
; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
434+
; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
445435
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
446436
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
447437
; SM80-FTZ-EMPTY:
@@ -453,12 +443,10 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
453443
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
454444
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
455445
; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f3, %f2, %f1;
456-
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
457446
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
458447
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
459448
; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f6, %f5, %f4;
460-
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
461-
; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
449+
; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
462450
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
463451
; SM80-FTZ-NEXT: ret;
464452
;
@@ -525,7 +513,7 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
525513
;
526514
; SM80-LABEL: test_fdiv(
527515
; SM80: {
528-
; SM80-NEXT: .reg .b16 %rs<7>;
516+
; SM80-NEXT: .reg .b16 %rs<5>;
529517
; SM80-NEXT: .reg .b32 %r<4>;
530518
; SM80-NEXT: .reg .f32 %f<7>;
531519
; SM80-EMPTY:
@@ -537,18 +525,16 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
537525
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
538526
; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
539527
; SM80-NEXT: div.rn.f32 %f3, %f2, %f1;
540-
; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
541528
; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
542529
; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
543530
; SM80-NEXT: div.rn.f32 %f6, %f5, %f4;
544-
; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
545-
; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
531+
; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
546532
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
547533
; SM80-NEXT: ret;
548534
;
549535
; SM80-FTZ-LABEL: test_fdiv(
550536
; SM80-FTZ: {
551-
; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
537+
; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
552538
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
553539
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
554540
; SM80-FTZ-EMPTY:
@@ -560,18 +546,16 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
560546
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
561547
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
562548
; SM80-FTZ-NEXT: div.rn.ftz.f32 %f3, %f2, %f1;
563-
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
564549
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
565550
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
566551
; SM80-FTZ-NEXT: div.rn.ftz.f32 %f6, %f5, %f4;
567-
; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
568-
; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
552+
; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
569553
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
570554
; SM80-FTZ-NEXT: ret;
571555
;
572556
; SM90-LABEL: test_fdiv(
573557
; SM90: {
574-
; SM90-NEXT: .reg .b16 %rs<7>;
558+
; SM90-NEXT: .reg .b16 %rs<5>;
575559
; SM90-NEXT: .reg .b32 %r<4>;
576560
; SM90-NEXT: .reg .f32 %f<7>;
577561
; SM90-EMPTY:
@@ -583,12 +567,10 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
583567
; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r1;
584568
; SM90-NEXT: cvt.f32.bf16 %f2, %rs4;
585569
; SM90-NEXT: div.rn.f32 %f3, %f2, %f1;
586-
; SM90-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
587570
; SM90-NEXT: cvt.f32.bf16 %f4, %rs1;
588571
; SM90-NEXT: cvt.f32.bf16 %f5, %rs3;
589572
; SM90-NEXT: div.rn.f32 %f6, %f5, %f4;
590-
; SM90-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
591-
; SM90-NEXT: mov.b32 %r3, {%rs6, %rs5};
573+
; SM90-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
592574
; SM90-NEXT: st.param.b32 [func_retval0], %r3;
593575
; SM90-NEXT: ret;
594576
%r = fdiv <2 x bfloat> %a, %b

llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
77
declare <2 x bfloat> @llvm.sin.f16(<2 x bfloat> %a) #0
88
declare <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a) #0
99

10+
1011
define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
1112
; CHECK-LABEL: test_sin(
1213
; CHECK: {

0 commit comments

Comments
 (0)