Skip to content

[RISCV] Add FMA support to combineOp_VLToVWOp_VL. #100454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 56 additions & 66 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14328,6 +14328,14 @@ struct NodeExtensionHelper {
return RISCVISD::VFWSUB_VL;
case RISCVISD::FMUL_VL:
return RISCVISD::VFWMUL_VL;
case RISCVISD::VFMADD_VL:
return RISCVISD::VFWMADD_VL;
case RISCVISD::VFMSUB_VL:
return RISCVISD::VFWMSUB_VL;
case RISCVISD::VFNMADD_VL:
return RISCVISD::VFWNMADD_VL;
case RISCVISD::VFNMSUB_VL:
return RISCVISD::VFWNMSUB_VL;
default:
llvm_unreachable("Unexpected opcode");
}
Expand Down Expand Up @@ -14521,6 +14529,11 @@ struct NodeExtensionHelper {
Subtarget.hasStdExtZvbb();
case RISCVISD::SHL_VL:
return Subtarget.hasStdExtZvbb();
case RISCVISD::VFMADD_VL:
case RISCVISD::VFNMSUB_VL:
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFMSUB_VL:
return true;
default:
return false;
}
Expand Down Expand Up @@ -14601,6 +14614,10 @@ struct NodeExtensionHelper {
case RISCVISD::FADD_VL:
case RISCVISD::FMUL_VL:
case RISCVISD::VFWADD_W_VL:
case RISCVISD::VFMADD_VL:
case RISCVISD::VFNMSUB_VL:
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFMSUB_VL:
return true;
case ISD::SUB:
case RISCVISD::SUB_VL:
Expand Down Expand Up @@ -14816,6 +14833,10 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
Strategies.push_back(canFoldToVW_W);
break;
case RISCVISD::FMUL_VL:
case RISCVISD::VFMADD_VL:
case RISCVISD::VFMSUB_VL:
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFNMSUB_VL:
Strategies.push_back(canFoldToVWWithSameExtension);
break;
case ISD::MUL:
Expand Down Expand Up @@ -14852,7 +14873,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
}
} // End anonymous namespace.

/// Combine a binary operation to its equivalent VW or VW_W form.
/// Combine a binary or FMA operation to its equivalent VW or VW_W form.
/// The supported combines are:
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
Expand All @@ -14865,9 +14886,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
/// vwsub_w(u) -> vwsub(u)
/// vfwadd_w -> vfwadd
/// vfwsub_w -> vfwsub
static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
static SDValue combineOp_VLToVWOp_VL(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;
if (DCI.isBeforeLegalize())
return SDValue();
Expand All @@ -14883,19 +14904,26 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,

while (!Worklist.empty()) {
SDNode *Root = Worklist.pop_back_val();
if (!NodeExtensionHelper::isSupportedRoot(Root, Subtarget))
return SDValue();

NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
auto AppendUsersIfNeeded = [&Worklist,
auto AppendUsersIfNeeded = [&Worklist, &Subtarget,
&Inserted](const NodeExtensionHelper &Op) {
if (Op.needToPromoteOtherUsers()) {
for (SDNode *TheUse : Op.OrigOperand->uses()) {
for (SDNode::use_iterator UI = Op.OrigOperand->use_begin(),
UE = Op.OrigOperand->use_end();
UI != UE; ++UI) {
SDNode *TheUse = *UI;
if (!NodeExtensionHelper::isSupportedRoot(TheUse, Subtarget))
return false;
// We only support the first 2 operands of FMA.
if (UI.getOperandNo() >= 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert that this is an FMA node?
I am wondering if it would make sense to assert if someone update isSupportedRoot and forget to update this part of the code.

What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had that in the code originally with an isFMA function. I removed it because it was the only use of the function. I figured if anyone added a new supported root that didn't use only operand 0 and 1, they'd also have to change the creation of the two NodeExtensionHelper objects 12 or so lines above this. So maybe they'd notice they needed to change this too.

Upon further review, I just realized that all of the binary ops have a third passthru vector operand that we should have been checking for and excluding all along. It's very often undef so it will take a more detailed review to figure out how to test that case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remember that we checked for the passthru operand being undef.
Is it something we somehow removed or do I just misremember?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only undef check I see so far is in fillUpExtensionSupportForSplat, but that's only for RISCVISD::VMV_V_X_VL

return false;
if (Inserted.insert(TheUse).second)
Worklist.push_back(TheUse);
}
}
return true;
};

// Control the compile time by limiting the number of node we look at in
Expand Down Expand Up @@ -14923,9 +14951,11 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
// we would be leaving the old input (since it is may still be used),
// and the new one.
if (Res->LHSExt.has_value())
AppendUsersIfNeeded(LHS);
if (!AppendUsersIfNeeded(LHS))
return SDValue();
if (Res->RHSExt.has_value())
AppendUsersIfNeeded(RHS);
if (!AppendUsersIfNeeded(RHS))
return SDValue();
break;
}
}
Expand Down Expand Up @@ -15012,7 +15042,7 @@ static SDValue performVWADDSUBW_VLCombine(SDNode *N,
assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL ||
Opc == RISCVISD::VWSUB_W_VL || Opc == RISCVISD::VWSUBU_W_VL);

if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;

return combineVWADDSUBWSelect(N, DCI.DAG);
Expand Down Expand Up @@ -15427,8 +15457,11 @@ static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
VL);
}

static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
static SDValue performVFMADD_VLCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const RISCVSubtarget &Subtarget) {
SelectionDAG &DAG = DCI.DAG;

if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
return V;

Expand All @@ -15440,50 +15473,7 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
if (N->isTargetStrictFPOpcode())
return SDValue();

// Try to form widening FMA.
SDValue Op0 = N->getOperand(0);
SDValue Op1 = N->getOperand(1);
SDValue Mask = N->getOperand(3);
SDValue VL = N->getOperand(4);

if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
return SDValue();

// TODO: Refactor to handle more complex cases similar to
// combineBinOp_VLToVWBinOp_VL.
if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
(Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
return SDValue();

// Check the mask and VL are the same.
if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
return SDValue();

unsigned NewOpc;
switch (N->getOpcode()) {
default:
llvm_unreachable("Unexpected opcode");
case RISCVISD::VFMADD_VL:
NewOpc = RISCVISD::VFWMADD_VL;
break;
case RISCVISD::VFNMSUB_VL:
NewOpc = RISCVISD::VFWNMSUB_VL;
break;
case RISCVISD::VFNMADD_VL:
NewOpc = RISCVISD::VFWNMADD_VL;
break;
case RISCVISD::VFMSUB_VL:
NewOpc = RISCVISD::VFWMSUB_VL;
break;
}

Op0 = Op0.getOperand(0);
Op1 = Op1.getOperand(0);

return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
N->getOperand(2), Mask, VL);
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
}

static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
Expand Down Expand Up @@ -16680,28 +16670,28 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case ISD::ADD: {
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
return V;
return performADDCombine(N, DCI, Subtarget);
}
case ISD::SUB: {
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return performSUBCombine(N, DAG, Subtarget);
}
case ISD::AND:
return performANDCombine(N, DCI, Subtarget);
case ISD::OR: {
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return performORCombine(N, DCI, Subtarget);
}
case ISD::XOR:
return performXORCombine(N, DAG, Subtarget);
case ISD::MUL:
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return performMULCombine(N, DAG, DCI, Subtarget);
case ISD::SDIV:
Expand Down Expand Up @@ -17126,7 +17116,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case RISCVISD::SHL_VL:
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
[[fallthrough]];
case RISCVISD::SRA_VL:
Expand All @@ -17151,7 +17141,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SRL:
case ISD::SHL: {
if (N->getOpcode() == ISD::SHL) {
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
}
SDValue ShAmt = N->getOperand(1);
Expand All @@ -17167,7 +17157,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case RISCVISD::ADD_VL:
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
return V;
return combineToVWMACC(N, DAG, Subtarget);
case RISCVISD::VWADD_W_VL:
Expand All @@ -17177,7 +17167,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
return performVWADDSUBW_VLCombine(N, DCI, Subtarget);
case RISCVISD::SUB_VL:
case RISCVISD::MUL_VL:
return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
case RISCVISD::VFMADD_VL:
case RISCVISD::VFNMADD_VL:
case RISCVISD::VFMSUB_VL:
Expand All @@ -17186,7 +17176,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
case RISCVISD::STRICT_VFNMADD_VL:
case RISCVISD::STRICT_VFMSUB_VL:
case RISCVISD::STRICT_VFNMSUB_VL:
return performVFMADD_VLCombine(N, DAG, Subtarget);
return performVFMADD_VLCombine(N, DCI, Subtarget);
case RISCVISD::FADD_VL:
case RISCVISD::FSUB_VL:
case RISCVISD::FMUL_VL:
Expand All @@ -17195,7 +17185,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
!Subtarget.hasVInstructionsF16())
return SDValue();
return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
}
case ISD::LOAD:
case ISD::STORE: {
Expand Down
109 changes: 109 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,112 @@ define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a,
store <2 x double> %g, ptr %z
ret void
}

