Skip to content

[RISCV] Add FMA support to combineOp_VLToVWOp_VL. #100454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2024
Merged

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Jul 24, 2024

This adds FMA to the widening web support we have for add, sub, mul, and shl.

Extra care needs to be taken to not widen the third FMA operand.

This adds FMA to the widening web support we have for add, sub, mul,
and shl.

Extra care needs to be taken to not widen the third FMA operand.
@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

This adds FMA to the widening web support we have for add, sub, mul, and shl.

Extra care needs to be taken to not widen the third FMA operand.


Full diff: https://github.com/llvm/llvm-project/pull/100454.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+56-66)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll (+109)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll (+4-16)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d40d4997d7614..7a657e481d9b1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14328,6 +14328,14 @@ struct NodeExtensionHelper {
       return RISCVISD::VFWSUB_VL;
     case RISCVISD::FMUL_VL:
       return RISCVISD::VFWMUL_VL;
+    case RISCVISD::VFMADD_VL:
+      return RISCVISD::VFWMADD_VL;
+    case RISCVISD::VFMSUB_VL:
+      return RISCVISD::VFWMSUB_VL;
+    case RISCVISD::VFNMADD_VL:
+      return RISCVISD::VFWNMADD_VL;
+    case RISCVISD::VFNMSUB_VL:
+      return RISCVISD::VFWNMSUB_VL;
     default:
       llvm_unreachable("Unexpected opcode");
     }
@@ -14521,6 +14529,11 @@ struct NodeExtensionHelper {
              Subtarget.hasStdExtZvbb();
     case RISCVISD::SHL_VL:
       return Subtarget.hasStdExtZvbb();
+    case RISCVISD::VFMADD_VL:
+    case RISCVISD::VFNMSUB_VL:
+    case RISCVISD::VFNMADD_VL:
+    case RISCVISD::VFMSUB_VL:
+      return true;
     default:
       return false;
     }
@@ -14601,6 +14614,10 @@ struct NodeExtensionHelper {
     case RISCVISD::FADD_VL:
     case RISCVISD::FMUL_VL:
     case RISCVISD::VFWADD_W_VL:
+    case RISCVISD::VFMADD_VL:
+    case RISCVISD::VFNMSUB_VL:
+    case RISCVISD::VFNMADD_VL:
+    case RISCVISD::VFMSUB_VL:
       return true;
     case ISD::SUB:
     case RISCVISD::SUB_VL:
@@ -14816,6 +14833,10 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
     Strategies.push_back(canFoldToVW_W);
     break;
   case RISCVISD::FMUL_VL:
+  case RISCVISD::VFMADD_VL:
+  case RISCVISD::VFMSUB_VL:
+  case RISCVISD::VFNMADD_VL:
+  case RISCVISD::VFNMSUB_VL:
     Strategies.push_back(canFoldToVWWithSameExtension);
     break;
   case ISD::MUL:
@@ -14852,7 +14873,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 }
 } // End anonymous namespace.
 
-/// Combine a binary operation to its equivalent VW or VW_W form.
+/// Combine a binary or FMA operation to its equivalent VW or VW_W form.
 /// The supported combines are:
 /// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
 /// sub | sub_vl -> vwsub(u) | vwsub(u)_w
@@ -14865,9 +14886,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
 /// vwsub_w(u) -> vwsub(u)
 /// vfwadd_w -> vfwadd
 /// vfwsub_w -> vfwsub
