Skip to content

Commit 705636a

Browse files
authored
[SelectionDAG][RISCV] Move VP_REDUCE* legalization to LegalizeDAG.cpp. (#90522)
LegalizeVectorType is responsible for legalizing nodes that perform an operation on each element may need to scalarize. This is not true for nodes like VP_REDUCE.*, BUILD_VECTOR, SHUFFLE_VECTOR, EXTRACT_SUBVECTOR, etc. This patch drops any nodes with a scalar result from LegalizeVectorOps and handles them in LegalizeDAG instead. This required moving the reduction promotion to LegalizeDAG. I have removed the support integer promotion as it was incorrect for integer min/max reductions. Since it was untested, it was best to assert on it until it was really needed. There are a couple regressions that can be fixed with a small DAG combine which I will do as a follow up.
1 parent 18268ac commit 705636a

File tree

4 files changed

+85
-83
lines changed

4 files changed

+85
-83
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,13 @@ class SelectionDAGLegalize {
180180
SmallVectorImpl<SDValue> &Results);
181181
SDValue PromoteLegalFP_TO_INT_SAT(SDNode *Node, const SDLoc &dl);
182182

183+
/// Implements vector reduce operation promotion.
184+
///
185+
/// All vector operands are promoted to a vector type with larger element
186+
/// type, and the start value is promoted to a larger scalar type. Then the
187+
/// result is truncated back to the original scalar type.
188+
void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
189+
183190
SDValue ExpandPARITY(SDValue Op, const SDLoc &dl);
184191

185192
SDValue ExpandExtractFromVectorThroughStack(SDValue Op);
@@ -2979,6 +2986,47 @@ SDValue SelectionDAGLegalize::ExpandPARITY(SDValue Op, const SDLoc &dl) {
29792986
return DAG.getNode(ISD::AND, dl, VT, Result, DAG.getConstant(1, dl, VT));
29802987
}
29812988

2989+
void SelectionDAGLegalize::PromoteReduction(SDNode *Node,
2990+
SmallVectorImpl<SDValue> &Results) {
2991+
MVT VecVT = Node->getOperand(1).getSimpleValueType();
2992+
MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
2993+
MVT ScalarVT = Node->getSimpleValueType(0);
2994+
MVT NewScalarVT = NewVecVT.getVectorElementType();
2995+
2996+
SDLoc DL(Node);
2997+
SmallVector<SDValue, 4> Operands(Node->getNumOperands());
2998+
2999+
// promote the initial value.
3000+
// FIXME: Support integer.
3001+
assert(Node->getOperand(0).getValueType().isFloatingPoint() &&
3002+
"Only FP promotion is supported");
3003+
Operands[0] =
3004+
DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
3005+
3006+
for (unsigned j = 1; j != Node->getNumOperands(); ++j)
3007+
if (Node->getOperand(j).getValueType().isVector() &&
3008+
!(ISD::isVPOpcode(Node->getOpcode()) &&
3009+
ISD::getVPMaskIdx(Node->getOpcode()) == j)) { // Skip mask operand.
3010+
// promote the vector operand.
3011+
// FIXME: Support integer.
3012+
assert(Node->getOperand(j).getValueType().isFloatingPoint() &&
3013+
"Only FP promotion is supported");
3014+
Operands[j] =
3015+
DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
3016+
} else {
3017+
Operands[j] = Node->getOperand(j); // Skip VL operand.
3018+
}
3019+
3020+
SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
3021+
Node->getFlags());
3022+
3023+
assert(ScalarVT.isFloatingPoint() && "Only FP promotion is supported");
3024+
Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
3025+
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
3026+
3027+
Results.push_back(Res);
3028+
}
3029+
29823030
bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
29833031
LLVM_DEBUG(dbgs() << "Trying to expand node\n");
29843032
SmallVector<SDValue, 8> Results;
@@ -4955,7 +5003,12 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
49555003
if (Node->getOpcode() == ISD::STRICT_UINT_TO_FP ||
49565004
Node->getOpcode() == ISD::STRICT_SINT_TO_FP ||
49575005
Node->getOpcode() == ISD::STRICT_FSETCC ||
4958-
Node->getOpcode() == ISD::STRICT_FSETCCS)
5006+
Node->getOpcode() == ISD::STRICT_FSETCCS ||
5007+
Node->getOpcode() == ISD::VP_REDUCE_FADD ||
5008+
Node->getOpcode() == ISD::VP_REDUCE_FMUL ||
5009+
Node->getOpcode() == ISD::VP_REDUCE_FMAX ||
5010+
Node->getOpcode() == ISD::VP_REDUCE_FMIN ||
5011+
Node->getOpcode() == ISD::VP_REDUCE_SEQ_FADD)
49595012
OVT = Node->getOperand(1).getSimpleValueType();
49605013
if (Node->getOpcode() == ISD::BR_CC ||
49615014
Node->getOpcode() == ISD::SELECT_CC)
@@ -5613,6 +5666,13 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
56135666
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true)));
56145667
break;
56155668
}
5669+
case ISD::VP_REDUCE_FADD:
5670+
case ISD::VP_REDUCE_FMUL:
5671+
case ISD::VP_REDUCE_FMAX:
5672+
case ISD::VP_REDUCE_FMIN:
5673+
case ISD::VP_REDUCE_SEQ_FADD:
5674+
PromoteReduction(Node, Results);
5675+
break;
56165676
}
56175677