define void @vfwmacc_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2, <2 x double> %w) {
; NO_FOLDING-LABEL: vfwmacc_v2f32_multiple_users:
; NO_FOLDING: # %bb.0:
; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; NO_FOLDING-NEXT: vfwcvt.f.f.v v12, v8
; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
; NO_FOLDING-NEXT: vfmul.vv v10, v12, v8
; NO_FOLDING-NEXT: vfmadd.vv v12, v9, v11
; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
; NO_FOLDING-NEXT: vse64.v v10, (a0)
; NO_FOLDING-NEXT: vse64.v v12, (a1)
; NO_FOLDING-NEXT: vse64.v v8, (a2)
; NO_FOLDING-NEXT: ret
;
; FOLDING-LABEL: vfwmacc_v2f32_multiple_users:
; FOLDING: # %bb.0:
; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; FOLDING-NEXT: vfwmul.vv v12, v8, v9
; FOLDING-NEXT: vfwmacc.vv v11, v8, v10
; FOLDING-NEXT: vfwsub.vv v8, v9, v10
; FOLDING-NEXT: vse64.v v12, (a0)
; FOLDING-NEXT: vse64.v v11, (a1)
; FOLDING-NEXT: vse64.v v8, (a2)
; FOLDING-NEXT: ret
%c = fpext <2 x float> %a to <2 x double>
%d = fpext <2 x float> %b to <2 x double>
%d2 = fpext <2 x float> %b2 to <2 x double>
%e = fmul <2 x double> %c, %d
%f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d2, <2 x double> %w)
%g = fsub <2 x double> %d, %d2
store <2 x double> %e, ptr %x
store <2 x double> %f, ptr %y
store <2 x double> %g, ptr %z
ret void
}

; Negative test. We can't fold because the FMA addend is a user.
define void @vfwmacc_v2f32_multiple_users_addend_user(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
; NO_FOLDING-LABEL: vfwmacc_v2f32_multiple_users_addend_user:
; NO_FOLDING: # %bb.0:
; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
; NO_FOLDING-NEXT: vfmadd.vv v11, v9, v8
; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
; NO_FOLDING-NEXT: vse64.v v10, (a0)
; NO_FOLDING-NEXT: vse64.v v11, (a1)
; NO_FOLDING-NEXT: vse64.v v8, (a2)
; NO_FOLDING-NEXT: ret
;
; FOLDING-LABEL: vfwmacc_v2f32_multiple_users_addend_user:
; FOLDING: # %bb.0:
; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; FOLDING-NEXT: vfwcvt.f.f.v v11, v8
; FOLDING-NEXT: vfwcvt.f.f.v v8, v9
; FOLDING-NEXT: vfwcvt.f.f.v v9, v10
; FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
; FOLDING-NEXT: vfmul.vv v10, v11, v8
; FOLDING-NEXT: vfmadd.vv v11, v9, v8
; FOLDING-NEXT: vfsub.vv v8, v8, v9
; FOLDING-NEXT: vse64.v v10, (a0)
; FOLDING-NEXT: vse64.v v11, (a1)
; FOLDING-NEXT: vse64.v v8, (a2)
; FOLDING-NEXT: ret
%c = fpext <2 x float> %a to <2 x double>
%d = fpext <2 x float> %b to <2 x double>
%d2 = fpext <2 x float> %b2 to <2 x double>
%e = fmul <2 x double> %c, %d
%f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d2, <2 x double> %d)
%g = fsub <2 x double> %d, %d2
store <2 x double> %e, ptr %x
store <2 x double> %f, ptr %y
store <2 x double> %g, ptr %z
ret void
}

; Negative test. We can't fold because the FMA addend is a user.
define void @vfwmacc_v2f32_addend_user(ptr %x, <2 x float> %a, <2 x float> %b) {
; NO_FOLDING-LABEL: vfwmacc_v2f32_addend_user:
; NO_FOLDING: # %bb.0:
; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; NO_FOLDING-NEXT: vfwcvt.f.f.v v10, v8
; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
; NO_FOLDING-NEXT: vfmadd.vv v8, v10, v8
; NO_FOLDING-NEXT: vse64.v v8, (a0)
; NO_FOLDING-NEXT: ret
;
; FOLDING-LABEL: vfwmacc_v2f32_addend_user:
; FOLDING: # %bb.0:
; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; FOLDING-NEXT: vfwcvt.f.f.v v10, v8
; FOLDING-NEXT: vfwcvt.f.f.v v8, v9
; FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
; FOLDING-NEXT: vfmadd.vv v8, v10, v8
; FOLDING-NEXT: vse64.v v8, (a0)
; FOLDING-NEXT: ret
%c = fpext <2 x float> %a to <2 x double>
%d = fpext <2 x float> %b to <2 x double>
%f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d, <2 x double> %d)
store <2 x double> %f, ptr %x
ret void
}
Loading
Loading