-static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
-                                           TargetLowering::DAGCombinerInfo &DCI,
-                                           const RISCVSubtarget &Subtarget) {
+static SDValue combineOp_VLToVWOp_VL(SDNode *N,
+                                     TargetLowering::DAGCombinerInfo &DCI,
+                                     const RISCVSubtarget &Subtarget) {
   SelectionDAG &DAG = DCI.DAG;
   if (DCI.isBeforeLegalize())
     return SDValue();
@@ -14883,19 +14904,26 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
 
   while (!Worklist.empty()) {
     SDNode *Root = Worklist.pop_back_val();
-    if (!NodeExtensionHelper::isSupportedRoot(Root, Subtarget))
-      return SDValue();
 
     NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
     NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
-    auto AppendUsersIfNeeded = [&Worklist,
+    auto AppendUsersIfNeeded = [&Worklist, &Subtarget,
                                 &Inserted](const NodeExtensionHelper &Op) {
       if (Op.needToPromoteOtherUsers()) {
-        for (SDNode *TheUse : Op.OrigOperand->uses()) {
+        for (SDNode::use_iterator UI = Op.OrigOperand->use_begin(),
+                                  UE = Op.OrigOperand->use_end();
+             UI != UE; ++UI) {
+          SDNode *TheUse = *UI;
+          if (!NodeExtensionHelper::isSupportedRoot(TheUse, Subtarget))
+            return false;
+          // We only support the first 2 operands of FMA.
+          if (UI.getOperandNo() >= 2)
+            return false;
           if (Inserted.insert(TheUse).second)
             Worklist.push_back(TheUse);
         }
       }
+      return true;
     };
 
     // Control the compile time by limiting the number of node we look at in
@@ -14923,9 +14951,11 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
           // we would be leaving the old input (since it is may still be used),
           // and the new one.
           if (Res->LHSExt.has_value())
-            AppendUsersIfNeeded(LHS);
+            if (!AppendUsersIfNeeded(LHS))
+              return SDValue();
           if (Res->RHSExt.has_value())
-            AppendUsersIfNeeded(RHS);
+            if (!AppendUsersIfNeeded(RHS))
+              return SDValue();
           break;
         }
       }
@@ -15012,7 +15042,7 @@ static SDValue performVWADDSUBW_VLCombine(SDNode *N,
   assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL ||
          Opc == RISCVISD::VWSUB_W_VL || Opc == RISCVISD::VWSUBU_W_VL);
 
-  if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+  if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
     return V;
 
   return combineVWADDSUBWSelect(N, DCI.DAG);
@@ -15427,8 +15457,11 @@ static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
                      VL);
 }
 
-static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
+static SDValue performVFMADD_VLCombine(SDNode *N,
+                                       TargetLowering::DAGCombinerInfo &DCI,
                                        const RISCVSubtarget &Subtarget) {
+  SelectionDAG &DAG = DCI.DAG;
+
   if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
     return V;
 
@@ -15440,50 +15473,7 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
   if (N->isTargetStrictFPOpcode())
     return SDValue();
 
-  // Try to form widening FMA.
-  SDValue Op0 = N->getOperand(0);
-  SDValue Op1 = N->getOperand(1);
-  SDValue Mask = N->getOperand(3);
-  SDValue VL = N->getOperand(4);
-
-  if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
-      Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
-    return SDValue();
-
-  // TODO: Refactor to handle more complex cases similar to
-  // combineBinOp_VLToVWBinOp_VL.
-  if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
-      (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
-    return SDValue();
-
-  // Check the mask and VL are the same.
-  if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
-      Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
-    return SDValue();
-
-  unsigned NewOpc;
-  switch (N->getOpcode()) {
-  default:
-    llvm_unreachable("Unexpected opcode");
-  case RISCVISD::VFMADD_VL:
-    NewOpc = RISCVISD::VFWMADD_VL;
-    break;
-  case RISCVISD::VFNMSUB_VL:
-    NewOpc = RISCVISD::VFWNMSUB_VL;
-    break;
-  case RISCVISD::VFNMADD_VL:
-    NewOpc = RISCVISD::VFWNMADD_VL;
-    break;
-  case RISCVISD::VFMSUB_VL:
-    NewOpc = RISCVISD::VFWMSUB_VL;
-    break;
-  }
-
-  Op0 = Op0.getOperand(0);
-  Op1 = Op1.getOperand(0);
-
-  return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
-                     N->getOperand(2), Mask, VL);
+  return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
 }
 
 static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
@@ -16680,28 +16670,28 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     break;
   }
   case ISD::ADD: {
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+    if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
     if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
       return V;
     return performADDCombine(N, DCI, Subtarget);
   }
   case ISD::SUB: {
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+    if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
     return performSUBCombine(N, DAG, Subtarget);
   }
   case ISD::AND:
     return performANDCombine(N, DCI, Subtarget);
   case ISD::OR: {
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+    if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
     return performORCombine(N, DCI, Subtarget);
   }
   case ISD::XOR:
     return performXORCombine(N, DAG, Subtarget);
   case ISD::MUL:
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+    if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
     return performMULCombine(N, DAG, DCI, Subtarget);
   case ISD::SDIV:
@@ -17126,7 +17116,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     break;
   }
   case RISCVISD::SHL_VL:
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+    if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
     [[fallthrough]];
   case RISCVISD::SRA_VL:
@@ -17151,7 +17141,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::SRL:
   case ISD::SHL: {
     if (N->getOpcode() == ISD::SHL) {
-      if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+      if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
         return V;
     }
     SDValue ShAmt = N->getOperand(1);
@@ -17167,7 +17157,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     break;
   }
   case RISCVISD::ADD_VL:
-    if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+    if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
       return V;
     return combineToVWMACC(N, DAG, Subtarget);
   case RISCVISD::VWADD_W_VL:
@@ -17177,7 +17167,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     return performVWADDSUBW_VLCombine(N, DCI, Subtarget);
   case RISCVISD::SUB_VL:
   case RISCVISD::MUL_VL:
-    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+    return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
   case RISCVISD::VFMADD_VL:
   case RISCVISD::VFNMADD_VL:
   case RISCVISD::VFMSUB_VL:
@@ -17186,7 +17176,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
   case RISCVISD::STRICT_VFNMADD_VL:
   case RISCVISD::STRICT_VFMSUB_VL:
   case RISCVISD::STRICT_VFNMSUB_VL:
-    return performVFMADD_VLCombine(N, DAG, Subtarget);
+    return performVFMADD_VLCombine(N, DCI, Subtarget);
   case RISCVISD::FADD_VL:
   case RISCVISD::FSUB_VL:
   case RISCVISD::FMUL_VL:
@@ -17195,7 +17185,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
         !Subtarget.hasVInstructionsF16())
       return SDValue();
-    return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+    return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
   }
   case ISD::LOAD:
   case ISD::STORE: {
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
index 3a99f535e9071..cb50ca4a72120 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
@@ -97,3 +97,112 @@ define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a,
   store <2 x double> %g, ptr %z
   ret void
 }
+
+define void @vfwmacc_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2, <2 x double> %w) {
+; NO_FOLDING-LABEL: vfwmacc_v2f32_multiple_users:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v12, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT:    vfmul.vv v10, v12, v8
+; NO_FOLDING-NEXT:    vfmadd.vv v12, v9, v11
+; NO_FOLDING-NEXT:    vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT:    vse64.v v10, (a0)
+; NO_FOLDING-NEXT:    vse64.v v12, (a1)
+; NO_FOLDING-NEXT:    vse64.v v8, (a2)
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vfwmacc_v2f32_multiple_users:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT:    vfwmul.vv v12, v8, v9
+; FOLDING-NEXT:    vfwmacc.vv v11, v8, v10
+; FOLDING-NEXT:    vfwsub.vv v8, v9, v10
+; FOLDING-NEXT:    vse64.v v12, (a0)
+; FOLDING-NEXT:    vse64.v v11, (a1)
+; FOLDING-NEXT:    vse64.v v8, (a2)
+; FOLDING-NEXT:    ret
+  %c = fpext <2 x float> %a to <2 x double>
+  %d = fpext <2 x float> %b to <2 x double>
+  %d2 = fpext <2 x float> %b2 to <2 x double>
+  %e = fmul <2 x double> %c, %d
+  %f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d2, <2 x double> %w)
+  %g = fsub <2 x double> %d, %d2
+  store <2 x double> %e, ptr %x
+  store <2 x double> %f, ptr %y
+  store <2 x double> %g, ptr %z
+  ret void
+}
+
+; Negative test. We can't fold because the FMA addend is a user.
+define void @vfwmacc_v2f32_multiple_users_addend_user(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmacc_v2f32_multiple_users_addend_user:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT:    vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT:    vfmadd.vv v11, v9, v8
+; NO_FOLDING-NEXT:    vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT:    vse64.v v10, (a0)
+; NO_FOLDING-NEXT:    vse64.v v11, (a1)
+; NO_FOLDING-NEXT:    vse64.v v8, (a2)
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vfwmacc_v2f32_multiple_users_addend_user:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT:    vfwcvt.f.f.v v11, v8
+; FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; FOLDING-NEXT:    vfwcvt.f.f.v v9, v10
+; FOLDING-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; FOLDING-NEXT:    vfmul.vv v10, v11, v8
+; FOLDING-NEXT:    vfmadd.vv v11, v9, v8
+; FOLDING-NEXT:    vfsub.vv v8, v8, v9
+; FOLDING-NEXT:    vse64.v v10, (a0)
+; FOLDING-NEXT:    vse64.v v11, (a1)
+; FOLDING-NEXT:    vse64.v v8, (a2)
+; FOLDING-NEXT:    ret
+  %c = fpext <2 x float> %a to <2 x double>
+  %d = fpext <2 x float> %b to <2 x double>
+  %d2 = fpext <2 x float> %b2 to <2 x double>
+  %e = fmul <2 x double> %c, %d
+  %f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d2, <2 x double> %d)
+  %g = fsub <2 x double> %d, %d2
+  store <2 x double> %e, ptr %x
+  store <2 x double> %f, ptr %y
+  store <2 x double> %g, ptr %z
+  ret void
+}
+
+; Negative test. We can't fold because the FMA addend is a user.
+define void @vfwmacc_v2f32_addend_user(ptr %x, <2 x float> %a, <2 x float> %b) {
+; NO_FOLDING-LABEL: vfwmacc_v2f32_addend_user:
+; NO_FOLDING:       # %bb.0:
+; NO_FOLDING-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v10, v8
+; NO_FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT:    vfmadd.vv v8, v10, v8
+; NO_FOLDING-NEXT:    vse64.v v8, (a0)
+; NO_FOLDING-NEXT:    ret
+;
+; FOLDING-LABEL: vfwmacc_v2f32_addend_user:
+; FOLDING:       # %bb.0:
+; FOLDING-NEXT:    vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT:    vfwcvt.f.f.v v10, v8
+; FOLDING-NEXT:    vfwcvt.f.f.v v8, v9
+; FOLDING-NEXT:    vsetvli zero, zero, e64, m1, ta, ma
+; FOLDING-NEXT:    vfmadd.vv v8, v10, v8
+; FOLDING-NEXT:    vse64.v v8, (a0)
+; FOLDING-NEXT:    ret
+  %c = fpext <2 x float> %a to <2 x double>
+  %d = fpext <2 x float> %b to <2 x double>
+  %f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d, <2 x double> %d)
+  store <2 x double> %f, ptr %x
+  ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll
index 1803b52aca674..5140d89b78307 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll
@@ -2031,11 +2031,8 @@ define <8 x double> @vfwnmsac_fv_v8f64_v8f16(<8 x double> %va, <8 x half> %vb, h
 define <2 x float> @vfwmacc_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c) {
 ; CHECK-LABEL: vfwmacc_vf2_v2f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    fcvt.s.h fa5, fa0
 ; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v10, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vfmacc.vf v8, fa5, v10
+; CHECK-NEXT:    vfwmacc.vf v8, fa0, v9
 ; CHECK-NEXT:    ret
   %cext = fpext half %c to float
   %head = insertelement <2 x float> poison, float %cext, i32 0
@@ -2048,11 +2045,8 @@ define <2 x float> @vfwmacc_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c)
 define <2 x float> @vfwmsac_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c) {
 ; CHECK-LABEL: vfwmsac_vf2_v2f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    fcvt.s.h fa5, fa0
 ; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v10, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vfmsac.vf v8, fa5, v10
+; CHECK-NEXT:    vfwmsac.vf v8, fa0, v9
 ; CHECK-NEXT:    ret
   %cext = fpext half %c to float
   %head = insertelement <2 x float> poison, float %cext, i32 0
@@ -2066,11 +2060,8 @@ define <2 x float> @vfwmsac_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c)
 define <2 x float> @vfwnmacc_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c) {
 ; CHECK-LABEL: vfwnmacc_vf2_v2f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    fcvt.s.h fa5, fa0
 ; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v10, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vfnmacc.vf v8, fa5, v10
+; CHECK-NEXT:    vfwnmacc.vf v8, fa0, v9
 ; CHECK-NEXT:    ret
   %cext = fpext half %c to float
   %head = insertelement <2 x float> poison, float %cext, i32 0
@@ -2085,11 +2076,8 @@ define <2 x float> @vfwnmacc_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c)
 define <2 x float> @vfwnmsac_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c) {
 ; CHECK-LABEL: vfwnmsac_vf2_v2f32:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    fcvt.s.h fa5, fa0
 ; CHECK-NEXT:    vsetivli zero, 2, e16, mf4, ta, ma
-; CHECK-NEXT:    vfwcvt.f.f.v v10, v9
-; CHECK-NEXT:    vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT:    vfnmsac.vf v8, fa5, v10
+; CHECK-NEXT:    vfwnmsac.vf v8, fa0, v9
 ; CHECK-NEXT:    ret
   %cext = fpext half %c to float
   %head = insertelement <2 x float> poison, float %cext, i32 0

if (!NodeExtensionHelper::isSupportedRoot(TheUse, Subtarget))
return false;
// We only support the first 2 operands of FMA.
if (UI.getOperandNo() >= 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that this is an FMA node?
I am wondering if it would make sense to assert if someone update isSupportedRoot and forget to update this part of the code.

What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had that in the code originally with an isFMA function. I removed it because it was the only use of the function. I figured if anyone added a new supported root that didn't use only operand 0 and 1, they'd also have to change the creation of the two NodeExtensionHelper objects 12 or so lines above this. So maybe they'd notice they needed to change this too.

Upon further review, I just realized that all of the binary ops have a third passthru vector operand that we should have been checking for and excluding all along. It's very often undef so it will take a more detailed review to figure out how to test that case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember that we checked for the passthru operand being undef.
Is it something we somehow removed or do I just misremember?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only undef check I see so far is in fillUpExtensionSupportForSplat, but that's only for RISCVISD::VMV_V_X_VL

Copy link
Member

@sun-jacobi sun-jacobi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :)

@topperc topperc merged commit b582b65 into llvm:main Jul 26, 2024
9 checks passed
@topperc topperc deleted the pr/fma-web branch July 26, 2024 16:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants