Skip to content

Commit b582b65

Browse files
authored
[RISCV] Add FMA support to combineOp_VLToVWOp_VL. (llvm#100454)
This adds FMA to the widening web support we have for add, sub, mul, and shl. Extra care needs to be taken to not widen the third FMA operand.
1 parent 9086f9d commit b582b65

File tree

3 files changed

+169
-82
lines changed

3 files changed

+169
-82
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 56 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14309,6 +14309,14 @@ struct NodeExtensionHelper {
1430914309
return RISCVISD::VFWSUB_VL;
1431014310
case RISCVISD::FMUL_VL:
1431114311
return RISCVISD::VFWMUL_VL;
14312+
case RISCVISD::VFMADD_VL:
14313+
return RISCVISD::VFWMADD_VL;
14314+
case RISCVISD::VFMSUB_VL:
14315+
return RISCVISD::VFWMSUB_VL;
14316+
case RISCVISD::VFNMADD_VL:
14317+
return RISCVISD::VFWNMADD_VL;
14318+
case RISCVISD::VFNMSUB_VL:
14319+
return RISCVISD::VFWNMSUB_VL;
1431214320
default:
1431314321
llvm_unreachable("Unexpected opcode");
1431414322
}
@@ -14502,6 +14510,11 @@ struct NodeExtensionHelper {
1450214510
Subtarget.hasStdExtZvbb();
1450314511
case RISCVISD::SHL_VL:
1450414512
return Subtarget.hasStdExtZvbb();
14513+
case RISCVISD::VFMADD_VL:
14514+
case RISCVISD::VFNMSUB_VL:
14515+
case RISCVISD::VFNMADD_VL:
14516+
case RISCVISD::VFMSUB_VL:
14517+
return true;
1450514518
default:
1450614519
return false;
1450714520
}
@@ -14582,6 +14595,10 @@ struct NodeExtensionHelper {
1458214595
case RISCVISD::FADD_VL:
1458314596
case RISCVISD::FMUL_VL:
1458414597
case RISCVISD::VFWADD_W_VL:
14598+
case RISCVISD::VFMADD_VL:
14599+
case RISCVISD::VFNMSUB_VL:
14600+
case RISCVISD::VFNMADD_VL:
14601+
case RISCVISD::VFMSUB_VL:
1458514602
return true;
1458614603
case ISD::SUB:
1458714604
case RISCVISD::SUB_VL:
@@ -14797,6 +14814,10 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1479714814
Strategies.push_back(canFoldToVW_W);
1479814815
break;
1479914816
case RISCVISD::FMUL_VL:
14817+
case RISCVISD::VFMADD_VL:
14818+
case RISCVISD::VFMSUB_VL:
14819+
case RISCVISD::VFNMADD_VL:
14820+
case RISCVISD::VFNMSUB_VL:
1480014821
Strategies.push_back(canFoldToVWWithSameExtension);
1480114822
break;
1480214823
case ISD::MUL:
@@ -14833,7 +14854,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1483314854
}
1483414855
} // End anonymous namespace.
1483514856

14836-
/// Combine a binary operation to its equivalent VW or VW_W form.
14857+
/// Combine a binary or FMA operation to its equivalent VW or VW_W form.
1483714858
/// The supported combines are:
1483814859
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
1483914860
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
@@ -14846,9 +14867,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1484614867
/// vwsub_w(u) -> vwsub(u)
1484714868
/// vfwadd_w -> vfwadd
1484814869
/// vfwsub_w -> vfwsub
14849-
static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
14850-
TargetLowering::DAGCombinerInfo &DCI,
14851-
const RISCVSubtarget &Subtarget) {
14870+
static SDValue combineOp_VLToVWOp_VL(SDNode *N,
14871+
TargetLowering::DAGCombinerInfo &DCI,
14872+
const RISCVSubtarget &Subtarget) {
1485214873
SelectionDAG &DAG = DCI.DAG;
1485314874
if (DCI.isBeforeLegalize())
1485414875
return SDValue();
@@ -14864,19 +14885,26 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
1486414885

1486514886
while (!Worklist.empty()) {
1486614887
SDNode *Root = Worklist.pop_back_val();
14867-
if (!NodeExtensionHelper::isSupportedRoot(Root, Subtarget))
14868-
return SDValue();
1486914888

1487014889
NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
1487114890
NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
14872-
auto AppendUsersIfNeeded = [&Worklist,
14891+
auto AppendUsersIfNeeded = [&Worklist, &Subtarget,
1487314892
&Inserted](const NodeExtensionHelper &Op) {
1487414893
if (Op.needToPromoteOtherUsers()) {
14875-
for (SDNode *TheUse : Op.OrigOperand->uses()) {
14894+
for (SDNode::use_iterator UI = Op.OrigOperand->use_begin(),
14895+
UE = Op.OrigOperand->use_end();
14896+
UI != UE; ++UI) {
14897+
SDNode *TheUse = *UI;
14898+
if (!NodeExtensionHelper::isSupportedRoot(TheUse, Subtarget))
14899+
return false;
14900+
// We only support the first 2 operands of FMA.
14901+
if (UI.getOperandNo() >= 2)
14902+
return false;
1487614903
if (Inserted.insert(TheUse).second)
1487714904
Worklist.push_back(TheUse);
1487814905
}
1487914906
}
14907+
return true;
1488014908
};
1488114909

1488214910
// Control the compile time by limiting the number of node we look at in
@@ -14904,9 +14932,11 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
1490414932
// we would be leaving the old input (since it is may still be used),
1490514933
// and the new one.
1490614934
if (Res->LHSExt.has_value())
14907-
AppendUsersIfNeeded(LHS);
14935+
if (!AppendUsersIfNeeded(LHS))
14936+
return SDValue();
1490814937
if (Res->RHSExt.has_value())
14909-
AppendUsersIfNeeded(RHS);
14938+
if (!AppendUsersIfNeeded(RHS))
14939+
return SDValue();
1491014940
break;
1491114941
}
1491214942
}
@@ -14993,7 +15023,7 @@ static SDValue performVWADDSUBW_VLCombine(SDNode *N,
1499315023
assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL ||
1499415024
Opc == RISCVISD::VWSUB_W_VL || Opc == RISCVISD::VWSUBU_W_VL);
1499515025

14996-
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
15026+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1499715027
return V;
1499815028

1499915029
return combineVWADDSUBWSelect(N, DCI.DAG);
@@ -15408,8 +15438,11 @@ static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
1540815438
VL);
1540915439
}
1541015440

15411-
static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
15441+
static SDValue performVFMADD_VLCombine(SDNode *N,
15442+
TargetLowering::DAGCombinerInfo &DCI,
1541215443
const RISCVSubtarget &Subtarget) {
15444+
SelectionDAG &DAG = DCI.DAG;
15445+
1541315446
if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
1541415447
return V;
1541515448

@@ -15421,50 +15454,7 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
1542115454
if (N->isTargetStrictFPOpcode())
1542215455
return SDValue();
1542315456

15424-
// Try to form widening FMA.
15425-
SDValue Op0 = N->getOperand(0);
15426-
SDValue Op1 = N->getOperand(1);
15427-
SDValue Mask = N->getOperand(3);
15428-
SDValue VL = N->getOperand(4);
15429-
15430-
if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
15431-
Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
15432-
return SDValue();
15433-
15434-
// TODO: Refactor to handle more complex cases similar to
15435-
// combineBinOp_VLToVWBinOp_VL.
15436-
if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
15437-
(Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
15438-
return SDValue();
15439-
15440-
// Check the mask and VL are the same.
15441-
if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
15442-
Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
15443-
return SDValue();
15444-
15445-
unsigned NewOpc;
15446-
switch (N->getOpcode()) {
15447-
default:
15448-
llvm_unreachable("Unexpected opcode");
15449-
case RISCVISD::VFMADD_VL:
15450-
NewOpc = RISCVISD::VFWMADD_VL;
15451-
break;
15452-
case RISCVISD::VFNMSUB_VL:
15453-
NewOpc = RISCVISD::VFWNMSUB_VL;
15454-
break;
15455-
case RISCVISD::VFNMADD_VL:
15456-
NewOpc = RISCVISD::VFWNMADD_VL;
15457-
break;
15458-
case RISCVISD::VFMSUB_VL:
15459-
NewOpc = RISCVISD::VFWMSUB_VL;
15460-
break;
15461-
}
15462-
15463-
Op0 = Op0.getOperand(0);
15464-
Op1 = Op1.getOperand(0);
15465-
15466-
return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
15467-
N->getOperand(2), Mask, VL);
15457+
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
1546815458
}
1546915459

1547015460
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
@@ -16661,28 +16651,28 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1666116651
break;
1666216652
}
1666316653
case ISD::ADD: {
16664-
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
16654+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1666516655
return V;
1666616656
if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
1666716657
return V;
1666816658
return performADDCombine(N, DCI, Subtarget);
1666916659
}
1667016660
case ISD::SUB: {
16671-
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
16661+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1667216662
return V;
1667316663
return performSUBCombine(N, DAG, Subtarget);
1667416664
}
1667516665
case ISD::AND:
1667616666
return performANDCombine(N, DCI, Subtarget);
1667716667
case ISD::OR: {
16678-
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
16668+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1667916669
return V;
1668016670
return performORCombine(N, DCI, Subtarget);
1668116671
}
1668216672
case ISD::XOR:
1668316673
return performXORCombine(N, DAG, Subtarget);
1668416674
case ISD::MUL:
16685-
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
16675+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1668616676
return V;
1668716677
return performMULCombine(N, DAG, DCI, Subtarget);
1668816678
case ISD::SDIV:
@@ -17107,7 +17097,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1710717097
break;
1710817098
}
1710917099
case RISCVISD::SHL_VL:
17110-
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
17100+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1711117101
return V;
1711217102
[[fallthrough]];
1711317103
case RISCVISD::SRA_VL:
@@ -17132,7 +17122,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1713217122
case ISD::SRL:
1713317123
case ISD::SHL: {
1713417124
if (N->getOpcode() == ISD::SHL) {
17135-
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
17125+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1713617126
return V;
1713717127
}
1713817128
SDValue ShAmt = N->getOperand(1);
@@ -17148,7 +17138,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1714817138
break;
1714917139
}
1715017140
case RISCVISD::ADD_VL:
17151-
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
17141+
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1715217142
return V;
1715317143
return combineToVWMACC(N, DAG, Subtarget);
1715417144
case RISCVISD::VWADD_W_VL:
@@ -17158,7 +17148,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1715817148
return performVWADDSUBW_VLCombine(N, DCI, Subtarget);
1715917149
case RISCVISD::SUB_VL:
1716017150
case RISCVISD::MUL_VL:
17161-
return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
17151+
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
1716217152
case RISCVISD::VFMADD_VL:
1716317153
case RISCVISD::VFNMADD_VL:
1716417154
case RISCVISD::VFMSUB_VL:
@@ -17167,7 +17157,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1716717157
case RISCVISD::STRICT_VFNMADD_VL:
1716817158
case RISCVISD::STRICT_VFMSUB_VL:
1716917159
case RISCVISD::STRICT_VFNMSUB_VL:
17170-
return performVFMADD_VLCombine(N, DAG, Subtarget);
17160+
return performVFMADD_VLCombine(N, DCI, Subtarget);
1717117161
case RISCVISD::FADD_VL:
1717217162
case RISCVISD::FSUB_VL:
1717317163
case RISCVISD::FMUL_VL:
@@ -17176,7 +17166,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1717617166
if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
1717717167
!Subtarget.hasVInstructionsF16())
1717817168
return SDValue();
17179-
return combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget);
17169+
return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
1718017170
}
1718117171
case ISD::LOAD:
1718217172
case ISD::STORE: {

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vfw-web-simplification.ll

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,112 @@ define void @vfwmul_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a,
9797
store <2 x double> %g, ptr %z
9898
ret void
9999
}
100+
101+
define void @vfwmacc_v2f32_multiple_users(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2, <2 x double> %w) {
102+
; NO_FOLDING-LABEL: vfwmacc_v2f32_multiple_users:
103+
; NO_FOLDING: # %bb.0:
104+
; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
105+
; NO_FOLDING-NEXT: vfwcvt.f.f.v v12, v8
106+
; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
107+
; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
108+
; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
109+
; NO_FOLDING-NEXT: vfmul.vv v10, v12, v8
110+
; NO_FOLDING-NEXT: vfmadd.vv v12, v9, v11
111+
; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
112+
; NO_FOLDING-NEXT: vse64.v v10, (a0)
113+
; NO_FOLDING-NEXT: vse64.v v12, (a1)
114+
; NO_FOLDING-NEXT: vse64.v v8, (a2)
115+
; NO_FOLDING-NEXT: ret
116+
;
117+
; FOLDING-LABEL: vfwmacc_v2f32_multiple_users:
118+
; FOLDING: # %bb.0:
119+
; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
120+
; FOLDING-NEXT: vfwmul.vv v12, v8, v9
121+
; FOLDING-NEXT: vfwmacc.vv v11, v8, v10
122+
; FOLDING-NEXT: vfwsub.vv v8, v9, v10
123+
; FOLDING-NEXT: vse64.v v12, (a0)
124+
; FOLDING-NEXT: vse64.v v11, (a1)
125+
; FOLDING-NEXT: vse64.v v8, (a2)
126+
; FOLDING-NEXT: ret
127+
%c = fpext <2 x float> %a to <2 x double>
128+
%d = fpext <2 x float> %b to <2 x double>
129+
%d2 = fpext <2 x float> %b2 to <2 x double>
130+
%e = fmul <2 x double> %c, %d
131+
%f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d2, <2 x double> %w)
132+
%g = fsub <2 x double> %d, %d2
133+
store <2 x double> %e, ptr %x
134+
store <2 x double> %f, ptr %y
135+
store <2 x double> %g, ptr %z
136+
ret void
137+
}
138+
139+
; Negative test. We can't fold because the FMA addend is a user.
140+
define void @vfwmacc_v2f32_multiple_users_addend_user(ptr %x, ptr %y, ptr %z, <2 x float> %a, <2 x float> %b, <2 x float> %b2) {
141+
; NO_FOLDING-LABEL: vfwmacc_v2f32_multiple_users_addend_user:
142+
; NO_FOLDING: # %bb.0:
143+
; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
144+
; NO_FOLDING-NEXT: vfwcvt.f.f.v v11, v8
145+
; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
146+
; NO_FOLDING-NEXT: vfwcvt.f.f.v v9, v10
147+
; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
148+
; NO_FOLDING-NEXT: vfmul.vv v10, v11, v8
149+
; NO_FOLDING-NEXT: vfmadd.vv v11, v9, v8
150+
; NO_FOLDING-NEXT: vfsub.vv v8, v8, v9
151+
; NO_FOLDING-NEXT: vse64.v v10, (a0)
152+
; NO_FOLDING-NEXT: vse64.v v11, (a1)
153+
; NO_FOLDING-NEXT: vse64.v v8, (a2)
154+
; NO_FOLDING-NEXT: ret
155+
;
156+
; FOLDING-LABEL: vfwmacc_v2f32_multiple_users_addend_user:
157+
; FOLDING: # %bb.0:
158+
; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
159+
; FOLDING-NEXT: vfwcvt.f.f.v v11, v8
160+
; FOLDING-NEXT: vfwcvt.f.f.v v8, v9
161+
; FOLDING-NEXT: vfwcvt.f.f.v v9, v10
162+
; FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
163+
; FOLDING-NEXT: vfmul.vv v10, v11, v8
164+
; FOLDING-NEXT: vfmadd.vv v11, v9, v8
165+
; FOLDING-NEXT: vfsub.vv v8, v8, v9
166+
; FOLDING-NEXT: vse64.v v10, (a0)
167+
; FOLDING-NEXT: vse64.v v11, (a1)
168+
; FOLDING-NEXT: vse64.v v8, (a2)
169+
; FOLDING-NEXT: ret
170+
%c = fpext <2 x float> %a to <2 x double>
171+
%d = fpext <2 x float> %b to <2 x double>
172+
%d2 = fpext <2 x float> %b2 to <2 x double>
173+
%e = fmul <2 x double> %c, %d
174+
%f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d2, <2 x double> %d)
175+
%g = fsub <2 x double> %d, %d2
176+
store <2 x double> %e, ptr %x
177+
store <2 x double> %f, ptr %y
178+
store <2 x double> %g, ptr %z
179+
ret void
180+
}
181+
182+
; Negative test. We can't fold because the FMA addend is a user.
183+
define void @vfwmacc_v2f32_addend_user(ptr %x, <2 x float> %a, <2 x float> %b) {
184+
; NO_FOLDING-LABEL: vfwmacc_v2f32_addend_user:
185+
; NO_FOLDING: # %bb.0:
186+
; NO_FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
187+
; NO_FOLDING-NEXT: vfwcvt.f.f.v v10, v8
188+
; NO_FOLDING-NEXT: vfwcvt.f.f.v v8, v9
189+
; NO_FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
190+
; NO_FOLDING-NEXT: vfmadd.vv v8, v10, v8
191+
; NO_FOLDING-NEXT: vse64.v v8, (a0)
192+
; NO_FOLDING-NEXT: ret
193+
;
194+
; FOLDING-LABEL: vfwmacc_v2f32_addend_user:
195+
; FOLDING: # %bb.0:
196+
; FOLDING-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
197+
; FOLDING-NEXT: vfwcvt.f.f.v v10, v8
198+
; FOLDING-NEXT: vfwcvt.f.f.v v8, v9
199+
; FOLDING-NEXT: vsetvli zero, zero, e64, m1, ta, ma
200+
; FOLDING-NEXT: vfmadd.vv v8, v10, v8
201+
; FOLDING-NEXT: vse64.v v8, (a0)
202+
; FOLDING-NEXT: ret
203+
%c = fpext <2 x float> %a to <2 x double>
204+
%d = fpext <2 x float> %b to <2 x double>
205+
%f = call <2 x double> @llvm.fma(<2 x double> %c, <2 x double> %d, <2 x double> %d)
206+
store <2 x double> %f, ptr %x
207+
ret void
208+
}

0 commit comments

Comments
 (0)