-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[NVPTX] Add folding for cvt.rn.bf16x2.f32 #116109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVPTX] Add folding for cvt.rn.bf16x2.f32 #116109
Conversation
@llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesFull diff: https://github.com/llvm/llvm-project/pull/116109.diff 5 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 2658ca32716378..b58918bc75df72 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -739,6 +739,20 @@ let hasSideEffects = false in {
def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
}
+def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
+ return N->hasOneUse();
+}]>;
+
+def : Pat<(v2bf16 (build_vector (bf16 (fpround_oneuse Float32Regs:$a)),
+ (bf16 (fpround_oneuse Float32Regs:$b)))),
+ (CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>,
+ Requires<[hasPTX<70>, hasSM<80>, hasBF16Math]>;
+
+def : Pat<(v2f16 (build_vector (f16 (fpround_oneuse Float32Regs:$a)),
+ (f16 (fpround_oneuse Float32Regs:$b)))),
+ (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>,
+ Requires<[hasPTX<70>, hasSM<80>, useFP16Math]>;
+
//-----------------------------------
// Selection instructions (selp)
//-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
index 80815b3ca37c05..ded3019e173b4c 100644
--- a/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -204,7 +204,7 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_faddx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-NEXT: .reg .b32 %r<4>;
; SM80-NEXT: .reg .f32 %f<7>;
; SM80-EMPTY:
@@ -216,18 +216,16 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
; SM80-NEXT: add.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
; SM80-NEXT: add.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_faddx2(
; SM80-FTZ: {
-; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
; SM80-FTZ-EMPTY:
@@ -239,12 +237,10 @@ define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
; SM80-FTZ-NEXT: add.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
; SM80-FTZ-NEXT: add.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-FTZ-NEXT: ret;
;
@@ -311,7 +307,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fsubx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-NEXT: .reg .b32 %r<4>;
; SM80-NEXT: .reg .f32 %f<7>;
; SM80-EMPTY:
@@ -323,18 +319,16 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
; SM80-NEXT: sub.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
; SM80-NEXT: sub.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fsubx2(
; SM80-FTZ: {
-; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
; SM80-FTZ-EMPTY:
@@ -346,12 +340,10 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
; SM80-FTZ-NEXT: sub.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-FTZ-NEXT: ret;
;
@@ -418,7 +410,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fmulx2(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-NEXT: .reg .b32 %r<4>;
; SM80-NEXT: .reg .f32 %f<7>;
; SM80-EMPTY:
@@ -430,18 +422,16 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
; SM80-NEXT: mul.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
; SM80-NEXT: mul.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fmulx2(
; SM80-FTZ: {
-; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
; SM80-FTZ-EMPTY:
@@ -453,12 +443,10 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
; SM80-FTZ-NEXT: mul.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-FTZ-NEXT: ret;
;
@@ -525,7 +513,7 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
;
; SM80-LABEL: test_fdiv(
; SM80: {
-; SM80-NEXT: .reg .b16 %rs<7>;
+; SM80-NEXT: .reg .b16 %rs<5>;
; SM80-NEXT: .reg .b32 %r<4>;
; SM80-NEXT: .reg .f32 %f<7>;
; SM80-EMPTY:
@@ -537,18 +525,16 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM80-NEXT: cvt.f32.bf16 %f2, %rs4;
; SM80-NEXT: div.rn.f32 %f3, %f2, %f1;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM80-NEXT: cvt.f32.bf16 %f4, %rs1;
; SM80-NEXT: cvt.f32.bf16 %f5, %rs3;
; SM80-NEXT: div.rn.f32 %f6, %f5, %f4;
-; SM80-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-NEXT: ret;
;
; SM80-FTZ-LABEL: test_fdiv(
; SM80-FTZ: {
-; SM80-FTZ-NEXT: .reg .b16 %rs<7>;
+; SM80-FTZ-NEXT: .reg .b16 %rs<5>;
; SM80-FTZ-NEXT: .reg .b32 %r<4>;
; SM80-FTZ-NEXT: .reg .f32 %f<7>;
; SM80-FTZ-EMPTY:
@@ -560,18 +546,16 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-FTZ-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f2, %rs4;
; SM80-FTZ-NEXT: div.rn.ftz.f32 %f3, %f2, %f1;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f4, %rs1;
; SM80-FTZ-NEXT: cvt.ftz.f32.bf16 %f5, %rs3;
; SM80-FTZ-NEXT: div.rn.ftz.f32 %f6, %f5, %f4;
-; SM80-FTZ-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM80-FTZ-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM80-FTZ-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM80-FTZ-NEXT: st.param.b32 [func_retval0], %r3;
; SM80-FTZ-NEXT: ret;
;
; SM90-LABEL: test_fdiv(
; SM90: {
-; SM90-NEXT: .reg .b16 %rs<7>;
+; SM90-NEXT: .reg .b16 %rs<5>;
; SM90-NEXT: .reg .b32 %r<4>;
; SM90-NEXT: .reg .f32 %f<7>;
; SM90-EMPTY:
@@ -583,12 +567,10 @@ define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r1;
; SM90-NEXT: cvt.f32.bf16 %f2, %rs4;
; SM90-NEXT: div.rn.f32 %f3, %f2, %f1;
-; SM90-NEXT: cvt.rn.bf16.f32 %rs5, %f3;
; SM90-NEXT: cvt.f32.bf16 %f4, %rs1;
; SM90-NEXT: cvt.f32.bf16 %f5, %rs3;
; SM90-NEXT: div.rn.f32 %f6, %f5, %f4;
-; SM90-NEXT: cvt.rn.bf16.f32 %rs6, %f6;
-; SM90-NEXT: mov.b32 %r3, {%rs6, %rs5};
+; SM90-NEXT: cvt.rn.bf16x2.f32 %r3, %f6, %f3;
; SM90-NEXT: st.param.b32 [func_retval0], %r3;
; SM90-NEXT: ret;
%r = fdiv <2 x bfloat> %a, %b
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
index a53c90ac6db8b6..d4f5ea6158218a 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions-approx.ll
@@ -13,9 +13,7 @@ declare <2 x bfloat> @llvm.cos.f16(<2 x bfloat> %a) #0
; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
; CHECK-DAG: sin.approx.f32 [[RF0:%f[0-9]+]], [[AF0]];
; CHECK-DAG: sin.approx.f32 [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
@@ -30,9 +28,7 @@ define <2 x bfloat> @test_sin(<2 x bfloat> %a) #0 #1 {
; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
; CHECK-DAG: cos.approx.f32 [[RF0:%f[0-9]+]], [[AF0]];
; CHECK-DAG: cos.approx.f32 [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]]
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_cos(<2 x bfloat> %a) #0 #1 {
diff --git a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
index 925ae4245a4c20..e820435f1710f0 100644
--- a/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll
@@ -26,9 +26,7 @@ define <2 x bfloat> @test_ret_const() #0 {
; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]]
; SM80-DAG: add.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], 0f3F800000;
; SM80-DAG: add.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], 0f40000000;
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]]
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]]
-; SM80-DAG: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]];
;
; CHECK-NEXT: st.param.b32 [func_retval0], [[R]];
; CHECK-NEXT: ret;
@@ -68,9 +66,7 @@ define bfloat @test_fadd_imm_1(bfloat %a) #0 {
; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
; SM80-DAG: sub.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
; SM80-DAG: sub.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
@@ -93,9 +89,7 @@ define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
; SM80-DAG: mul.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
; SM80-DAG: mul.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; SM80: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]};
+; SM80-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
@@ -116,9 +110,7 @@ define <2 x bfloat> @test_fmulx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[FR0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[FR1]];
-; CHECK-NEXT: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[FR0]], [[FR1]];
; CHECK-NEXT: st.param.b32 [func_retval0], [[R]];
; CHECK-NEXT: ret;
@@ -287,9 +279,7 @@ define <2 x bfloat> @test_select_cc_bf16_f32(<2 x bfloat> %a, <2 x bfloat> %b,
; CHECK-LABEL: test_fptrunc_2xfloat(
; CHECK: ld.param.v2.f32 {[[A0:%f[0-9]+]], [[A1:%f[0-9]+]]}, [test_fptrunc_2xfloat_param_0];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[A1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[A0]], [[A1]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_fptrunc_2xfloat(<2 x float> %a) #0 {
@@ -359,9 +349,7 @@ declare <2 x bfloat> @llvm.fmuladd.f16(<2 x bfloat> %a, <2 x bfloat> %b, <2 x bf
; CHECK-DAG: cvt.f32.bf16 [[AF1:%f[0-9]+]], [[A1]];
; CHECK-DAG: sqrt.rn.f32 [[RF0:%f[0-9]+]], [[AF0]];
; CHECK-DAG: sqrt.rn.f32 [[RF1:%f[0-9]+]], [[AF1]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-DAG: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_sqrt(<2 x bfloat> %a) #0 {
@@ -436,9 +424,7 @@ define <2 x bfloat> @test_maxnum(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
; SM80-DAG: cvt.rmi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
; SM80-DAG: cvt.rmi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
@@ -455,9 +441,7 @@ define <2 x bfloat> @test_floor(<2 x bfloat> %a) #0 {
; SM80-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
; SM80-DAG: cvt.rpi.f32.f32 [[RF0:%f[0-9]+]], [[FA0]];
; SM80-DAG: cvt.rpi.f32.f32 [[RF1:%f[0-9]+]], [[FA1]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R0:%rs[0-9]+]], [[RF0]];
-; SM80-DAG: cvt.rn.bf16.f32 [[R1:%rs[0-9]+]], [[RF1]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM80: cvt.rn.bf16x2.f32 [[R:%r[0-9]+]], [[RF0]], [[RF1]];
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
@@ -470,7 +454,7 @@ define <2 x bfloat> @test_ceil(<2 x bfloat> %a) #0 {
; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
; SM90: cvt.rzi.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
; SM90: cvt.rzi.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
@@ -483,7 +467,7 @@ define <2 x bfloat> @test_trunc(<2 x bfloat> %a) #0 {
; CHECK-DAG: mov.b32 {[[A0:%rs[0-9]+]], [[A1:%rs[0-9]+]]}, [[A]];
; SM90: cvt.rni.bf16.bf16 [[R1:%rs[0-9]+]], [[A1]];
; SM90: cvt.rni.bf16.bf16 [[R0:%rs[0-9]+]], [[A0]];
-; CHECK: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
+; SM90: mov.b32 [[R:%r[0-9]+]], {[[R0]], [[R1]]}
; CHECK: st.param.b32 [func_retval0], [[R]];
; CHECK: ret;
define <2 x bfloat> @test_rint(<2 x bfloat> %a) #0 {
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
index 4e30cebfe90251..417a3b72f6fde3 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -135,3 +135,24 @@ ret i32 %val
}
declare i32 @llvm.nvvm.f2tf32.rna(float)
+
+
+define <2 x bfloat> @fold_ff2bf16x2(float %a, float %b) {
+; CHECK-LABEL: fold_ff2bf16x2
+; CHECK: cvt.rn.bf16x2.f32
+ %ah = fptrunc float %a to bfloat
+ %bh = fptrunc float %b to bfloat
+ %v0 = insertelement <2 x bfloat> poison, bfloat %ah, i64 0
+ %v1 = insertelement <2 x bfloat> %v0, bfloat %bh, i64 1
+ ret <2 x bfloat> %v1
+}
+
+define <2 x half> @fold_ff2f16x2(float %a, float %b) {
+; CHECK-LABEL: fold_ff2f16x2
+; CHECK: cvt.rn.f16x2.f32
+ %ah = fptrunc float %a to half
+ %bh = fptrunc float %b to half
+ %v0 = insertelement <2 x half> poison, half %ah, i64 0
+ %v1 = insertelement <2 x half> %v0, half %bh, i64 1
+ ret <2 x half> %v1
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice. LGTM modulo test nit.
|
||
|
||
define <2 x bfloat> @fold_ff2bf16x2(float %a, float %b) { | ||
; CHECK-LABEL: fold_ff2bf16x2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may as well automate check generation in this test file, too. For vector construction, theorder of inputs is important and we're not capturing it in this file right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, I've updated the test file.
This change is breaking triton tests (results in huge numeric disparities, e.g. https://github.com/triton-lang/triton/blob/main/python/test/unit/language/test_core.py), we'll need to revert until a fix forward can be merged. |
This reverts commit 90cbd4a.
Reverts #116109 This change is breaking triton tests (results in huge numeric disparities, e.g. https://github.com/triton-lang/triton/blob/main/python/test/unit/language/test_core.py), we'll need to revert until a fix forward can be merged.
Thank you for identifying and reverting this @tomnatan30! |
Apologies for this issue. #116417 should fix. @tomnatan30 please review if you can. |
Reland #116109. Fixes issue where operands were flipped. Per the PTX spec, a mov instruction packs the first operand as low, and the second operand as high: > ``` > // pack two 16-bit elements into .b32 > d = a.x | (a.y << 16) > ``` On the other hand cvt.rn.f16x2.f32 instructions take high, than low operands: > For .f16x2 and .bf16x2 instruction type, two inputs a and b of .f32 type are converted into .f16 or .bf16 type and the converted values are packed in the destination register d, such that the value converted from input a is stored in the upper half of d and the value converted from input b is stored in the lower half of d
No description provided.