@@ -14309,6 +14309,14 @@ struct NodeExtensionHelper {
14309
14309
return RISCVISD::VFWSUB_VL;
14310
14310
case RISCVISD::FMUL_VL:
14311
14311
return RISCVISD::VFWMUL_VL;
14312
+ case RISCVISD::VFMADD_VL:
14313
+ return RISCVISD::VFWMADD_VL;
14314
+ case RISCVISD::VFMSUB_VL:
14315
+ return RISCVISD::VFWMSUB_VL;
14316
+ case RISCVISD::VFNMADD_VL:
14317
+ return RISCVISD::VFWNMADD_VL;
14318
+ case RISCVISD::VFNMSUB_VL:
14319
+ return RISCVISD::VFWNMSUB_VL;
14312
14320
default:
14313
14321
llvm_unreachable("Unexpected opcode");
14314
14322
}
@@ -14502,6 +14510,11 @@ struct NodeExtensionHelper {
14502
14510
Subtarget.hasStdExtZvbb();
14503
14511
case RISCVISD::SHL_VL:
14504
14512
return Subtarget.hasStdExtZvbb();
14513
+ case RISCVISD::VFMADD_VL:
14514
+ case RISCVISD::VFNMSUB_VL:
14515
+ case RISCVISD::VFNMADD_VL:
14516
+ case RISCVISD::VFMSUB_VL:
14517
+ return true;
14505
14518
default:
14506
14519
return false;
14507
14520
}
@@ -14582,6 +14595,10 @@ struct NodeExtensionHelper {
14582
14595
case RISCVISD::FADD_VL:
14583
14596
case RISCVISD::FMUL_VL:
14584
14597
case RISCVISD::VFWADD_W_VL:
14598
+ case RISCVISD::VFMADD_VL:
14599
+ case RISCVISD::VFNMSUB_VL:
14600
+ case RISCVISD::VFNMADD_VL:
14601
+ case RISCVISD::VFMSUB_VL:
14585
14602
return true;
14586
14603
case ISD::SUB:
14587
14604
case RISCVISD::SUB_VL:
@@ -14797,6 +14814,10 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
14797
14814
Strategies.push_back(canFoldToVW_W);
14798
14815
break;
14799
14816
case RISCVISD::FMUL_VL:
14817
+ case RISCVISD::VFMADD_VL:
14818
+ case RISCVISD::VFMSUB_VL:
14819
+ case RISCVISD::VFNMADD_VL:
14820
+ case RISCVISD::VFNMSUB_VL:
14800
14821
Strategies.push_back(canFoldToVWWithSameExtension);
14801
14822
break;
14802
14823
case ISD::MUL:
@@ -14833,7 +14854,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
14833
14854
}
14834
14855
} // End anonymous namespace.
14835
14856
14836
- /// Combine a binary operation to its equivalent VW or VW_W form.
14857
+ /// Combine a binary or FMA operation to its equivalent VW or VW_W form.
14837
14858
/// The supported combines are:
14838
14859
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
14839
14860
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
@@ -14846,9 +14867,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
14846
14867
/// vwsub_w(u) -> vwsub(u)
14847
14868
/// vfwadd_w -> vfwadd
14848
14869
/// vfwsub_w -> vfwsub
14849
- static SDValue combineBinOp_VLToVWBinOp_VL (SDNode *N,
14850
- TargetLowering::DAGCombinerInfo &DCI,
14851
- const RISCVSubtarget &Subtarget) {
14870
+ static SDValue combineOp_VLToVWOp_VL (SDNode *N,
14871
+ TargetLowering::DAGCombinerInfo &DCI,
14872
+ const RISCVSubtarget &Subtarget) {
14852
14873
SelectionDAG &DAG = DCI.DAG;
14853
14874
if (DCI.isBeforeLegalize())
14854
14875
return SDValue();
@@ -14864,19 +14885,26 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
14864
14885
14865
14886
while (!Worklist.empty()) {
14866
14887
SDNode *Root = Worklist.pop_back_val();
14867
- if (!NodeExtensionHelper::isSupportedRoot(Root, Subtarget))
14868
- return SDValue();
14869
14888
14870
14889
NodeExtensionHelper LHS(Root, 0, DAG, Subtarget);
14871
14890
NodeExtensionHelper RHS(Root, 1, DAG, Subtarget);
14872
- auto AppendUsersIfNeeded = [&Worklist,
14891
+ auto AppendUsersIfNeeded = [&Worklist, &Subtarget,
14873
14892
&Inserted](const NodeExtensionHelper &Op) {
14874
14893
if (Op.needToPromoteOtherUsers()) {
14875
- for (SDNode *TheUse : Op.OrigOperand->uses()) {
14894
+ for (SDNode::use_iterator UI = Op.OrigOperand->use_begin(),
14895
+ UE = Op.OrigOperand->use_end();
14896
+ UI != UE; ++UI) {
14897
+ SDNode *TheUse = *UI;
14898
+ if (!NodeExtensionHelper::isSupportedRoot(TheUse, Subtarget))
14899
+ return false;
14900
+ // We only support the first 2 operands of FMA.
14901
+ if (UI.getOperandNo() >= 2)
14902
+ return false;
14876
14903
if (Inserted.insert(TheUse).second)
14877
14904
Worklist.push_back(TheUse);
14878
14905
}
14879
14906
}
14907
+ return true;
14880
14908
};
14881
14909
14882
14910
// Control the compile time by limiting the number of node we look at in
@@ -14904,9 +14932,11 @@ static SDValue combineBinOp_VLToVWBinOp_VL(SDNode *N,
14904
14932
// we would be leaving the old input (since it is may still be used),
14905
14933
// and the new one.
14906
14934
if (Res->LHSExt.has_value())
14907
- AppendUsersIfNeeded(LHS);
14935
+ if (!AppendUsersIfNeeded(LHS))
14936
+ return SDValue();
14908
14937
if (Res->RHSExt.has_value())
14909
- AppendUsersIfNeeded(RHS);
14938
+ if (!AppendUsersIfNeeded(RHS))
14939
+ return SDValue();
14910
14940
break;
14911
14941
}
14912
14942
}
@@ -14993,7 +15023,7 @@ static SDValue performVWADDSUBW_VLCombine(SDNode *N,
14993
15023
assert(Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWADDU_W_VL ||
14994
15024
Opc == RISCVISD::VWSUB_W_VL || Opc == RISCVISD::VWSUBU_W_VL);
14995
15025
14996
- if (SDValue V = combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget))
15026
+ if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
14997
15027
return V;
14998
15028
14999
15029
return combineVWADDSUBWSelect(N, DCI.DAG);
@@ -15408,8 +15438,11 @@ static SDValue combineVFMADD_VLWithVFNEG_VL(SDNode *N, SelectionDAG &DAG) {
15408
15438
VL);
15409
15439
}
15410
15440
15411
- static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
15441
+ static SDValue performVFMADD_VLCombine(SDNode *N,
15442
+ TargetLowering::DAGCombinerInfo &DCI,
15412
15443
const RISCVSubtarget &Subtarget) {
15444
+ SelectionDAG &DAG = DCI.DAG;
15445
+
15413
15446
if (SDValue V = combineVFMADD_VLWithVFNEG_VL(N, DAG))
15414
15447
return V;
15415
15448
@@ -15421,50 +15454,7 @@ static SDValue performVFMADD_VLCombine(SDNode *N, SelectionDAG &DAG,
15421
15454
if (N->isTargetStrictFPOpcode())
15422
15455
return SDValue();
15423
15456
15424
- // Try to form widening FMA.
15425
- SDValue Op0 = N->getOperand(0);
15426
- SDValue Op1 = N->getOperand(1);
15427
- SDValue Mask = N->getOperand(3);
15428
- SDValue VL = N->getOperand(4);
15429
-
15430
- if (Op0.getOpcode() != RISCVISD::FP_EXTEND_VL ||
15431
- Op1.getOpcode() != RISCVISD::FP_EXTEND_VL)
15432
- return SDValue();
15433
-
15434
- // TODO: Refactor to handle more complex cases similar to
15435
- // combineBinOp_VLToVWBinOp_VL.
15436
- if ((!Op0.hasOneUse() || !Op1.hasOneUse()) &&
15437
- (Op0 != Op1 || !Op0->hasNUsesOfValue(2, 0)))
15438
- return SDValue();
15439
-
15440
- // Check the mask and VL are the same.
15441
- if (Op0.getOperand(1) != Mask || Op0.getOperand(2) != VL ||
15442
- Op1.getOperand(1) != Mask || Op1.getOperand(2) != VL)
15443
- return SDValue();
15444
-
15445
- unsigned NewOpc;
15446
- switch (N->getOpcode()) {
15447
- default:
15448
- llvm_unreachable("Unexpected opcode");
15449
- case RISCVISD::VFMADD_VL:
15450
- NewOpc = RISCVISD::VFWMADD_VL;
15451
- break;
15452
- case RISCVISD::VFNMSUB_VL:
15453
- NewOpc = RISCVISD::VFWNMSUB_VL;
15454
- break;
15455
- case RISCVISD::VFNMADD_VL:
15456
- NewOpc = RISCVISD::VFWNMADD_VL;
15457
- break;
15458
- case RISCVISD::VFMSUB_VL:
15459
- NewOpc = RISCVISD::VFWMSUB_VL;
15460
- break;
15461
- }
15462
-
15463
- Op0 = Op0.getOperand(0);
15464
- Op1 = Op1.getOperand(0);
15465
-
15466
- return DAG.getNode(NewOpc, SDLoc(N), N->getValueType(0), Op0, Op1,
15467
- N->getOperand(2), Mask, VL);
15457
+ return combineOp_VLToVWOp_VL(N, DCI, Subtarget);
15468
15458
}
15469
15459
15470
15460
static SDValue performSRACombine(SDNode *N, SelectionDAG &DAG,
@@ -16661,28 +16651,28 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
16661
16651
break;
16662
16652
}
16663
16653
case ISD::ADD: {
16664
- if (SDValue V = combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget))
16654
+ if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
16665
16655
return V;
16666
16656
if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
16667
16657
return V;
16668
16658
return performADDCombine(N, DCI, Subtarget);
16669
16659
}
16670
16660
case ISD::SUB: {
16671
- if (SDValue V = combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget))
16661
+ if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
16672
16662
return V;
16673
16663
return performSUBCombine(N, DAG, Subtarget);
16674
16664
}
16675
16665
case ISD::AND:
16676
16666
return performANDCombine(N, DCI, Subtarget);
16677
16667
case ISD::OR: {
16678
- if (SDValue V = combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget))
16668
+ if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
16679
16669
return V;
16680
16670
return performORCombine(N, DCI, Subtarget);
16681
16671
}
16682
16672
case ISD::XOR:
16683
16673
return performXORCombine(N, DAG, Subtarget);
16684
16674
case ISD::MUL:
16685
- if (SDValue V = combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget))
16675
+ if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
16686
16676
return V;
16687
16677
return performMULCombine(N, DAG, DCI, Subtarget);
16688
16678
case ISD::SDIV:
@@ -17107,7 +17097,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
17107
17097
break;
17108
17098
}
17109
17099
case RISCVISD::SHL_VL:
17110
- if (SDValue V = combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget))
17100
+ if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
17111
17101
return V;
17112
17102
[[fallthrough]];
17113
17103
case RISCVISD::SRA_VL:
@@ -17132,7 +17122,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
17132
17122
case ISD::SRL:
17133
17123
case ISD::SHL: {
17134
17124
if (N->getOpcode() == ISD::SHL) {
17135
- if (SDValue V = combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget))
17125
+ if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
17136
17126
return V;
17137
17127
}
17138
17128
SDValue ShAmt = N->getOperand(1);
@@ -17148,7 +17138,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
17148
17138
break;
17149
17139
}
17150
17140
case RISCVISD::ADD_VL:
17151
- if (SDValue V = combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget))
17141
+ if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
17152
17142
return V;
17153
17143
return combineToVWMACC(N, DAG, Subtarget);
17154
17144
case RISCVISD::VWADD_W_VL:
@@ -17158,7 +17148,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
17158
17148
return performVWADDSUBW_VLCombine(N, DCI, Subtarget);
17159
17149
case RISCVISD::SUB_VL:
17160
17150
case RISCVISD::MUL_VL:
17161
- return combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget);
17151
+ return combineOp_VLToVWOp_VL (N, DCI, Subtarget);
17162
17152
case RISCVISD::VFMADD_VL:
17163
17153
case RISCVISD::VFNMADD_VL:
17164
17154
case RISCVISD::VFMSUB_VL:
@@ -17167,7 +17157,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
17167
17157
case RISCVISD::STRICT_VFNMADD_VL:
17168
17158
case RISCVISD::STRICT_VFMSUB_VL:
17169
17159
case RISCVISD::STRICT_VFNMSUB_VL:
17170
- return performVFMADD_VLCombine(N, DAG , Subtarget);
17160
+ return performVFMADD_VLCombine(N, DCI , Subtarget);
17171
17161
case RISCVISD::FADD_VL:
17172
17162
case RISCVISD::FSUB_VL:
17173
17163
case RISCVISD::FMUL_VL:
@@ -17176,7 +17166,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
17176
17166
if (N->getValueType(0).getVectorElementType() == MVT::f32 &&
17177
17167
!Subtarget.hasVInstructionsF16())
17178
17168
return SDValue();
17179
- return combineBinOp_VLToVWBinOp_VL (N, DCI, Subtarget);
17169
+ return combineOp_VLToVWOp_VL (N, DCI, Subtarget);
17180
17170
}
17181
17171
case ISD::LOAD:
17182
17172
case ISD::STORE: {
0 commit comments