-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Reland "[NVPTX] Add folding for cvt.rn.bf16x2.f32" #116417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like flipped args between LLVM and PTX is a common failure pattern.
We should probably switch argument names from a/b
to hi/lo
where bit order matters or e0/e1
for vectors to make it a bit easier catching such issues early next time. It's not fool-proof but would at least give a hint where particular argument is supposed to be placed.
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesReland #116109. Fixes issue where operands were flipped. Per the PTX spec, a mov instruction packs the first operand as low, and the second operand as high: Patch is 31.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116417.diff 5 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index a16935dcbb93be..77821f3bfa33c4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -739,6 +739,20 @@ let hasSideEffects = false in {
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
}
+def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
+ return N->hasOneUse();
+}]>;
+
+def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse Float32Regs:$a)),
+ (bf16 (fpround_oneuse Float32Regs:$b)))),
+ (CVT_bf16x2_f32 Float32Regs:$b, Float32Regs:$a, CvtRN)>,
+ Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>;
+
+def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse Float32Regs:$a)),
+ (f16 (fpround_oneuse Float32Regs:$b)))),
+ (CVT_f16x2_f32 Float32Regs:$b, Float32Regs:$a, CvtRN)>,
+ Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>;
+
//-----------------------------------
// Selection instructions (selp)
//-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 80815b3ca37c05..eee31be80e9826 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -204,7 +204,7 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_faddx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-NEXT: .reg .b32 %r<4>;
; SM80-NEXT: .reg .f32 %f<7>;
; SM80-EMPTY:
@@ -212,22 +212,20 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
+; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
+; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
+; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
+; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_faddx2(
; SM80-FTZ: {
-; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
; SM80-FTZ-EMPTY:
@@ -235,16 +233,14 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-FTZ-NEXT: ld.param.b32 %r1, [test_faddx2_param_0];
; SM80-FTZ-NEXT: ld.param.b32 %r2, [test_faddx2_param_1];
; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f1, %rs2;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f1, %rs1;
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs3;
; SM80-FTZ-NEXT: add.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs2;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs4;
; SM80-FTZ-NEXT: add.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-FTZ-NEXT: ret;
;
@@ -311,7 +307,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fsubx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-NEXT: .reg .b32 %r<4>;
; SM80-NEXT: .reg .f32 %f<7>;
; SM80-EMPTY:
@@ -319,22 +315,20 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
+; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
+; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
+; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
+; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fsubx2(
; SM80-FTZ: {
-; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
; SM80-FTZ-EMPTY:
@@ -342,16 +336,14 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-FTZ-NEXT: ld.param.b32 %r1, [test_fsubx2_param_0];
; SM80-FTZ-NEXT: ld.param.b32 %r2, [test_fsubx2_param_1];
; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f1, %rs2;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f1, %rs1;
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs3;
; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs2;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs4;
; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-FTZ-NEXT: ret;
;
@@ -418,7 +410,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fmulx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-NEXT: .reg .b32 %r<4>;
; SM80-NEXT: .reg .f32 %f<7>;
; SM80-EMPTY:
@@ -426,22 +418,20 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
+; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
+; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
+; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
+; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fmulx2(
; SM80-FTZ: {
-; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
; SM80-FTZ-EMPTY:
@@ -449,16 +439,14 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-FTZ-NEXT: ld.param.b32 %r1, [test_fmulx2_param_0];
; SM80-FTZ-NEXT: ld.param.b32 %r2, [test_fmulx2_param_1];
; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f1, %rs2;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f1, %rs1;
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs3;
; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs2;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs4;
; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-FTZ-NEXT: ret;
;
@@ -525,7 +513,7 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fdiv(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-NEXT: .reg .b32 %r<4>;
; SM80-NEXT: .reg .f32 %f<7>;
; SM80-EMPTY:
@@ -533,22 +521,20 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: ld.param.b32 %r1, [test_fdiv_param_0];
; SM80-NEXT: ld.param.b32 %r2, [test_fdiv_param_1];
; SM80-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-NEXT: cvt.f32.bf16 %f1, %rs2;
+; SM80-NEXT: cvt.f32.bf16 %f1, %rs1;
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
+; SM80-NEXT: cvt.f32.bf16 %f2, %rs3;
; SM80-NEXT: div.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
-; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
+; SM80-NEXT: cvt.f32.bf16 %f4, %rs2;
+; SM80-NEXT: cvt.f32.bf16 %f5, %rs4;
; SM80-NEXT: div.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fdiv(
; SM80-FTZ: {
-; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
; SM80-FTZ-EMPTY:
@@ -556,22 +542,20 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-FTZ-NEXT: ld.param.b32 %r1, [test_fdiv_param_0];
; SM80-FTZ-NEXT: ld.param.b32 %r2, [test_fdiv_param_1];
; SM80-FTZ-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f1, %rs2;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f1, %rs1;
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs3;
; SM80-FTZ-NEXT: div.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
-; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs2;
+; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs4;
; SM80-FTZ-NEXT: div.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-FTZ-NEXT: ret;
;
; SM90-LABEL: test_fdiv(
; SM90: {
-; SM90-NEXT: .reg .b16 %rs<7>;
+; SM90-NEXT: .reg .b16 %rs<5>;
; SM90-NEXT: .reg .b32 %r<4>;
; SM90-NEXT: .reg .f32 %f<7>;
; SM90-EMPTY:
@@ -579,16 +563,14 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM90-NEXT: ld.param.b32 %r1, [test_fdiv_param_0];
; SM90-NEXT: ld.param.b32 %r2, [test_fdiv_param_1];
; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r2;
-; SM90-NEXT: cvt.f32.bf16 %f1, %rs2;
+; SM90-NEXT: cvt.f32.bf16 %f1, %rs1;
; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r1;
-; SM90-NEXT: cvt.f32.bf16 %f2, %rs4;
+; SM90-NEXT: cvt.f32.bf16 %f2, %rs3;
; SM90-NEXT: div.rn.f32 %f3, %f2, %f1;
-; SM90-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
-; SM90-NEXT: cvt.f32.bf16 %f4, %rs1;
-; SM90-NEXT: cvt.f32.bf16 %f5, %rs3;
+; SM90-NEXT: cvt.f32.bf16 %f4, %rs2;
+; SM90-NEXT: cvt.f32.bf16 %f5, %rs4;
; SM90-NEXT: div.rn.f32 %f6, %f5, %f4;
-; SM90-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM90-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM90-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM90-NEXT: st.param.b32 [func_retval0], %r3;
; SM90-NEXT: ret;
%r = fdiv <2 x bfloat> %a, %b
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
index a53c90ac6db8b6..f93825b5a2f6c9 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
@@ -13,9 +13,7 @@ declare <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a) #0
; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
; CHECK-DAG: sin.approx.f32 [[RF0:%f[0-9]+]], [[AF0]];
; CHECK-DAG: sin.approx.f32 [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF1]], [[RF0]]
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
@@ -30,9 +28,7 @@ define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
; CHECK-DAG: cos.approx.f32 [[RF0:%f[0-9]+]], [[AF0]];
; CHECK-DAG: cos.approx.f32 [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF1]], [[RF0]]
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 {
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 925ae4245a4c20..db618a6be01c8a 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -26,9 +26,7 @@ define <2 x bfloat> @test_ret_const() #0 {
; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]
; SM80-DAG: add.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], 0f3F800000;
; SM80-DAG: add.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], 0f40000000;
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]
-; SM80-DAG: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR1]], [[FR0]];
;
; CHECK-NEXT: st.param.b32 [func_retval0], [[R]];
; CHECK-NEXT: ret;
@@ -68,9 +66,7 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
; SM80-DAG: sub.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
; SM80-DAG: sub.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR1]], [[FR0]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
@@ -93,9 +89,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
; SM80-DAG: mul.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
; SM80-DAG: mul.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR1]], [[FR0]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
@@ -116,9 +110,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; CHECK-NEXT: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR1]], [[FR0]];
; CHECK-NEXT: st.param.b32 [func_retval0], [[R]];
; CHECK-NEXT: ret;
@@ -287,9 +279,7 @@ define <2 x bfloat> @test_select_cc_bf16_f32(<2 x bfloat> %a, <2 x bfloat> %b,
; CHECK-LABEL: test_fptrunc_2xfloat(
; CHECK: ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_fptrunc_2xfloat_param_0];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[A1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[A1]], [[A0]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_fptrunc_2xfloat(<2 x float> %a) #0 {
@@ -359,9 +349,7 @@ declare <2 x bfloat> @llvm.fmuladd.f16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bf
; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
; CHECK-DAG: sqrt.rn.f32 [[RF0:%f[0-9]+]], [[AF0]];
; CHECK-DAG: sqrt.rn.f32 [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF1]], [[RF0]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_sqrt(<2 x bfloat> %a) #0 {
@@ -436,9 +424,7 @@ define <2 x bfloat> @test_maxnum(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
; SM80-DAG: cvt.rmi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
; SM80-DAG: cvt.rmi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF1]], [[RF0]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
@@ -455,9 +441,7 @@ define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
; SM80-DAG: cvt.rpi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
; SM80-DAG: cvt.rpi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF1]], [[RF0]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
@@ -470,7 +454,7 @@ define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
; SM90: cvt.rzi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
; SM90: cvt.rzi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
@@ -483,7 +467,7 @@ define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
; SM90: cvt.rni.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
; SM90: cvt.rni.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]]...
[truncated]
|
Yea, inconsistent argument ordering in PTX does make things tricky. I've switched the patterns in this MR to use |
Tests could use the same treatment -- without looking up the docs for nvvm intrinsics it's hard to tell where a/b or f1/f2 are supposed to end up in the instruction arguments. While now I can be confident that the tablegen pattern is correct, I still would have hard time telling if the test patterns check what we want them to check. |
87dba4e
to
1003ab5
Compare
@tomnatan30 would it be possible to confirm on your end that this latest iteration no longer causes issues? |
1003ab5
to
07b2045
Compare
Reland #116109.
Fixes issue where operands were flipped.
Per the PTX spec, a mov instruction packs the first operand as low, and the second operand as high:
On the other hand cvt.rn.f16x2.f32 instructions take high, than low operands: