@@ -189,6 +189,12 @@ class VectorLegalizer {
189
189
190
190
void PromoteSTRICT (SDNode *Node, SmallVectorImpl<SDValue> &Results);
191
191
192
+ // / Calculate the reduction using a type of higher precision and round the
193
+ // / result to match the original type. Setting NonArithmetic signifies the
194
+ // / rounding of the result does not affect its value.
195
+ void PromoteFloatVECREDUCE (SDNode *Node, SmallVectorImpl<SDValue> &Results,
196
+ bool NonArithmetic);
197
+
192
198
public:
193
199
VectorLegalizer (SelectionDAG& dag) :
194
200
DAG (dag), TLI(dag.getTargetLoweringInfo()) {}
@@ -500,20 +506,14 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
500
506
case ISD::VECREDUCE_UMAX:
501
507
case ISD::VECREDUCE_UMIN:
502
508
case ISD::VECREDUCE_FADD:
503
- case ISD::VECREDUCE_FMUL:
504
- case ISD::VECTOR_FIND_LAST_ACTIVE:
505
- Action = TLI.getOperationAction (Node->getOpcode (),
506
- Node->getOperand (0 ).getValueType ());
507
- break ;
508
509
case ISD::VECREDUCE_FMAX:
509
- case ISD::VECREDUCE_FMIN:
510
510
case ISD::VECREDUCE_FMAXIMUM:
511
+ case ISD::VECREDUCE_FMIN:
511
512
case ISD::VECREDUCE_FMINIMUM:
513
+ case ISD::VECREDUCE_FMUL:
514
+ case ISD::VECTOR_FIND_LAST_ACTIVE:
512
515
Action = TLI.getOperationAction (Node->getOpcode (),
513
516
Node->getOperand (0 ).getValueType ());
514
- // Defer non-vector results to LegalizeDAG.
515
- if (Action == TargetLowering::Promote)
516
- Action = TargetLowering::Legal;
517
517
break ;
518
518
case ISD::VECREDUCE_SEQ_FADD:
519
519
case ISD::VECREDUCE_SEQ_FMUL:
@@ -688,6 +688,24 @@ void VectorLegalizer::PromoteSTRICT(SDNode *Node,
688
688
Results.push_back (Round.getValue (1 ));
689
689
}
690
690
691
+ void VectorLegalizer::PromoteFloatVECREDUCE (SDNode *Node,
692
+ SmallVectorImpl<SDValue> &Results,
693
+ bool NonArithmetic) {
694
+ MVT OpVT = Node->getOperand (0 ).getSimpleValueType ();
695
+ assert (OpVT.isFloatingPoint () && " Expected floating point reduction!" );
696
+ MVT NewOpVT = TLI.getTypeToPromoteTo (Node->getOpcode (), OpVT);
697
+
698
+ SDLoc DL (Node);
699
+ SDValue NewOp = DAG.getNode (ISD::FP_EXTEND, DL, NewOpVT, Node->getOperand (0 ));
700
+ SDValue Rdx =
701
+ DAG.getNode (Node->getOpcode (), DL, NewOpVT.getVectorElementType (), NewOp,
702
+ Node->getFlags ());
703
+ SDValue Res =
704
+ DAG.getNode (ISD::FP_ROUND, DL, Node->getValueType (0 ), Rdx,
705
+ DAG.getIntPtrConstant (NonArithmetic, DL, /* isTarget=*/ true ));
706
+ Results.push_back (Res);
707
+ }
708
+
691
709
void VectorLegalizer::Promote (SDNode *Node, SmallVectorImpl<SDValue> &Results) {
692
710
// For a few operations there is a specific concept for promotion based on
693
711
// the operand's type.
@@ -719,6 +737,15 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
719
737
case ISD::STRICT_FMA:
720
738
PromoteSTRICT (Node, Results);
721
739
return ;
740
+ case ISD::VECREDUCE_FADD:
741
+ PromoteFloatVECREDUCE (Node, Results, /* NonArithmetic=*/ false );
742
+ return ;
743
+ case ISD::VECREDUCE_FMAX:
744
+ case ISD::VECREDUCE_FMAXIMUM:
745
+ case ISD::VECREDUCE_FMIN:
746
+ case ISD::VECREDUCE_FMINIMUM:
747
+ PromoteFloatVECREDUCE (Node, Results, /* NonArithmetic=*/ true );
748
+ return ;
722
749
case ISD::FP_ROUND:
723
750
case ISD::FP_EXTEND:
724
751
// These operations are used to do promotion so they can't be promoted
0 commit comments