56185678
// Replace the original node with the legalized result.

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 5 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,6 @@ class VectorLegalizer {
176176
/// truncated back to the original type.
177177
void PromoteFP_TO_INT(SDNode *Node, SmallVectorImpl<SDValue> &Results);
178178

179-
/// Implements vector reduce operation promotion.
180-
///
181-
/// All vector operands are promoted to a vector type with larger element
182-
/// type, and the start value is promoted to a larger scalar type. Then the
183-
/// result is truncated back to the original scalar type.
184-
void PromoteReduction(SDNode *Node, SmallVectorImpl<SDValue> &Results);
185-
186179
/// Implements vector setcc operation promotion.
187180
///
188181
/// All vector operands are promoted to a vector type with larger element
@@ -510,6 +503,11 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
510503
if (Action != TargetLowering::Legal) \
511504
break; \
512505
} \
506+
/* Defer non-vector results to LegalizeDAG. */ \
507+
if (!Node->getValueType(0).isVector()) { \
508+
Action = TargetLowering::Legal; \
509+
break; \
510+
} \
513511
Action = TLI.getOperationAction(Node->getOpcode(), LegalizeVT); \
514512
} break;
515513
#include "llvm/IR/VPIntrinsics.def"
@@ -580,50 +578,6 @@ bool VectorLegalizer::LowerOperationWrapper(SDNode *Node,
580578
return true;
581579
}
582580

583-
void VectorLegalizer::PromoteReduction(SDNode *Node,
584-
SmallVectorImpl<SDValue> &Results) {
585-
MVT VecVT = Node->getOperand(1).getSimpleValueType();
586-
MVT NewVecVT = TLI.getTypeToPromoteTo(Node->getOpcode(), VecVT);
587-
MVT ScalarVT = Node->getSimpleValueType(0);
588-
MVT NewScalarVT = NewVecVT.getVectorElementType();
589-
590-
SDLoc DL(Node);
591-
SmallVector<SDValue, 4> Operands(Node->getNumOperands());
592-
593-
// promote the initial value.
594-
if (Node->getOperand(0).getValueType().isFloatingPoint())
595-
Operands[0] =
596-
DAG.getNode(ISD::FP_EXTEND, DL, NewScalarVT, Node->getOperand(0));
597-
else
598-
Operands[0] =
599-
DAG.getNode(ISD::ANY_EXTEND, DL, NewScalarVT, Node->getOperand(0));
600-
601-
for (unsigned j = 1; j != Node->getNumOperands(); ++j)
602-
if (Node->getOperand(j).getValueType().isVector() &&
603-
!(ISD::isVPOpcode(Node->getOpcode()) &&
604-
ISD::getVPMaskIdx(Node->getOpcode()) == j)) // Skip mask operand.
605-
// promote the vector operand.
606-
if (Node->getOperand(j).getValueType().isFloatingPoint())
607-
Operands[j] =
608-
DAG.getNode(ISD::FP_EXTEND, DL, NewVecVT, Node->getOperand(j));
609-
else
610-
Operands[j] =
611-
DAG.getNode(ISD::ANY_EXTEND, DL, NewVecVT, Node->getOperand(j));
612-
else
613-
Operands[j] = Node->getOperand(j); // Skip VL operand.
614-
615-
SDValue Res = DAG.getNode(Node->getOpcode(), DL, NewScalarVT, Operands,
616-
Node->getFlags());
617-
618-
if (ScalarVT.isFloatingPoint())
619-
Res = DAG.getNode(ISD::FP_ROUND, DL, ScalarVT, Res,
620-
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
621-
else
622-
Res = DAG.getNode(ISD::TRUNCATE, DL, ScalarVT, Res);
623-
624-
Results.push_back(Res);
625-
}
626-
627581
void VectorLegalizer::PromoteSETCC(SDNode *Node,
628582
SmallVectorImpl<SDValue> &Results) {
629583
MVT VecVT = Node->getOperand(0).getSimpleValueType();
@@ -708,23 +662,6 @@ void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
708662
// Promote the operation by extending the operand.
709663
PromoteFP_TO_INT(Node, Results);
710664
return;
711-
case ISD::VP_REDUCE_ADD:
712-
case ISD::VP_REDUCE_MUL:
713-
case ISD::VP_REDUCE_AND:
714-
case ISD::VP_REDUCE_OR:
715-
case ISD::VP_REDUCE_XOR:
716-
case ISD::VP_REDUCE_SMAX:
717-
case ISD::VP_REDUCE_SMIN:
718-
case ISD::VP_REDUCE_UMAX:
719-
case ISD::VP_REDUCE_UMIN:
720-
case ISD::VP_REDUCE_FADD:
721-
case ISD::VP_REDUCE_FMUL:
722-
case ISD::VP_REDUCE_FMAX:
723-
case ISD::VP_REDUCE_FMIN:
724-
case ISD::VP_REDUCE_SEQ_FADD:
725-
// Promote the operation by extending the operand.
726-
PromoteReduction(Node, Results);
727-
return;
728665
case ISD::VP_SETCC:
729666
case ISD::SETCC:
730667
// Promote the operation by extending the operand.

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -802,25 +802,27 @@ define signext i32 @vpreduce_xor_v64i32(i32 signext %s, <64 x i32> %v, <64 x i1>
802802
; CHECK-LABEL: vpreduce_xor_v64i32:
803803
; CHECK: # %bb.0:
804804
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, ta, ma
805-
; CHECK-NEXT: li a3, 32
806805
; CHECK-NEXT: vslidedown.vi v24, v0, 4
807-
; CHECK-NEXT: mv a2, a1
808-
; CHECK-NEXT: bltu a1, a3, .LBB49_2
806+
; CHECK-NEXT: addi a2, a1, -32
807+
; CHECK-NEXT: sltu a3, a1, a2
808+
; CHECK-NEXT: addi a3, a3, -1
809+
; CHECK-NEXT: li a4, 32
810+
; CHECK-NEXT: and a2, a3, a2
811+
; CHECK-NEXT: bltu a1, a4, .LBB49_2
809812
; CHECK-NEXT: # %bb.1:
810-
; CHECK-NEXT: li a2, 32
813+
; CHECK-NEXT: li a1, 32
811814
; CHECK-NEXT: .LBB49_2:
812815
; CHECK-NEXT: vsetvli zero, zero, e32, m2, ta, ma
813816
; CHECK-NEXT: vmv.s.x v25, a0
814-
; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
817+
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
815818
; CHECK-NEXT: vredxor.vs v25, v8, v25, v0.t
816-
; CHECK-NEXT: addi a0, a1, -32
817-
; CHECK-NEXT: sltu a1, a1, a0
818-
; CHECK-NEXT: addi a1, a1, -1
819-
; CHECK-NEXT: and a0, a1, a0
820-
; CHECK-NEXT: vsetvli zero, a0, e32, m8, ta, ma
821-
; CHECK-NEXT: vmv1r.v v0, v24
822-
; CHECK-NEXT: vredxor.vs v25, v16, v25, v0.t
823819
; CHECK-NEXT: vmv.x.s a0, v25
820+
; CHECK-NEXT: vsetivli zero, 1, e32, m8, ta, ma
821+
; CHECK-NEXT: vmv.s.x v8, a0
822+
; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
823+
; CHECK-NEXT: vmv1r.v v0, v24
824+
; CHECK-NEXT: vredxor.vs v8, v16, v8, v0.t
825+
; CHECK-NEXT: vmv.x.s a0, v8
824826
; CHECK-NEXT: ret
825827
%r = call i32 @llvm.vp.reduce.xor.v64i32(i32 %s, <64 x i32> %v, <64 x i1> %m, i32 %evl)
826828
ret i32 %r

llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,10 +1115,13 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
11151115
; CHECK-NEXT: vmv.s.x v25, a0
11161116
; CHECK-NEXT: vsetvli zero, a1, e32, m8, ta, ma
11171117
; CHECK-NEXT: vredmaxu.vs v25, v8, v25, v0.t
1118+
; CHECK-NEXT: vmv.x.s a0, v25
1119+
; CHECK-NEXT: vsetivli zero, 1, e32, m8, ta, ma
1120+
; CHECK-NEXT: vmv.s.x v8, a0
11181121
; CHECK-NEXT: vsetvli zero, a2, e32, m8, ta, ma
11191122
; CHECK-NEXT: vmv1r.v v0, v24
1120-
; CHECK-NEXT: vredmaxu.vs v25, v16, v25, v0.t
1121-
; CHECK-NEXT: vmv.x.s a0, v25
1123+
; CHECK-NEXT: vredmaxu.vs v8, v16, v8, v0.t
1124+
; CHECK-NEXT: vmv.x.s a0, v8
11221125
; CHECK-NEXT: ret
11231126
%r = call i32 @llvm.vp.reduce.umax.nxv32i32(i32 %s, <vscale x 32 x i32> %v, <vscale x 32 x i1> %m, i32 %evl)
11241127
ret i32 %r

0 commit comments

Comments
 (0)