Skip to content

[DAGCombine] Add DAG optimisation for BF16_TO_FP #69426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 27, 2023

Conversation

sunshaoce
Copy link
Contributor

Before

slli a0, a0, 48
srli a0, a0, 48
slli a0, a0, 16

After

slli a0, a0, 16

@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Oct 18, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 18, 2023

@llvm/pr-subscribers-llvm-selectiondag

Author: Shao-Ce SUN (sunshaoce)

Changes

Before

slli a0, a0, 48
srli a0, a0, 48
slli a0, a0, 16

After

slli a0, a0, 16

Full diff: https://github.com/llvm/llvm-project/pull/69426.diff

3 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+19)
  • (modified) llvm/test/CodeGen/RISCV/bfloat-convert.ll (-42)
  • (modified) llvm/test/CodeGen/RISCV/bfloat.ll (+4-28)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 2dfdddad3cc389f..1e9d2176befd93b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -545,6 +545,7 @@ namespace {
     SDValue visitFP_TO_FP16(SDNode *N);
     SDValue visitFP16_TO_FP(SDNode *N);
     SDValue visitFP_TO_BF16(SDNode *N);
+    SDValue visitBF16_TO_FP(SDNode *N);
     SDValue visitVECREDUCE(SDNode *N);
     SDValue visitVPOp(SDNode *N);
     SDValue visitGET_FPENV_MEM(SDNode *N);
@@ -1912,6 +1913,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
 
 SDValue DAGCombiner::visit(SDNode *N) {
   switch (N->getOpcode()) {
+    // clang-format off
   default: break;
   case ISD::TokenFactor:        return visitTokenFactor(N);
   case ISD::MERGE_VALUES:       return visitMERGE_VALUES(N);
@@ -2043,6 +2045,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::FP_TO_FP16:         return visitFP_TO_FP16(N);
   case ISD::FP16_TO_FP:         return visitFP16_TO_FP(N);
   case ISD::FP_TO_BF16:         return visitFP_TO_BF16(N);
+  case ISD::BF16_TO_FP:         return visitBF16_TO_FP(N);
   case ISD::FREEZE:             return visitFREEZE(N);
   case ISD::GET_FPENV_MEM:      return visitGET_FPENV_MEM(N);
   case ISD::SET_FPENV_MEM:      return visitSET_FPENV_MEM(N);
@@ -2064,6 +2067,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
 #define BEGIN_REGISTER_VP_SDNODE(SDOPC, ...) case ISD::SDOPC:
 #include "llvm/IR/VPIntrinsics.def"
     return visitVPOp(N);
+    // clang-format on
   }
   return SDValue();
 }
@@ -26219,6 +26223,21 @@ SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
   return SDValue();
 }
 
+SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
+  SDValue N0 = N->getOperand(0);
+
+  // fold bf16_to_fp(op & 0xffff) -> bf16_to_fp(op)
+  if (!TLI.shouldKeepZExtForFP16Conv() && N0->getOpcode() == ISD::AND) {
+    ConstantSDNode *AndConst = getAsNonOpaqueConstant(N0.getOperand(1));
+    if (AndConst && AndConst->getAPIntValue() == 0xffff) {
+      return DAG.getNode(ISD::BF16_TO_FP, SDLoc(N), N->getValueType(0),
+                         N0.getOperand(0));
+    }
+  }
+
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitVECREDUCE(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N0.getValueType();
diff --git a/llvm/test/CodeGen/RISCV/bfloat-convert.ll b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
index 8a0c4240d161bfb..bfa2c3bb4a8ba66 100644
--- a/llvm/test/CodeGen/RISCV/bfloat-convert.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat-convert.ll
@@ -39,8 +39,6 @@ define i16 @fcvt_si_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_si_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -100,8 +98,6 @@ define i16 @fcvt_si_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_si_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    feq.s a0, fa5, fa5
@@ -145,8 +141,6 @@ define i16 @fcvt_ui_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_ui_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -196,8 +190,6 @@ define i16 @fcvt_ui_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-NEXT:    lui a0, %hi(.LCPI3_0)
 ; RV64ID-NEXT:    flw fa5, %lo(.LCPI3_0)(a0)
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa4, a0
 ; RV64ID-NEXT:    fmv.w.x fa3, zero
@@ -235,8 +227,6 @@ define i32 @fcvt_w_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -281,8 +271,6 @@ define i32 @fcvt_w_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.w.s a0, fa5, rtz
@@ -321,8 +309,6 @@ define i32 @fcvt_wu_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -361,8 +347,6 @@ define i32 @fcvt_wu_bf16_multiple_use(bfloat %x, ptr %y) nounwind {
 ; RV64ID-LABEL: fcvt_wu_bf16_multiple_use:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -413,8 +397,6 @@ define i32 @fcvt_wu_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.wu.s a0, fa5, rtz
@@ -463,8 +445,6 @@ define i64 @fcvt_l_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_l_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -606,8 +586,6 @@ define i64 @fcvt_l_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_l_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -654,8 +632,6 @@ define i64 @fcvt_lu_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_lu_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -730,8 +706,6 @@ define i64 @fcvt_lu_bf16_sat(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_lu_bf16_sat:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -1200,8 +1174,6 @@ define float @fcvt_s_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_s_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa0, a0
 ; RV64ID-NEXT:    ret
@@ -1313,8 +1285,6 @@ define double @fcvt_d_bf16(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_d_bf16:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.d.s fa0, fa5
@@ -1521,8 +1491,6 @@ define signext i8 @fcvt_w_s_i8(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_s_i8:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.l.s a0, fa5, rtz
@@ -1582,8 +1550,6 @@ define signext i8 @fcvt_w_s_sat_i8(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_s_sat_i8:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    feq.s a0, fa5, fa5
@@ -1627,8 +1593,6 @@ define zeroext i8 @fcvt_wu_s_i8(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_s_i8:
 ; RV64ID:       # %bb.0:
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.lu.s a0, fa5, rtz
@@ -1676,8 +1640,6 @@ define zeroext i8 @fcvt_wu_s_sat_i8(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_s_sat_i8:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fmv.w.x fa4, zero
@@ -1731,8 +1693,6 @@ define zeroext i32 @fcvt_wu_bf16_sat_zext(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_wu_bf16_sat_zext:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.wu.s a0, fa5, rtz
@@ -1784,8 +1744,6 @@ define signext i32 @fcvt_w_bf16_sat_sext(bfloat %a) nounwind {
 ; RV64ID-LABEL: fcvt_w_bf16_sat_sext:
 ; RV64ID:       # %bb.0: # %start
 ; RV64ID-NEXT:    fmv.x.w a0, fa0
-; RV64ID-NEXT:    slli a0, a0, 48
-; RV64ID-NEXT:    srli a0, a0, 48
 ; RV64ID-NEXT:    slli a0, a0, 16
 ; RV64ID-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-NEXT:    fcvt.w.s a0, fa5, rtz
diff --git a/llvm/test/CodeGen/RISCV/bfloat.ll b/llvm/test/CodeGen/RISCV/bfloat.ll
index 5013f76f9b0b33a..d62f35388123f7c 100644
--- a/llvm/test/CodeGen/RISCV/bfloat.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat.ll
@@ -164,8 +164,6 @@ define float @bfloat_to_float(bfloat %a) nounwind {
 ;
 ; RV64ID-LP64-LABEL: bfloat_to_float:
 ; RV64ID-LP64:       # %bb.0:
-; RV64ID-LP64-NEXT:    slli a0, a0, 48
-; RV64ID-LP64-NEXT:    srli a0, a0, 48
 ; RV64ID-LP64-NEXT:    slli a0, a0, 16
 ; RV64ID-LP64-NEXT:    ret
 ;
@@ -179,8 +177,6 @@ define float @bfloat_to_float(bfloat %a) nounwind {
 ; RV64ID-LP64D-LABEL: bfloat_to_float:
 ; RV64ID-LP64D:       # %bb.0:
 ; RV64ID-LP64D-NEXT:    fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT:    slli a0, a0, 48
-; RV64ID-LP64D-NEXT:    srli a0, a0, 48
 ; RV64ID-LP64D-NEXT:    slli a0, a0, 16
 ; RV64ID-LP64D-NEXT:    fmv.w.x fa0, a0
 ; RV64ID-LP64D-NEXT:    ret
@@ -223,8 +219,6 @@ define double @bfloat_to_double(bfloat %a) nounwind {
 ;
 ; RV64ID-LP64-LABEL: bfloat_to_double:
 ; RV64ID-LP64:       # %bb.0:
-; RV64ID-LP64-NEXT:    slli a0, a0, 48
-; RV64ID-LP64-NEXT:    srli a0, a0, 48
 ; RV64ID-LP64-NEXT:    slli a0, a0, 16
 ; RV64ID-LP64-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-LP64-NEXT:    fcvt.d.s fa5, fa5
@@ -242,8 +236,6 @@ define double @bfloat_to_double(bfloat %a) nounwind {
 ; RV64ID-LP64D-LABEL: bfloat_to_double:
 ; RV64ID-LP64D:       # %bb.0:
 ; RV64ID-LP64D-NEXT:    fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT:    slli a0, a0, 48
-; RV64ID-LP64D-NEXT:    srli a0, a0, 48
 ; RV64ID-LP64D-NEXT:    slli a0, a0, 16
 ; RV64ID-LP64D-NEXT:    fmv.w.x fa5, a0
 ; RV64ID-LP64D-NEXT:    fcvt.d.s fa0, fa5
@@ -366,10 +358,6 @@ define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
 ; RV64ID-LP64:       # %bb.0:
 ; RV64ID-LP64-NEXT:    addi sp, sp, -16
 ; RV64ID-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
-; RV64ID-LP64-NEXT:    lui a2, 16
-; RV64ID-LP64-NEXT:    addi a2, a2, -1
-; RV64ID-LP64-NEXT:    and a0, a0, a2
-; RV64ID-LP64-NEXT:    and a1, a1, a2
 ; RV64ID-LP64-NEXT:    slli a1, a1, 16
 ; RV64ID-LP64-NEXT:    fmv.w.x fa5, a1
 ; RV64ID-LP64-NEXT:    slli a0, a0, 16
@@ -408,11 +396,7 @@ define bfloat @bfloat_add(bfloat %a, bfloat %b) nounwind {
 ; RV64ID-LP64D-NEXT:    addi sp, sp, -16
 ; RV64ID-LP64D-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
 ; RV64ID-LP64D-NEXT:    fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT:    lui a1, 16
-; RV64ID-LP64D-NEXT:    addi a1, a1, -1
-; RV64ID-LP64D-NEXT:    and a0, a0, a1
-; RV64ID-LP64D-NEXT:    fmv.x.w a2, fa1
-; RV64ID-LP64D-NEXT:    and a1, a2, a1
+; RV64ID-LP64D-NEXT:    fmv.x.w a1, fa1
 ; RV64ID-LP64D-NEXT:    slli a1, a1, 16
 ; RV64ID-LP64D-NEXT:    fmv.w.x fa5, a1
 ; RV64ID-LP64D-NEXT:    slli a0, a0, 16
@@ -604,12 +588,8 @@ define void @bfloat_store(ptr %a, bfloat %b, bfloat %c) nounwind {
 ; RV64ID-LP64-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
 ; RV64ID-LP64-NEXT:    sd s0, 0(sp) # 8-byte Folded Spill
 ; RV64ID-LP64-NEXT:    mv s0, a0
-; RV64ID-LP64-NEXT:    lui a0, 16
-; RV64ID-LP64-NEXT:    addi a0, a0, -1
-; RV64ID-LP64-NEXT:    and a1, a1, a0
-; RV64ID-LP64-NEXT:    and a0, a2, a0
-; RV64ID-LP64-NEXT:    slli a0, a0, 16
-; RV64ID-LP64-NEXT:    fmv.w.x fa5, a0
+; RV64ID-LP64-NEXT:    slli a2, a2, 16
+; RV64ID-LP64-NEXT:    fmv.w.x fa5, a2
 ; RV64ID-LP64-NEXT:    slli a1, a1, 16
 ; RV64ID-LP64-NEXT:    fmv.w.x fa4, a1
 ; RV64ID-LP64-NEXT:    fadd.s fa5, fa4, fa5
@@ -651,11 +631,7 @@ define void @bfloat_store(ptr %a, bfloat %b, bfloat %c) nounwind {
 ; RV64ID-LP64D-NEXT:    sd s0, 0(sp) # 8-byte Folded Spill
 ; RV64ID-LP64D-NEXT:    mv s0, a0
 ; RV64ID-LP64D-NEXT:    fmv.x.w a0, fa0
-; RV64ID-LP64D-NEXT:    lui a1, 16
-; RV64ID-LP64D-NEXT:    addi a1, a1, -1
-; RV64ID-LP64D-NEXT:    and a0, a0, a1
-; RV64ID-LP64D-NEXT:    fmv.x.w a2, fa1
-; RV64ID-LP64D-NEXT:    and a1, a2, a1
+; RV64ID-LP64D-NEXT:    fmv.x.w a1, fa1
 ; RV64ID-LP64D-NEXT:    slli a1, a1, 16
 ; RV64ID-LP64D-NEXT:    fmv.w.x fa5, a1
 ; RV64ID-LP64D-NEXT:    slli a0, a0, 16

@@ -26219,6 +26223,21 @@ SDValue DAGCombiner::visitFP_TO_BF16(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::visitBF16_TO_FP(SDNode *N) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Combine this with visitFP16_TO_FP above.

But couldn't SimplifyDemandedBits handle this optimization in a more generic way?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SimplifyDemandedBits should handle shl(lshr(x,c),c) mask patterns already - are there bitcasts in the way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried SimplifyDemandedBits, but it didn't work, maybe there's something wrong with my code.

SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
  SDValue N0 = N->getOperand(0);
  APInt DemandedBits = APInt::getAllOnes(N0.getScalarValueSizeInBits());
  if (SimplifyDemandedBits(N0, DemandedBits)) {
    if (N->getOpcode() != ISD::DELETED_NODE)
      AddToWorklist(N);
    return SDValue(N, 0);
  }

  return SDValue();
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to use APInt DemandedBits = APInt::getLoBitsSet(N0.getScalarValueSizeInBits(), 16); so that only the lower 16 bits are demanded.

@sunshaoce
Copy link
Contributor Author

ping

@@ -2043,6 +2044,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::FP_TO_FP16: return visitFP_TO_FP16(N);
case ISD::FP16_TO_FP: return visitFP16_TO_FP(N);
case ISD::FP_TO_BF16: return visitFP_TO_BF16(N);
case ISD::BF16_TO_FP: return visitFP16_TO_FP(N);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should combine the cases with the other call to visitFP16_TO_FP. I did have to check the body to make sure this was OK, so should also rename the function to reflect it handles both cases

@sunshaoce sunshaoce merged commit 9f6bf00 into llvm:main Dec 27, 2023
@sunshaoce sunshaoce deleted the BF16_TO_FP branch December 27, 2023 09:20
@@ -26137,14 +26139,17 @@ SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
}

SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arsenm didn't you request this function to be renamed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants