-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[RISCV] Add FMA support to combineOp_VLToVWOp_VL. #100454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This adds FMA to the widening web support we have for add, sub, mul, and shl. Extra care needs to be taken to not widen the third FMA operand.
@llvm/pr-subscribers-backend-risc-v Author: Craig Topper (topperc) ChangesThis adds FMA to the widening web support we have for add, sub, mul, and shl. Extra care needs to be taken to not widen the third FMA operand. Full diff: https://github.com/llvm/llvm-project/pull/100454.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d40d4997d7614..7a657e481d9b1 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -14328,6 +14328,14 @@ struct NodeExtensionHelper {
return RISCVISD::VFWSUB_VL;
case RISCVISD::FMUL_VL:
return RISCVISD::VFWMUL_VL;
+ case RISCVISD::VFMADD_VL:
+ return RISCVISD::VFWMADD_VL;
+ case RISCVISD::VFMSUB_VL:
+ return RISCVISD::VFWMSUB_VL;
+ case RISCVISD::VFNMADD_VL:
+ return RISCVISD::VFWNMADD_VL;
+ case RISCVISD::VFNMSUB_VL:
+ return RISCVISD::VFWNMSUB_VL;
default:
llvm_unreachable("Unexpected opcode");
}
@@ -14521,6 +14529,11 @@ struct NodeExtensionHelper {
Subtarget.hasStdExtZvbb();
case RISCVISD::SHL_VL:
return Subtarget.hasStdExtZvbb();
+ case RISCVISD::VFMADD_VL:
+ case RISCVISD::VFNMSUB_VL:
+ case RISCVISD::VFNMADD_VL:
+ case RISCVISD::VFMSUB_VL:
+ return true;
default:
return false;
}
@@ -14601,6 +14614,10 @@ struct NodeExtensionHelper {
case RISCVISD::FADD_VL:
case RISCVISD::FMUL_VL:
case RISCVISD::VFWADD_W_VL:
+ case RISCVISD::VFMADD_VL:
+ case RISCVISD::VFNMSUB_VL:
+ case RISCVISD::VFNMADD_VL:
+ case RISCVISD::VFMSUB_VL:
return true;
case ISD::SUB:
case RISCVISD::SUB_VL:
@@ -14816,6 +14833,10 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
Strategies.push_back(canFoldToVW_W);
break;
case RISCVISD::FMUL_VL:
+ case RISCVISD::VFMADD_VL:
+ case RISCVISD::VFMSUB_VL:
+ case RISCVISD::VFNMADD_VL:
+ case RISCVISD::VFNMSUB_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
break;
case ISD::MUL:
@@ -14852,7 +14873,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
}
} // End anonymous namespace.
-/// Combine a binary operation to its equivalent VW or VW_W form.
+/// Combine a binary or FMA operation to its equivalent VW or VW_W form.
/// The supported combines are:
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
@@ -14865,9 +14886,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// vwsub_w(u) -> vwsub(u)
/// vfwadd_w -> vfwadd
/// vfwsub_w -> vfwsub
-static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI,
- const RISCVSubtarget &Subtarget) {
+static SDValue combineOp_VLToVWOp_VL(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;
if (DCI.isBeforeLegalize())
return SDValue();
@@ -14883,19 +14904,26 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
while (!Worklist.empty()) {
SDNode *Root = Worklist.pop_back_val();
- if (!NodeExtensionHelper::isSupportedRoot(Root, Subtarget))
- return SDValue();
NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
- auto AppendUsersIfNeeded = [&Worklist,
+ auto AppendUsersIfNeeded = [&Worklist, &Subtarget,
&Inserted](const NodeExtensionHelper &Op) {
if (Op.needToPromoteOtherUsers()) {
- for (SDNode *TheUse : Op.OrigOperand->uses()) {
+ for (SDNode::use_iterator UI = Op.OrigOperand->use_begin(),
+ UE = Op.OrigOperand->use_end();
+ UI != UE; ++UI) {
+ SDNode *TheUse = *UI;
+ if (!NodeExtensionHelper::isSupportedRoot(TheUse, Subtarget))
+ return false;
+ // We only support the first 2 operands of FMA.
+ if (UI.getOperandNo() >= 2)
+ return false;
if (Inserted.insert(TheUse).second)
Worklist.push_back(TheUse);
}
}
+ return true;
};
// Control the compile time by limiting the number of node we look at in
@@ -14923,9 +14951,11 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
// we would be leaving the old input (since it is may still be used),
// and the new one.
if (Res->LHSExt.has_value())
- AppendUsersIfNeeded(LHS);
+ if (!AppendUsersIfNeeded(LHS))
+ return SDValue();
if (Res->RHSExt.has_value())
- AppendUsersIfNeeded(RHS);
+ if (!AppendUsersIfNeeded(RHS))
+ return SDValue();
break;
}
}
@@ -15012,7 +15042,7 @@ static SDValue performVWADDSUBW_VLCombine(SDNode *N,
assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL ||
Opc == RISCVISD::VWSUB_W_VL || Opc == RISCVISD::VWSUBU_W_VL);
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return combineVWADDSUBWSelect(N, DCI.DAG);
@@ -15427,8 +15457,11 @@ static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
VL);
}
-static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
+static SDValue performVFMADD_VLCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
+ SelectionDAG &DAG = DCI.DAG;
+
if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
return V;
@@ -15440,50 +15473,7 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
if (N->isTargetStrictFPOpcode())
return SDValue();
- // Try to form widening FMA.
- SDValue Op0 = N->getOperand(0);
- SDValue Op1 = N->getOperand(1);
- SDValue Mask = N->getOperand(3);
- SDValue VL = N->getOperand(4);
-
- if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
- Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
- return SDValue();
-
- // TODO: Refactor to handle more complex cases similar to
- // combineBinOp_VLToVWBinOp_VL.
- if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
- (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
- return SDValue();
-
- // Check the mask and VL are the same.
- if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
- Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
- return SDValue();
-
- unsigned NewOpc;
- switch (N->getOpcode()) {
- default:
- llvm_unreachable("Unexpected opcode");
- case RISCVISD::VFMADD_VL:
- NewOpc = RISCVISD::VFWMADD_VL;
- break;
- case RISCVISD::VFNMSUB_VL:
- NewOpc = RISCVISD::VFWNMSUB_VL;
- break;
- case RISCVISD::VFNMADD_VL:
- NewOpc = RISCVISD::VFWNMADD_VL;
- break;
- case RISCVISD::VFMSUB_VL:
- NewOpc = RISCVISD::VFWMSUB_VL;
- break;
- }
-
- Op0 = Op0.getOperand(0);
- Op1 = Op1.getOperand(0);
-
- return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
- N->getOperand(2), Mask, VL);
+ return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
}
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
@@ -16680,28 +16670,28 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case ISD::ADD: {
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
return V;
return performADDCombine(N, DCI, Subtarget);
}
case ISD::SUB: {
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return performSUBCombine(N, DAG, Subtarget);
}
case ISD::AND:
return performANDCombine(N, DCI, Subtarget);
case ISD::OR: {
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return performORCombine(N, DCI, Subtarget);
}
case ISD::XOR:
return performXORCombine(N, DAG, Subtarget);
case ISD::MUL:
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return performMULCombine(N, DAG, DCI, Subtarget);
case ISD::SDIV:
@@ -17126,7 +17116,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case RISCVISD::SHL_VL:
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
[[fallthrough]];
case RISCVISD::SRA_VL:
@@ -17151,7 +17141,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SRL:
case ISD::SHL: {
if (N->getOpcode() == ISD::SHL) {
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
}
SDValue ShAmt = N->getOperand(1);
@@ -17167,7 +17157,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case RISCVISD::ADD_VL:
- if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return combineToVWMACC(N, DAG, Subtarget);
case RISCVISD::VWADD_W_VL:
@@ -17177,7 +17167,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return performVWADDSUBW_VLCombine(N, DCI, Subtarget);
case RISCVISD::SUB_VL:
case RISCVISD::MUL_VL:
- return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+ return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
case RISCVISD::VFMADD_VL:
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFMSUB_VL:
@@ -17186,7 +17176,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::STRICT_VFNMADD_VL:
case RISCVISD::STRICT_VFMSUB_VL:
case RISCVISD::STRICT_VFNMSUB_VL:
- return performVFMADD_VLCombine(N, DAG, Subtarget);
+ return performVFMADD_VLCombine(N, DCI, Subtarget);
case RISCVISD::FADD_VL:
case RISCVISD::FSUB_VL:
case RISCVISD::FMUL_VL:
@@ -17195,7 +17185,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
!Subtarget.hasVInstructionsF16())
return SDValue();
- return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
+ return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
}
case ISD::LOAD:
case ISD::STORE: {
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
index 3a99f535e9071..cb50ca4a72120 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
@@ -97,3 +97,112 @@ define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a,
store <2 x double> %g, ptr %z
ret void
}
+
+define void @vfwmacc_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2, <2 x double> %w) {
+; NO_FOLDING-LABEL: vfwmacc_v2f32_multiple_users:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v12, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v12, v8
+; NO_FOLDING-NEXT: vfmadd.vv v12, v9, v11
+; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT: vse64.v v10, (a0)
+; NO_FOLDING-NEXT: vse64.v v12, (a1)
+; NO_FOLDING-NEXT: vse64.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; FOLDING-LABEL: vfwmacc_v2f32_multiple_users:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT: vfwmul.vv v12, v8, v9
+; FOLDING-NEXT: vfwmacc.vv v11, v8, v10
+; FOLDING-NEXT: vfwsub.vv v8, v9, v10
+; FOLDING-NEXT: vse64.v v12, (a0)
+; FOLDING-NEXT: vse64.v v11, (a1)
+; FOLDING-NEXT: vse64.v v8, (a2)
+; FOLDING-NEXT: ret
+ %c = fpext <2 x float> %a to <2 x double>
+ %d = fpext <2 x float> %b to <2 x double>
+ %d2 = fpext <2 x float> %b2 to <2 x double>
+ %e = fmul <2 x double> %c, %d
+ %f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d2, <2 x double> %w)
+ %g = fsub <2 x double> %d, %d2
+ store <2 x double> %e, ptr %x
+ store <2 x double> %f, ptr %y
+ store <2 x double> %g, ptr %z
+ ret void
+}
+
+; Negative test. We can't fold because the FMA addend is a user.
+define void @vfwmacc_v2f32_multiple_users_addend_user(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
+; NO_FOLDING-LABEL: vfwmacc_v2f32_multiple_users_addend_user:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
+; NO_FOLDING-NEXT: vfmadd.vv v11, v9, v8
+; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
+; NO_FOLDING-NEXT: vse64.v v10, (a0)
+; NO_FOLDING-NEXT: vse64.v v11, (a1)
+; NO_FOLDING-NEXT: vse64.v v8, (a2)
+; NO_FOLDING-NEXT: ret
+;
+; FOLDING-LABEL: vfwmacc_v2f32_multiple_users_addend_user:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT: vfwcvt.f.f.v v11, v8
+; FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; FOLDING-NEXT: vfwcvt.f.f.v v9, v10
+; FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
+; FOLDING-NEXT: vfmul.vv v10, v11, v8
+; FOLDING-NEXT: vfmadd.vv v11, v9, v8
+; FOLDING-NEXT: vfsub.vv v8, v8, v9
+; FOLDING-NEXT: vse64.v v10, (a0)
+; FOLDING-NEXT: vse64.v v11, (a1)
+; FOLDING-NEXT: vse64.v v8, (a2)
+; FOLDING-NEXT: ret
+ %c = fpext <2 x float> %a to <2 x double>
+ %d = fpext <2 x float> %b to <2 x double>
+ %d2 = fpext <2 x float> %b2 to <2 x double>
+ %e = fmul <2 x double> %c, %d
+ %f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d2, <2 x double> %d)
+ %g = fsub <2 x double> %d, %d2
+ store <2 x double> %e, ptr %x
+ store <2 x double> %f, ptr %y
+ store <2 x double> %g, ptr %z
+ ret void
+}
+
+; Negative test. We can't fold because the FMA addend is a user.
+define void @vfwmacc_v2f32_addend_user(ptr %x, <2 x float> %a, <2 x float> %b) {
+; NO_FOLDING-LABEL: vfwmacc_v2f32_addend_user:
+; NO_FOLDING: # %bb.0:
+; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v10, v8
+; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
+; NO_FOLDING-NEXT: vfmadd.vv v8, v10, v8
+; NO_FOLDING-NEXT: vse64.v v8, (a0)
+; NO_FOLDING-NEXT: ret
+;
+; FOLDING-LABEL: vfwmacc_v2f32_addend_user:
+; FOLDING: # %bb.0:
+; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
+; FOLDING-NEXT: vfwcvt.f.f.v v10, v8
+; FOLDING-NEXT: vfwcvt.f.f.v v8, v9
+; FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
+; FOLDING-NEXT: vfmadd.vv v8, v10, v8
+; FOLDING-NEXT: vse64.v v8, (a0)
+; FOLDING-NEXT: ret
+ %c = fpext <2 x float> %a to <2 x double>
+ %d = fpext <2 x float> %b to <2 x double>
+ %f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d, <2 x double> %d)
+ store <2 x double> %f, ptr %x
+ ret void
+}
diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll
index 1803b52aca674..5140d89b78307 100644
--- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfwmacc.ll
@@ -2031,11 +2031,8 @@ define <8 x double> @vfwnmsac_fv_v8f64_v8f16(<8 x double> %va, <8 x half> %vb, h
define <2 x float> @vfwmacc_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c) {
; CHECK-LABEL: vfwmacc_vf2_v2f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: fcvt.s.h fa5, fa0
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v10, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT: vfmacc.vf v8, fa5, v10
+; CHECK-NEXT: vfwmacc.vf v8, fa0, v9
; CHECK-NEXT: ret
%cext = fpext half %c to float
%head = insertelement <2 x float> poison, float %cext, i32 0
@@ -2048,11 +2045,8 @@ define <2 x float> @vfwmacc_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c)
define <2 x float> @vfwmsac_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c) {
; CHECK-LABEL: vfwmsac_vf2_v2f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: fcvt.s.h fa5, fa0
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v10, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT: vfmsac.vf v8, fa5, v10
+; CHECK-NEXT: vfwmsac.vf v8, fa0, v9
; CHECK-NEXT: ret
%cext = fpext half %c to float
%head = insertelement <2 x float> poison, float %cext, i32 0
@@ -2066,11 +2060,8 @@ define <2 x float> @vfwmsac_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c)
define <2 x float> @vfwnmacc_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c) {
; CHECK-LABEL: vfwnmacc_vf2_v2f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: fcvt.s.h fa5, fa0
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v10, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT: vfnmacc.vf v8, fa5, v10
+; CHECK-NEXT: vfwnmacc.vf v8, fa0, v9
; CHECK-NEXT: ret
%cext = fpext half %c to float
%head = insertelement <2 x float> poison, float %cext, i32 0
@@ -2085,11 +2076,8 @@ define <2 x float> @vfwnmacc_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c)
define <2 x float> @vfwnmsac_vf2_v2f32(<2 x float> %va, <2 x half> %vb, half %c) {
; CHECK-LABEL: vfwnmsac_vf2_v2f32:
; CHECK: # %bb.0:
-; CHECK-NEXT: fcvt.s.h fa5, fa0
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
-; CHECK-NEXT: vfwcvt.f.f.v v10, v9
-; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
-; CHECK-NEXT: vfnmsac.vf v8, fa5, v10
+; CHECK-NEXT: vfwnmsac.vf v8, fa0, v9
; CHECK-NEXT: ret
%cext = fpext half %c to float
%head = insertelement <2 x float> poison, float %cext, i32 0
|
if (!NodeExtensionHelper::isSupportedRoot(TheUse, Subtarget)) | ||
return false; | ||
// We only support the first 2 operands of FMA. | ||
if (UI.getOperandNo() >= 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we assert that this is an FMA node?
I am wondering if it would make sense to assert if someone update isSupportedRoot
and forget to update this part of the code.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had that in the code originally with an isFMA
function. I removed it because it was the only use of the function. I figured if anyone added a new supported root that didn't use only operand 0 and 1, they'd also have to change the creation of the two NodeExtensionHelper
objects 12 or so lines above this. So maybe they'd notice they needed to change this too.
Upon further review, I just realized that all of the binary ops have a third passthru vector operand that we should have been checking for and excluding all along. It's very often undef
so it will take a more detailed review to figure out how to test that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I vaguely remember that we checked for the passthru operand being undef
.
Is it something we somehow removed or do I just misremember?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only undef check I see so far is in fillUpExtensionSupportForSplat, but that's only for RISCVISD::VMV_V_X_VL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
This adds FMA to the widening web support we have for add, sub, mul, and shl.
Extra care needs to be taken to not widen the third FMA operand.