Skip to content

[NVPTX] Add folding for cvt.rn.bf16x2.f32 #116109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 14, 2024

Conversation

AlexMaclean
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/116109.diff

5 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+14)
  • (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+18-36)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll (+2-6)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+10-26)
  • (modified) llvm/test/CodeGen/NVPTX/convert-sm80.ll (+21)
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 2658ca32716378..b58918bc75df72 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -739,6 +739,20 @@ let hasSideEffects = false in {
   def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
 }
 
+def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
+  return N->hasOneUse();
+}]>;
+
+def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse Float32Regs:$a)),
+                                (bf16 (fpround_oneuse Float32Regs:$b)))),
+          (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>,
+      Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>;
+
+def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse Float32Regs:$a)),
+                               (f16 (fpround_oneuse Float32Regs:$b)))),
+          (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>,
+      Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>;
+
 //-----------------------------------
 // Selection instructions (selp)
 //-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 80815b3ca37c05..ded3019e173b4c 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -204,7 +204,7 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_faddx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-NEXT:    .reg .b32 %r<4>;
 ; SM80-NEXT:    .reg .f32 %f<7>;
 ; SM80-EMPTY:
@@ -216,18 +216,16 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM80-NEXT:    cvt.f32.bf16 %f2, %rs4;
 ; SM80-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM80-NEXT:    cvt.f32.bf16 %f4, %rs1;
 ; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
 ; SM80-NEXT:    add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_faddx2(
 ; SM80-FTZ:       {
-; SM80-FTZ-NEXT:    .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
 ; SM80-FTZ-NEXT:    .reg .b32 %r<4>;
 ; SM80-FTZ-NEXT:    .reg .f32 %f<7>;
 ; SM80-FTZ-EMPTY:
@@ -239,12 +237,10 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs4;
 ; SM80-FTZ-NEXT:    add.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs1;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs3;
 ; SM80-FTZ-NEXT:    add.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-FTZ-NEXT:    ret;
 ;
@@ -311,7 +307,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fsubx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-NEXT:    .reg .b32 %r<4>;
 ; SM80-NEXT:    .reg .f32 %f<7>;
 ; SM80-EMPTY:
@@ -323,18 +319,16 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM80-NEXT:    cvt.f32.bf16 %f2, %rs4;
 ; SM80-NEXT:    sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM80-NEXT:    cvt.f32.bf16 %f4, %rs1;
 ; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
 ; SM80-NEXT:    sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fsubx2(
 ; SM80-FTZ:       {
-; SM80-FTZ-NEXT:    .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
 ; SM80-FTZ-NEXT:    .reg .b32 %r<4>;
 ; SM80-FTZ-NEXT:    .reg .f32 %f<7>;
 ; SM80-FTZ-EMPTY:
@@ -346,12 +340,10 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs4;
 ; SM80-FTZ-NEXT:    sub.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs1;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs3;
 ; SM80-FTZ-NEXT:    sub.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-FTZ-NEXT:    ret;
 ;
@@ -418,7 +410,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fmulx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-NEXT:    .reg .b32 %r<4>;
 ; SM80-NEXT:    .reg .f32 %f<7>;
 ; SM80-EMPTY:
@@ -430,18 +422,16 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM80-NEXT:    cvt.f32.bf16 %f2, %rs4;
 ; SM80-NEXT:    mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM80-NEXT:    cvt.f32.bf16 %f4, %rs1;
 ; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
 ; SM80-NEXT:    mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fmulx2(
 ; SM80-FTZ:       {
-; SM80-FTZ-NEXT:    .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
 ; SM80-FTZ-NEXT:    .reg .b32 %r<4>;
 ; SM80-FTZ-NEXT:    .reg .f32 %f<7>;
 ; SM80-FTZ-EMPTY:
@@ -453,12 +443,10 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs4;
 ; SM80-FTZ-NEXT:    mul.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs1;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs3;
 ; SM80-FTZ-NEXT:    mul.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-FTZ-NEXT:    ret;
 ;
@@ -525,7 +513,7 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fdiv(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-NEXT:    .reg .b32 %r<4>;
 ; SM80-NEXT:    .reg .f32 %f<7>;
 ; SM80-EMPTY:
@@ -537,18 +525,16 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM80-NEXT:    cvt.f32.bf16 %f2, %rs4;
 ; SM80-NEXT:    div.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM80-NEXT:    cvt.f32.bf16 %f4, %rs1;
 ; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
 ; SM80-NEXT:    div.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fdiv(
 ; SM80-FTZ:       {
-; SM80-FTZ-NEXT:    .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
 ; SM80-FTZ-NEXT:    .reg .b32 %r<4>;
 ; SM80-FTZ-NEXT:    .reg .f32 %f<7>;
 ; SM80-FTZ-EMPTY:
@@ -560,18 +546,16 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs4;
 ; SM80-FTZ-NEXT:    div.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs1;
 ; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs3;
 ; SM80-FTZ-NEXT:    div.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-FTZ-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fdiv(
 ; SM90:       {
-; SM90-NEXT:    .reg .b16 %rs<7>;
+; SM90-NEXT:    .reg .b16 %rs<5>;
 ; SM90-NEXT:    .reg .b32 %r<4>;
 ; SM90-NEXT:    .reg .f32 %f<7>;
 ; SM90-EMPTY:
@@ -583,12 +567,10 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM90-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
 ; SM90-NEXT:    cvt.f32.bf16 %f2, %rs4;
 ; SM90-NEXT:    div.rn.f32 %f3, %f2, %f1;
-; SM90-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
 ; SM90-NEXT:    cvt.f32.bf16 %f4, %rs1;
 ; SM90-NEXT:    cvt.f32.bf16 %f5, %rs3;
 ; SM90-NEXT:    div.rn.f32 %f6, %f5, %f4;
-; SM90-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM90-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM90-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM90-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM90-NEXT:    ret;
   %r = fdiv <2 x bfloat> %a, %b
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
index a53c90ac6db8b6..d4f5ea6158218a 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
@@ -13,9 +13,7 @@ declare <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a) #0
 ; CHECK-DAG:  cvt.f32.bf16     [[AF1:%f[0-9]+]], [[A1]];
 ; CHECK-DAG:  sin.approx.f32  [[RF0:%f[0-9]+]], [[AF0]];
 ; CHECK-DAG:  sin.approx.f32  [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:      cvt.rn.bf16x2.f32         [[R:%r[0-9]+]], [[RF0]], [[RF1]]
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
@@ -30,9 +28,7 @@ define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
 ; CHECK-DAG:  cvt.f32.bf16     [[AF1:%f[0-9]+]], [[A1]];
 ; CHECK-DAG:  cos.approx.f32  [[RF0:%f[0-9]+]], [[AF0]];
 ; CHECK-DAG:  cos.approx.f32  [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:      cvt.rn.bf16x2.f32         [[R:%r[0-9]+]], [[RF0]], [[RF1]]
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 {
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 925ae4245a4c20..e820435f1710f0 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -26,9 +26,7 @@ define <2 x bfloat> @test_ret_const() #0 {
 ; SM80-DAG:  cvt.f32.bf16    [[FA1:%f[0-9]+]], [[A1]]
 ; SM80-DAG:  add.rn.f32     [[FR0:%f[0-9]+]], [[FA0]], 0f3F800000;
 ; SM80-DAG:  add.rn.f32     [[FR1:%f[0-9]+]], [[FA1]], 0f40000000;
-; SM80-DAG:  cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]
-; SM80-DAG:  cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]
-; SM80-DAG:  mov.b32        [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80-DAG:  cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[FR0]], [[FR1]];
 ;
 ; CHECK-NEXT: st.param.b32    [func_retval0], [[R]];
 ; CHECK-NEXT: ret;
@@ -68,9 +66,7 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
 ; SM80-DAG:   cvt.f32.bf16    [[FB1:%f[0-9]+]], [[B1]];
 ; SM80-DAG:   sub.rn.f32      [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
 ; SM80-DAG:   sub.rn.f32      [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; SM80-DAG:   cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; SM80-DAG:   cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; SM80:       mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; SM80-DAG:   cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[FR0]], [[FR1]];
 
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
@@ -93,9 +89,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-DAG:   cvt.f32.bf16    [[FB1:%f[0-9]+]], [[B1]];
 ; SM80-DAG:   mul.rn.f32      [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
 ; SM80-DAG:   mul.rn.f32      [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; SM80-DAG:   cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; SM80-DAG:   cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; SM80:       mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; SM80-DAG:  cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[FR0]], [[FR1]];
 
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
@@ -116,9 +110,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; CHECK-DAG:  cvt.f32.bf16     [[FB1:%f[0-9]+]], [[B1]];
 ; CHECK-DAG:  div.rn.f32      [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
 ; CHECK-DAG:  div.rn.f32      [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[FR0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[FR1]];
-; CHECK-NEXT: mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:  cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[FR0]], [[FR1]];
 ; CHECK-NEXT: st.param.b32    [func_retval0], [[R]];
 ; CHECK-NEXT: ret;
 
@@ -287,9 +279,7 @@ define <2 x bfloat> @test_select_cc_bf16_f32(<2 x bfloat> %a, <2 x bfloat> %b,
 
 ; CHECK-LABEL: test_fptrunc_2xfloat(
 ; CHECK:      ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_fptrunc_2xfloat_param_0];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[A0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[A1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:      cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[A0]], [[A1]];
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_fptrunc_2xfloat(<2 x float> %a) #0 {
@@ -359,9 +349,7 @@ declare <2 x bfloat> @llvm.fmuladd.f16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bf
 ; CHECK-DAG:  cvt.f32.bf16     [[AF1:%f[0-9]+]], [[A1]];
 ; CHECK-DAG:  sqrt.rn.f32     [[RF0:%f[0-9]+]], [[AF0]];
 ; CHECK-DAG:  sqrt.rn.f32     [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-DAG:  cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[RF0]], [[RF1]];
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_sqrt(<2 x bfloat> %a) #0 {
@@ -436,9 +424,7 @@ define <2 x bfloat> @test_maxnum(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-DAG:   cvt.f32.bf16     [[FA1:%f[0-9]+]], [[A1]];
 ; SM80-DAG:  cvt.rmi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
 ; SM80-DAG:  cvt.rmi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
-; SM80-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; SM80-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80:      cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[RF0]], [[RF1]];
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
@@ -455,9 +441,7 @@ define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
 ; SM80-DAG:   cvt.f32.bf16     [[FA1:%f[0-9]+]], [[A1]];
 ; SM80-DAG:   cvt.rpi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
 ; SM80-DAG:   cvt.rpi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
-; SM80-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; SM80-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80:       cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[RF0]], [[RF1]];
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
@@ -470,7 +454,7 @@ define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
 ; CHECK-DAG:  mov.b32         {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
 ; SM90:  cvt.rzi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
 ; SM90:  cvt.rzi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM90:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
@@ -483,7 +467,7 @@ define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
 ; CHECK-DAG:  mov.b32         {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
 ; SM90:  cvt.rni.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
 ; SM90:  cvt.rni.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM90:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_rint(<2 x bfloat> %a) #0 {
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
index 4e30cebfe90251..417a3b72f6fde3 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -135,3 +135,24 @@ ret i32 %val
 }
 
 declare i32 @llvm.nvvm.f2tf32.rna(float)
+
+
+define <2 x bfloat> @fold_ff2bf16x2(float %a, float %b) {
+; CHECK-LABEL: fold_ff2bf16x2
+; CHECK: cvt.rn.bf16x2.f32
+  %ah = fptrunc float %a to bfloat
+  %bh = fptrunc float %b to bfloat
+  %v0 = insertelement <2 x bfloat> poison, bfloat %ah, i64 0
+  %v1 = insertelement <2 x bfloat> %v0, bfloat %bh, i64 1
+  ret <2 x bfloat> %v1
+}
+
+define <2 x half> @fold_ff2f16x2(float %a, float %b) {
+; CHECK-LABEL: fold_ff2f16x2
+; CHECK: cvt.rn.f16x2.f32
+  %ah = fptrunc float %a to half
+  %bh = fptrunc float %b to half
+  %v0 = insertelement <2 x half> poison, half %ah, i64 0
+  %v1 = insertelement <2 x half> %v0, half %bh, i64 1
+  ret <2 x half> %v1
+}

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. LGTM modulo test nit.



define <2 x bfloat> @fold_ff2bf16x2(float %a, float %b) {
; CHECK-LABEL: fold_ff2bf16x2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may as well automate check generation in this test file, too. For vector construction, theorder of inputs is important and we're not capturing it in this file right now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I've updated the test file.

@AlexMaclean AlexMaclean merged commit 90cbd4a into llvm:main Nov 14, 2024
8 checks passed
@tomnatan30
Copy link
Contributor

This change is breaking triton tests (results in huge numeric disparities, e.g. https://github.com/triton-lang/triton/blob/main/python/test/unit/language/test_core.py), we'll need to revert until a fix forward can be merged.

tomnatan30 added a commit that referenced this pull request Nov 15, 2024
cota pushed a commit that referenced this pull request Nov 15, 2024
Reverts #116109

This change is breaking triton tests (results in huge numeric
disparities, e.g.
https://github.com/triton-lang/triton/blob/main/python/test/unit/language/test_core.py),
we'll need to revert until a fix forward can be merged.
@justinfargnoli
Copy link
Contributor

Thank you for identifying and reverting this @tomnatan30!

@AlexMaclean
Copy link
Member Author

Apologies for this issue. #116417 should fix. @tomnatan30 please review if you can.

AlexMaclean added a commit to AlexMaclean/llvm-project that referenced this pull request Nov 18, 2024
AlexMaclean added a commit to AlexMaclean/llvm-project that referenced this pull request Dec 6, 2024
AlexMaclean added a commit that referenced this pull request Dec 6, 2024
Reland #116109.

Fixes issue where operands were flipped. 

Per the PTX spec, a mov instruction packs the first operand as low, and
the second operand as high:
> ```
> // pack two 16-bit elements into .b32
> d = a.x | (a.y << 16)
> ```
On the other hand cvt.rn.f16x2.f32 instructions take high, than low
operands:
> For .f16x2 and .bf16x2 instruction type, two inputs a and b of .f32
type are converted into .f16 or .bf16 type and the converted values are
packed in the destination register d, such that the value converted from
input a is stored in the upper half of d and the value converted from
input b is stored in the lower half of d
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants