Skip to content

Reland "[NVPTX] Add folding for cvt.rn.bf16x2.f32" #116417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 6, 2024

Conversation

AlexMaclean
Copy link
Member

@AlexMaclean AlexMaclean commented Nov 15, 2024

Reland #116109.

Fixes issue where operands were flipped.

Per the PTX spec, a mov instruction packs the first operand as low, and the second operand as high:

// pack two 16-bit elements into .b32
d = a.x | (a.y << 16)

On the other hand cvt.rn.f16x2.f32 instructions take high, than low operands:

For .f16x2 and .bf16x2 instruction type, two inputs a and b of .f32 type are converted into .f16 or .bf16 type and the converted values are packed in the destination register d, such that the value converted from input a is stored in the upper half of d and the value converted from input b is stored in the lower half of d

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like flipped args between LLVM and PTX is a common failure pattern.
We should probably switch argument names from a/b to hi/lo where bit order matters or e0/e1 for vectors to make it a bit easier catching such issues early next time. It's not fool-proof but would at least give a hint where particular argument is supposed to be placed.

@llvmbot
Copy link
Member

llvmbot commented Nov 16, 2024

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

Changes

Reland #116109.

Fixes issue where operands were flipped.

Per the PTX spec, a mov instruction packs the first operand as low, and the second operand as high:
> &gt; // pack two 16-bit elements into .b32 &gt; d = a.x | (a.y &lt;&lt; 16) &gt;
On the other hand cvt.rn.f16x2.f32 instructions take high, than low operands:
> For .f16x2 and .bf16x2 instruction type, two inputs a and b of .f32 type are converted into .f16 or .bf16 type and the converted values are packed in the destination register d, such that the value converted from input a is stored in the upper half of d and the value converted from input b is stored in the lower half of d


Patch is 31.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116417.diff

5 Files Affected:

  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+14)
  • (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+54-72)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll (+2-6)
  • (modified) llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll (+10-26)
  • (modified) llvm/test/CodeGen/NVPTX/convert-sm80.ll (+203-77)
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a16935dcbb93be..77821f3bfa33c4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -739,6 +739,20 @@ let hasSideEffects = false in {
   def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
 }
 
+def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
+  return N->hasOneUse();
+}]>;
+
+def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse Float32Regs:$a)),
+                                (bf16 (fpround_oneuse Float32Regs:$b)))),
+          (CVT_bf16x2_f32 Float32Regs:$b, Float32Regs:$a, CvtRN)>,
+      Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>;
+
+def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse Float32Regs:$a)),
+                               (f16 (fpround_oneuse Float32Regs:$b)))),
+          (CVT_f16x2_f32 Float32Regs:$b, Float32Regs:$a, CvtRN)>,
+      Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>;
+
 //-----------------------------------
 // Selection instructions (selp)
 //-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 80815b3ca37c05..eee31be80e9826 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -204,7 +204,7 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_faddx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-NEXT:    .reg .b32 %r<4>;
 ; SM80-NEXT:    .reg .f32 %f<7>;
 ; SM80-EMPTY:
@@ -212,22 +212,20 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    ld.param.b32 %r1, [test_faddx2_param_0];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_faddx2_param_1];
 ; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs2;
+; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
 ; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs4;
+; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
 ; SM80-NEXT:    add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs1;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
+; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
+; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
 ; SM80-NEXT:    add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_faddx2(
 ; SM80-FTZ:       {
-; SM80-FTZ-NEXT:    .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
 ; SM80-FTZ-NEXT:    .reg .b32 %r<4>;
 ; SM80-FTZ-NEXT:    .reg .f32 %f<7>;
 ; SM80-FTZ-EMPTY:
@@ -235,16 +233,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-FTZ-NEXT:    ld.param.b32 %r1, [test_faddx2_param_0];
 ; SM80-FTZ-NEXT:    ld.param.b32 %r2, [test_faddx2_param_1];
 ; SM80-FTZ-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f1, %rs2;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f1, %rs1;
 ; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs4;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs3;
 ; SM80-FTZ-NEXT:    add.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs3;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs2;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs4;
 ; SM80-FTZ-NEXT:    add.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-FTZ-NEXT:    ret;
 ;
@@ -311,7 +307,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fsubx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-NEXT:    .reg .b32 %r<4>;
 ; SM80-NEXT:    .reg .f32 %f<7>;
 ; SM80-EMPTY:
@@ -319,22 +315,20 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fsubx2_param_0];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fsubx2_param_1];
 ; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs2;
+; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
 ; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs4;
+; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
 ; SM80-NEXT:    sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs1;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
+; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
+; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
 ; SM80-NEXT:    sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fsubx2(
 ; SM80-FTZ:       {
-; SM80-FTZ-NEXT:    .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
 ; SM80-FTZ-NEXT:    .reg .b32 %r<4>;
 ; SM80-FTZ-NEXT:    .reg .f32 %f<7>;
 ; SM80-FTZ-EMPTY:
@@ -342,16 +336,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-FTZ-NEXT:    ld.param.b32 %r1, [test_fsubx2_param_0];
 ; SM80-FTZ-NEXT:    ld.param.b32 %r2, [test_fsubx2_param_1];
 ; SM80-FTZ-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f1, %rs2;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f1, %rs1;
 ; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs4;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs3;
 ; SM80-FTZ-NEXT:    sub.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs3;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs2;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs4;
 ; SM80-FTZ-NEXT:    sub.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-FTZ-NEXT:    ret;
 ;
@@ -418,7 +410,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fmulx2(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-NEXT:    .reg .b32 %r<4>;
 ; SM80-NEXT:    .reg .f32 %f<7>;
 ; SM80-EMPTY:
@@ -426,22 +418,20 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_0];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_1];
 ; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs2;
+; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
 ; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs4;
+; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
 ; SM80-NEXT:    mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs1;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
+; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
+; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
 ; SM80-NEXT:    mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fmulx2(
 ; SM80-FTZ:       {
-; SM80-FTZ-NEXT:    .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
 ; SM80-FTZ-NEXT:    .reg .b32 %r<4>;
 ; SM80-FTZ-NEXT:    .reg .f32 %f<7>;
 ; SM80-FTZ-EMPTY:
@@ -449,16 +439,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-FTZ-NEXT:    ld.param.b32 %r1, [test_fmulx2_param_0];
 ; SM80-FTZ-NEXT:    ld.param.b32 %r2, [test_fmulx2_param_1];
 ; SM80-FTZ-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f1, %rs2;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f1, %rs1;
 ; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs4;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs3;
 ; SM80-FTZ-NEXT:    mul.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs3;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs2;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs4;
 ; SM80-FTZ-NEXT:    mul.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-FTZ-NEXT:    ret;
 ;
@@ -525,7 +513,7 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ;
 ; SM80-LABEL: test_fdiv(
 ; SM80:       {
-; SM80-NEXT:    .reg .b16 %rs<7>;
+; SM80-NEXT:    .reg .b16 %rs<5>;
 ; SM80-NEXT:    .reg .b32 %r<4>;
 ; SM80-NEXT:    .reg .f32 %f<7>;
 ; SM80-EMPTY:
@@ -533,22 +521,20 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-NEXT:    ld.param.b32 %r1, [test_fdiv_param_0];
 ; SM80-NEXT:    ld.param.b32 %r2, [test_fdiv_param_1];
 ; SM80-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT:    cvt.f32.bf16 %f1, %rs2;
+; SM80-NEXT:    cvt.f32.bf16 %f1, %rs1;
 ; SM80-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT:    cvt.f32.bf16 %f2, %rs4;
+; SM80-NEXT:    cvt.f32.bf16 %f2, %rs3;
 ; SM80-NEXT:    div.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-NEXT:    cvt.f32.bf16 %f4, %rs1;
-; SM80-NEXT:    cvt.f32.bf16 %f5, %rs3;
+; SM80-NEXT:    cvt.f32.bf16 %f4, %rs2;
+; SM80-NEXT:    cvt.f32.bf16 %f5, %rs4;
 ; SM80-NEXT:    div.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-NEXT:    ret;
 ;
 ; SM80-FTZ-LABEL: test_fdiv(
 ; SM80-FTZ:       {
-; SM80-FTZ-NEXT:    .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT:    .reg .b16 %rs<5>;
 ; SM80-FTZ-NEXT:    .reg .b32 %r<4>;
 ; SM80-FTZ-NEXT:    .reg .f32 %f<7>;
 ; SM80-FTZ-EMPTY:
@@ -556,22 +542,20 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-FTZ-NEXT:    ld.param.b32 %r1, [test_fdiv_param_0];
 ; SM80-FTZ-NEXT:    ld.param.b32 %r2, [test_fdiv_param_1];
 ; SM80-FTZ-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f1, %rs2;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f1, %rs1;
 ; SM80-FTZ-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs4;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f2, %rs3;
 ; SM80-FTZ-NEXT:    div.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs1;
-; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs3;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f4, %rs2;
+; SM80-FTZ-NEXT:    cvt.ftz.f32.bf16 %f5, %rs4;
 ; SM80-FTZ-NEXT:    div.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM80-FTZ-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM80-FTZ-NEXT:    ret;
 ;
 ; SM90-LABEL: test_fdiv(
 ; SM90:       {
-; SM90-NEXT:    .reg .b16 %rs<7>;
+; SM90-NEXT:    .reg .b16 %rs<5>;
 ; SM90-NEXT:    .reg .b32 %r<4>;
 ; SM90-NEXT:    .reg .f32 %f<7>;
 ; SM90-EMPTY:
@@ -579,16 +563,14 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM90-NEXT:    ld.param.b32 %r1, [test_fdiv_param_0];
 ; SM90-NEXT:    ld.param.b32 %r2, [test_fdiv_param_1];
 ; SM90-NEXT:    mov.b32 {%rs1, %rs2}, %r2;
-; SM90-NEXT:    cvt.f32.bf16 %f1, %rs2;
+; SM90-NEXT:    cvt.f32.bf16 %f1, %rs1;
 ; SM90-NEXT:    mov.b32 {%rs3, %rs4}, %r1;
-; SM90-NEXT:    cvt.f32.bf16 %f2, %rs4;
+; SM90-NEXT:    cvt.f32.bf16 %f2, %rs3;
 ; SM90-NEXT:    div.rn.f32 %f3, %f2, %f1;
-; SM90-NEXT:    cvt.rn.bf16.f32 %rs5, %f3;
-; SM90-NEXT:    cvt.f32.bf16 %f4, %rs1;
-; SM90-NEXT:    cvt.f32.bf16 %f5, %rs3;
+; SM90-NEXT:    cvt.f32.bf16 %f4, %rs2;
+; SM90-NEXT:    cvt.f32.bf16 %f5, %rs4;
 ; SM90-NEXT:    div.rn.f32 %f6, %f5, %f4;
-; SM90-NEXT:    cvt.rn.bf16.f32 %rs6, %f6;
-; SM90-NEXT:    mov.b32 %r3, {%rs6, %rs5};
+; SM90-NEXT:    cvt.rn.bf16x2.f32 %r3, %f6, %f3;
 ; SM90-NEXT:    st.param.b32 [func_retval0], %r3;
 ; SM90-NEXT:    ret;
   %r = fdiv <2 x bfloat> %a, %b
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
index a53c90ac6db8b6..f93825b5a2f6c9 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
@@ -13,9 +13,7 @@ declare <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a) #0
 ; CHECK-DAG:  cvt.f32.bf16     [[AF1:%f[0-9]+]], [[A1]];
 ; CHECK-DAG:  sin.approx.f32  [[RF0:%f[0-9]+]], [[AF0]];
 ; CHECK-DAG:  sin.approx.f32  [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:      cvt.rn.bf16x2.f32         [[R:%r[0-9]+]], [[RF1]], [[RF0]]
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
@@ -30,9 +28,7 @@ define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
 ; CHECK-DAG:  cvt.f32.bf16     [[AF1:%f[0-9]+]], [[A1]];
 ; CHECK-DAG:  cos.approx.f32  [[RF0:%f[0-9]+]], [[AF0]];
 ; CHECK-DAG:  cos.approx.f32  [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:      cvt.rn.bf16x2.f32         [[R:%r[0-9]+]], [[RF1]], [[RF0]]
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 {
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 925ae4245a4c20..db618a6be01c8a 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -26,9 +26,7 @@ define <2 x bfloat> @test_ret_const() #0 {
 ; SM80-DAG:  cvt.f32.bf16    [[FA1:%f[0-9]+]], [[A1]]
 ; SM80-DAG:  add.rn.f32     [[FR0:%f[0-9]+]], [[FA0]], 0f3F800000;
 ; SM80-DAG:  add.rn.f32     [[FR1:%f[0-9]+]], [[FA1]], 0f40000000;
-; SM80-DAG:  cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]
-; SM80-DAG:  cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]
-; SM80-DAG:  mov.b32        [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80-DAG:  cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[FR1]], [[FR0]];
 ;
 ; CHECK-NEXT: st.param.b32    [func_retval0], [[R]];
 ; CHECK-NEXT: ret;
@@ -68,9 +66,7 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
 ; SM80-DAG:   cvt.f32.bf16    [[FB1:%f[0-9]+]], [[B1]];
 ; SM80-DAG:   sub.rn.f32      [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
 ; SM80-DAG:   sub.rn.f32      [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; SM80-DAG:   cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; SM80-DAG:   cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; SM80:       mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; SM80-DAG:   cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[FR1]], [[FR0]];
 
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
@@ -93,9 +89,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-DAG:   cvt.f32.bf16    [[FB1:%f[0-9]+]], [[B1]];
 ; SM80-DAG:   mul.rn.f32      [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
 ; SM80-DAG:   mul.rn.f32      [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; SM80-DAG:   cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; SM80-DAG:   cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; SM80:       mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; SM80-DAG:  cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[FR1]], [[FR0]];
 
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
@@ -116,9 +110,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; CHECK-DAG:  cvt.f32.bf16     [[FB1:%f[0-9]+]], [[B1]];
 ; CHECK-DAG:  div.rn.f32      [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
 ; CHECK-DAG:  div.rn.f32      [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[FR0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[FR1]];
-; CHECK-NEXT: mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:  cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[FR1]], [[FR0]];
 ; CHECK-NEXT: st.param.b32    [func_retval0], [[R]];
 ; CHECK-NEXT: ret;
 
@@ -287,9 +279,7 @@ define <2 x bfloat> @test_select_cc_bf16_f32(<2 x bfloat> %a, <2 x bfloat> %b,
 
 ; CHECK-LABEL: test_fptrunc_2xfloat(
 ; CHECK:      ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_fptrunc_2xfloat_param_0];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[A0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[A1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK:      cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[A1]], [[A0]];
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_fptrunc_2xfloat(<2 x float> %a) #0 {
@@ -359,9 +349,7 @@ declare <2 x bfloat> @llvm.fmuladd.f16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bf
 ; CHECK-DAG:  cvt.f32.bf16     [[AF1:%f[0-9]+]], [[A1]];
 ; CHECK-DAG:  sqrt.rn.f32     [[RF0:%f[0-9]+]], [[AF0]];
 ; CHECK-DAG:  sqrt.rn.f32     [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-DAG:  cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[RF1]], [[RF0]];
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_sqrt(<2 x bfloat> %a) #0 {
@@ -436,9 +424,7 @@ define <2 x bfloat> @test_maxnum(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
 ; SM80-DAG:   cvt.f32.bf16     [[FA1:%f[0-9]+]], [[A1]];
 ; SM80-DAG:  cvt.rmi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
 ; SM80-DAG:  cvt.rmi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
-; SM80-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; SM80-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80:      cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[RF1]], [[RF0]];
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
@@ -455,9 +441,7 @@ define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
 ; SM80-DAG:   cvt.f32.bf16     [[FA1:%f[0-9]+]], [[A1]];
 ; SM80-DAG:   cvt.rpi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
 ; SM80-DAG:   cvt.rpi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
-; SM80-DAG:  cvt.rn.bf16.f32  [[R0:%rs[0-9]+]], [[RF0]];
-; SM80-DAG:  cvt.rn.bf16.f32  [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80:       cvt.rn.bf16x2.f32        [[R:%r[0-9]+]], [[RF1]], [[RF0]];
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
@@ -470,7 +454,7 @@ define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
 ; CHECK-DAG:  mov.b32         {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
 ; SM90:  cvt.rzi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
 ; SM90:  cvt.rzi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM90:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
 ; CHECK:      st.param.b32    [func_retval0], [[R]];
 ; CHECK:      ret;
 define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
@@ -483,7 +467,7 @@ define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
 ; CHECK-DAG:  mov.b32         {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
 ; SM90:  cvt.rni.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
 ; SM90:  cvt.rni.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK:      mov.b32         [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM90:      mov.b32         [[R:%r[0-9]+]], {[[R0]]...
[truncated]

@AlexMaclean
Copy link
Member Author

Looks like flipped args between LLVM and PTX is a common failure pattern. We should probably switch argument names from a/b to hi/lo where bit order matters or e0/e1 for vectors to make it a bit easier catching such issues early next time. It's not fool-proof but would at least give a hint where particular argument is supposed to be placed.

Yea, inconsistent argument ordering in PTX does make things tricky. I've switched the patterns in this MR to use lo and hi as you suggested.

@Artem-B
Copy link
Member

Artem-B commented Nov 18, 2024

Yea, inconsistent argument ordering in PTX does make things tricky. I've switched the patterns in this MR to use lo and hi as you suggested.

Tests could use the same treatment -- without looking up the docs for nvvm intrinsics it's hard to tell where a/b or f1/f2 are supposed to end up in the instruction arguments. While now I can be confident that the tablegen pattern is correct, I still would have hard time telling if the test patterns check what we want them to check.

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/reland-cvt2 branch from 87dba4e to 1003ab5 Compare November 18, 2024 22:40
@AlexMaclean
Copy link
Member Author

@tomnatan30 would it be possible to confirm on your end that this latest iteration no longer causes issues?

@AlexMaclean AlexMaclean force-pushed the dev/amaclean/reland-cvt2 branch from 1003ab5 to 07b2045 Compare December 6, 2024 18:34
@AlexMaclean AlexMaclean merged commit 4b24ab4 into llvm:main Dec 6, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants