Skip to content

Commit ef5d932

Browse files
committed
[DAG][RISCV] Use vp_reduce_fadd/fmul when widening types for FP reductions
This is a follow up to llvm#105455 which updates the VPIntrinsic mappings for the fadd and fmul cases, and supports both ordered and unordered reductions. This allows the use a single wider operation with a restricted EVL instead of padding the vector with the neutral element. This has all the same tradeoffs as the previous patch.
1 parent 26a8a85 commit ef5d932

File tree

5 files changed

+67
-86
lines changed

5 files changed

+67
-86
lines changed

llvm/include/llvm/IR/VPIntrinsics.def

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -722,27 +722,29 @@ HELPER_REGISTER_REDUCTION_VP(vp_reduce_fminimum, VP_REDUCE_FMINIMUM,
722722
#error \
723723
"The internal helper macro HELPER_REGISTER_REDUCTION_SEQ_VP is already defined!"
724724
#endif
725-
#define HELPER_REGISTER_REDUCTION_SEQ_VP(VPID, VPSD, SEQ_VPSD, INTRIN) \
725+
#define HELPER_REGISTER_REDUCTION_SEQ_VP(VPID, VPSD, SEQ_VPSD, SDOPC, SEQ_SDOPC, INTRIN) \
726726
BEGIN_REGISTER_VP_INTRINSIC(VPID, 2, 3) \
727727
BEGIN_REGISTER_VP_SDNODE(VPSD, 1, VPID, 2, 3) \
728728
VP_PROPERTY_REDUCTION(0, 1) \
729+
VP_PROPERTY_FUNCTIONAL_SDOPC(SDOPC) \
729730
END_REGISTER_VP_SDNODE(VPSD) \
730731
BEGIN_REGISTER_VP_SDNODE(SEQ_VPSD, 1, VPID, 2, 3) \
731732
HELPER_MAP_VPID_TO_VPSD(VPID, SEQ_VPSD) \
733+
VP_PROPERTY_FUNCTIONAL_SDOPC(SEQ_SDOPC) \
732734
VP_PROPERTY_REDUCTION(0, 1) \
733735
END_REGISTER_VP_SDNODE(SEQ_VPSD) \
734736
VP_PROPERTY_FUNCTIONAL_INTRINSIC(INTRIN) \
735737
END_REGISTER_VP_INTRINSIC(VPID)
736738

737739
// llvm.vp.reduce.fadd(start,x,mask,vlen)
738740
HELPER_REGISTER_REDUCTION_SEQ_VP(vp_reduce_fadd, VP_REDUCE_FADD,
739-
VP_REDUCE_SEQ_FADD,
740-
vector_reduce_fadd)
741+
VP_REDUCE_SEQ_FADD, VECREDUCE_FADD,
742+
VECREDUCE_SEQ_FADD, vector_reduce_fadd)
741743

742744
// llvm.vp.reduce.fmul(start,x,mask,vlen)
743745
HELPER_REGISTER_REDUCTION_SEQ_VP(vp_reduce_fmul, VP_REDUCE_FMUL,
744-
VP_REDUCE_SEQ_FMUL,
745-
vector_reduce_fmul)
746+
VP_REDUCE_SEQ_FMUL, VECREDUCE_FMUL,
747+
VECREDUCE_SEQ_FMUL, vector_reduce_fmul)
746748

747749
#undef HELPER_REGISTER_REDUCTION_SEQ_VP
748750

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7311,8 +7311,6 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE(SDNode *N) {
73117311
// Generate a vp.reduce_op if it is custom/legal for the target. This avoids
73127312
// needing to pad the source vector, because the inactive lanes can simply be
73137313
// disabled and not contribute to the result.
7314-
// TODO: VECREDUCE_FADD, VECREDUCE_FMUL aren't currently mapped correctly,
7315-
// and thus don't take this path.
73167314
if (auto VPOpcode = ISD::getVPForBaseOpcode(Opc);
73177315
VPOpcode && TLI.isOperationLegalOrCustom(*VPOpcode, WideVT)) {
73187316
SDValue Start = NeutralElem;
@@ -7351,6 +7349,7 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
73517349
SDValue VecOp = N->getOperand(1);
73527350
SDValue Op = GetWidenedVector(VecOp);
73537351

7352+
EVT VT = N->getValueType(0);
73547353
EVT OrigVT = VecOp.getValueType();
73557354
EVT WideVT = Op.getValueType();
73567355
EVT ElemVT = OrigVT.getVectorElementType();
@@ -7364,6 +7363,19 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
73647363
unsigned OrigElts = OrigVT.getVectorMinNumElements();
73657364
unsigned WideElts = WideVT.getVectorMinNumElements();
73667365

7366+
// Generate a vp.reduce_op if it is custom/legal for the target. This avoids
7367+
// needing to pad the source vector, because the inactive lanes can simply be
7368+
// disabled and not contribute to the result.
7369+
if (auto VPOpcode = ISD::getVPForBaseOpcode(Opc);
7370+
VPOpcode && TLI.isOperationLegalOrCustom(*VPOpcode, WideVT)) {
7371+
EVT WideMaskVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1,
7372+
WideVT.getVectorElementCount());
7373+
SDValue Mask = DAG.getAllOnesConstant(dl, WideMaskVT);
7374+
SDValue EVL = DAG.getElementCount(dl, TLI.getVPExplicitVectorLengthTy(),
7375+
OrigVT.getVectorElementCount());
7376+
return DAG.getNode(*VPOpcode, dl, VT, {AccOp, Op, Mask, EVL}, Flags);
7377+
}
7378+
73677379
if (WideVT.isScalableVector()) {
73687380
unsigned GCD = std::gcd(OrigElts, WideElts);
73697381
EVT SplatVT = EVT::getVectorVT(*DAG.getContext(), ElemVT,
@@ -7372,14 +7384,14 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
73727384
for (unsigned Idx = OrigElts; Idx < WideElts; Idx = Idx + GCD)
73737385
Op = DAG.getNode(ISD::INSERT_SUBVECTOR, dl, WideVT, Op, SplatNeutral,
73747386
DAG.getVectorIdxConstant(Idx, dl));
7375-
return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags);
7387+
return DAG.getNode(Opc, dl, VT, AccOp, Op, Flags);
73767388
}
73777389

73787390
for (unsigned Idx = OrigElts; Idx < WideElts; Idx++)
73797391
Op = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, WideVT, Op, NeutralElem,
73807392
DAG.getVectorIdxConstant(Idx, dl));
73817393

7382-
return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags);
7394+
return DAG.getNode(Opc, dl, VT, AccOp, Op, Flags);
73837395
}
73847396

73857397
SDValue DAGTypeLegalizer::WidenVecOp_VP_REDUCE(SDNode *N) {

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-formation.ll

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -791,12 +791,7 @@ define float @reduce_fadd_16xi32_prefix5(ptr %p) {
791791
; CHECK-NEXT: vle32.v v8, (a0)
792792
; CHECK-NEXT: lui a0, 524288
793793
; CHECK-NEXT: vmv.s.x v10, a0
794-
; CHECK-NEXT: vsetivli zero, 6, e32, m2, tu, ma
795-
; CHECK-NEXT: vslideup.vi v8, v10, 5
796-
; CHECK-NEXT: vsetivli zero, 7, e32, m2, tu, ma
797-
; CHECK-NEXT: vslideup.vi v8, v10, 6
798-
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
799-
; CHECK-NEXT: vslideup.vi v8, v10, 7
794+
; CHECK-NEXT: vsetivli zero, 5, e32, m2, ta, ma
800795
; CHECK-NEXT: vfredusum.vs v8, v8, v10
801796
; CHECK-NEXT: vfmv.f.s fa0, v8
802797
; CHECK-NEXT: ret
@@ -880,7 +875,7 @@ define float @reduce_fadd_4xi32_non_associative(ptr %p) {
880875
; CHECK-NEXT: vfmv.f.s fa5, v9
881876
; CHECK-NEXT: lui a0, 524288
882877
; CHECK-NEXT: vmv.s.x v9, a0
883-
; CHECK-NEXT: vslideup.vi v8, v9, 3
878+
; CHECK-NEXT: vsetivli zero, 3, e32, m1, ta, ma
884879
; CHECK-NEXT: vfredusum.vs v8, v8, v9
885880
; CHECK-NEXT: vfmv.f.s fa4, v8
886881
; CHECK-NEXT: fadd.s fa0, fa4, fa5

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp.ll

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,6 @@ define half @vreduce_fadd_v7f16(ptr %x, half %s) {
9898
; CHECK: # %bb.0:
9999
; CHECK-NEXT: vsetivli zero, 7, e16, m1, ta, ma
100100
; CHECK-NEXT: vle16.v v8, (a0)
101-
; CHECK-NEXT: lui a0, 1048568
102-
; CHECK-NEXT: vmv.s.x v9, a0
103-
; CHECK-NEXT: vsetivli zero, 8, e16, m1, ta, ma
104-
; CHECK-NEXT: vslideup.vi v8, v9, 7
105101
; CHECK-NEXT: vfmv.s.f v9, fa0
106102
; CHECK-NEXT: vfredusum.vs v8, v8, v9
107103
; CHECK-NEXT: vfmv.f.s fa0, v8
@@ -470,10 +466,6 @@ define float @vreduce_fadd_v7f32(ptr %x, float %s) {
470466
; CHECK: # %bb.0:
471467
; CHECK-NEXT: vsetivli zero, 7, e32, m2, ta, ma
472468
; CHECK-NEXT: vle32.v v8, (a0)
473-
; CHECK-NEXT: lui a0, 524288
474-
; CHECK-NEXT: vmv.s.x v10, a0
475-
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
476-
; CHECK-NEXT: vslideup.vi v8, v10, 7
477469
; CHECK-NEXT: vfmv.s.f v10, fa0
478470
; CHECK-NEXT: vfredusum.vs v8, v8, v10
479471
; CHECK-NEXT: vfmv.f.s fa0, v8
@@ -488,10 +480,6 @@ define float @vreduce_ord_fadd_v7f32(ptr %x, float %s) {
488480
; CHECK: # %bb.0:
489481
; CHECK-NEXT: vsetivli zero, 7, e32, m2, ta, ma
490482
; CHECK-NEXT: vle32.v v8, (a0)
491-
; CHECK-NEXT: lui a0, 524288
492-
; CHECK-NEXT: vmv.s.x v10, a0
493-
; CHECK-NEXT: vsetivli zero, 8, e32, m2, ta, ma
494-
; CHECK-NEXT: vslideup.vi v8, v10, 7
495483
; CHECK-NEXT: vfmv.s.f v10, fa0
496484
; CHECK-NEXT: vfredosum.vs v8, v8, v10
497485
; CHECK-NEXT: vfmv.f.s fa0, v8

llvm/test/CodeGen/RISCV/rvv/vreductions-fp-sdnode.ll

Lines changed: 42 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -889,17 +889,12 @@ define half @vreduce_ord_fadd_nxv3f16(<vscale x 3 x half> %v, half %s) {
889889
; CHECK-NEXT: csrr a0, vlenb
890890
; CHECK-NEXT: srli a0, a0, 3
891891
; CHECK-NEXT: slli a1, a0, 1
892-
; CHECK-NEXT: add a1, a1, a0
893892
; CHECK-NEXT: add a0, a1, a0
894-
; CHECK-NEXT: lui a2, 1048568
895-
; CHECK-NEXT: vsetvli a3, zero, e16, m1, ta, ma
896-
; CHECK-NEXT: vmv.v.x v9, a2
897-
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
898-
; CHECK-NEXT: vslideup.vx v8, v9, a1
899-
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
893+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
900894
; CHECK-NEXT: vfmv.s.f v9, fa0
901-
; CHECK-NEXT: vfredosum.vs v8, v8, v9
902-
; CHECK-NEXT: vfmv.f.s fa0, v8
895+
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
896+
; CHECK-NEXT: vfredosum.vs v9, v8, v9
897+
; CHECK-NEXT: vfmv.f.s fa0, v9
903898
; CHECK-NEXT: ret
904899
%red = call half @llvm.vector.reduce.fadd.nxv3f16(half %s, <vscale x 3 x half> %v)
905900
ret half %red
@@ -910,18 +905,15 @@ declare half @llvm.vector.reduce.fadd.nxv6f16(half, <vscale x 6 x half>)
910905
define half @vreduce_ord_fadd_nxv6f16(<vscale x 6 x half> %v, half %s) {
911906
; CHECK-LABEL: vreduce_ord_fadd_nxv6f16:
912907
; CHECK: # %bb.0:
913-
; CHECK-NEXT: lui a0, 1048568
914-
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
915-
; CHECK-NEXT: vmv.v.x v10, a0
916908
; CHECK-NEXT: csrr a0, vlenb
917-
; CHECK-NEXT: srli a0, a0, 2
918-
; CHECK-NEXT: add a1, a0, a0
919-
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
920-
; CHECK-NEXT: vslideup.vx v9, v10, a0
921-
; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, ma
909+
; CHECK-NEXT: srli a1, a0, 3
910+
; CHECK-NEXT: slli a1, a1, 1
911+
; CHECK-NEXT: sub a0, a0, a1
912+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
922913
; CHECK-NEXT: vfmv.s.f v10, fa0
923-
; CHECK-NEXT: vfredosum.vs v8, v8, v10
924-
; CHECK-NEXT: vfmv.f.s fa0, v8
914+
; CHECK-NEXT: vsetvli zero, a0, e16, m2, ta, ma
915+
; CHECK-NEXT: vfredosum.vs v10, v8, v10
916+
; CHECK-NEXT: vfmv.f.s fa0, v10
925917
; CHECK-NEXT: ret
926918
%red = call half @llvm.vector.reduce.fadd.nxv6f16(half %s, <vscale x 6 x half> %v)
927919
ret half %red
@@ -932,22 +924,15 @@ declare half @llvm.vector.reduce.fadd.nxv10f16(half, <vscale x 10 x half>)
932924
define half @vreduce_ord_fadd_nxv10f16(<vscale x 10 x half> %v, half %s) {
933925
; CHECK-LABEL: vreduce_ord_fadd_nxv10f16:
934926
; CHECK: # %bb.0:
935-
; CHECK-NEXT: lui a0, 1048568
936-
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
937-
; CHECK-NEXT: vmv.v.x v12, a0
938927
; CHECK-NEXT: csrr a0, vlenb
939-
; CHECK-NEXT: srli a0, a0, 2
940-
; CHECK-NEXT: add a1, a0, a0
941-
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
942-
; CHECK-NEXT: vslideup.vx v10, v12, a0
943-
; CHECK-NEXT: vsetvli zero, a0, e16, m1, tu, ma
944-
; CHECK-NEXT: vmv.v.v v11, v12
945-
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
946-
; CHECK-NEXT: vslideup.vx v11, v12, a0
947-
; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
928+
; CHECK-NEXT: srli a0, a0, 3
929+
; CHECK-NEXT: li a1, 10
930+
; CHECK-NEXT: mul a0, a0, a1
931+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
948932
; CHECK-NEXT: vfmv.s.f v12, fa0
949-
; CHECK-NEXT: vfredosum.vs v8, v8, v12
950-
; CHECK-NEXT: vfmv.f.s fa0, v8
933+
; CHECK-NEXT: vsetvli zero, a0, e16, m4, ta, ma
934+
; CHECK-NEXT: vfredosum.vs v12, v8, v12
935+
; CHECK-NEXT: vfmv.f.s fa0, v12
951936
; CHECK-NEXT: ret
952937
%red = call half @llvm.vector.reduce.fadd.nxv10f16(half %s, <vscale x 10 x half> %v)
953938
ret half %red
@@ -958,13 +943,16 @@ declare half @llvm.vector.reduce.fadd.nxv12f16(half, <vscale x 12 x half>)
958943
define half @vreduce_ord_fadd_nxv12f16(<vscale x 12 x half> %v, half %s) {
959944
; CHECK-LABEL: vreduce_ord_fadd_nxv12f16:
960945
; CHECK: # %bb.0:
961-
; CHECK-NEXT: lui a0, 1048568
962-
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
963-
; CHECK-NEXT: vmv.v.x v11, a0
946+
; CHECK-NEXT: csrr a0, vlenb
947+
; CHECK-NEXT: srli a0, a0, 3
948+
; CHECK-NEXT: slli a1, a0, 2
949+
; CHECK-NEXT: slli a0, a0, 4
950+
; CHECK-NEXT: sub a0, a0, a1
951+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
964952
; CHECK-NEXT: vfmv.s.f v12, fa0
965-
; CHECK-NEXT: vsetvli a0, zero, e16, m4, ta, ma
966-
; CHECK-NEXT: vfredosum.vs v8, v8, v12
967-
; CHECK-NEXT: vfmv.f.s fa0, v8
953+
; CHECK-NEXT: vsetvli zero, a0, e16, m4, ta, ma
954+
; CHECK-NEXT: vfredosum.vs v12, v8, v12
955+
; CHECK-NEXT: vfmv.f.s fa0, v12
968956
; CHECK-NEXT: ret
969957
%red = call half @llvm.vector.reduce.fadd.nxv12f16(half %s, <vscale x 12 x half> %v)
970958
ret half %red
@@ -977,17 +965,14 @@ define half @vreduce_fadd_nxv3f16(<vscale x 3 x half> %v, half %s) {
977965
; CHECK-NEXT: csrr a0, vlenb
978966
; CHECK-NEXT: srli a0, a0, 3
979967
; CHECK-NEXT: slli a1, a0, 1
980-
; CHECK-NEXT: add a1, a1, a0
981968
; CHECK-NEXT: add a0, a1, a0
982-
; CHECK-NEXT: lui a2, 1048568
983-
; CHECK-NEXT: vsetvli a3, zero, e16, m1, ta, ma
984-
; CHECK-NEXT: vmv.v.x v9, a2
985-
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
986-
; CHECK-NEXT: vslideup.vx v8, v9, a1
987-
; CHECK-NEXT: vsetvli a0, zero, e16, m1, ta, ma
969+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
988970
; CHECK-NEXT: vfmv.s.f v9, fa0
989-
; CHECK-NEXT: vfredusum.vs v8, v8, v9
990-
; CHECK-NEXT: vfmv.f.s fa0, v8
971+
; CHECK-NEXT: lui a1, 1048568
972+
; CHECK-NEXT: vmv.s.x v10, a1
973+
; CHECK-NEXT: vsetvli zero, a0, e16, m1, ta, ma
974+
; CHECK-NEXT: vfredusum.vs v10, v8, v9
975+
; CHECK-NEXT: vfmv.f.s fa0, v10
991976
; CHECK-NEXT: ret
992977
%red = call reassoc half @llvm.vector.reduce.fadd.nxv3f16(half %s, <vscale x 3 x half> %v)
993978
ret half %red
@@ -996,18 +981,17 @@ define half @vreduce_fadd_nxv3f16(<vscale x 3 x half> %v, half %s) {
996981
define half @vreduce_fadd_nxv6f16(<vscale x 6 x half> %v, half %s) {
997982
; CHECK-LABEL: vreduce_fadd_nxv6f16:
998983
; CHECK: # %bb.0:
999-
; CHECK-NEXT: lui a0, 1048568
1000-
; CHECK-NEXT: vsetvli a1, zero, e16, m1, ta, ma
1001-
; CHECK-NEXT: vmv.v.x v10, a0
1002984
; CHECK-NEXT: csrr a0, vlenb
1003-
; CHECK-NEXT: srli a0, a0, 2
1004-
; CHECK-NEXT: add a1, a0, a0
1005-
; CHECK-NEXT: vsetvli zero, a1, e16, m1, ta, ma
1006-
; CHECK-NEXT: vslideup.vx v9, v10, a0
1007-
; CHECK-NEXT: vsetvli a0, zero, e16, m2, ta, ma
985+
; CHECK-NEXT: srli a1, a0, 3
986+
; CHECK-NEXT: slli a1, a1, 1
987+
; CHECK-NEXT: sub a0, a0, a1
988+
; CHECK-NEXT: vsetivli zero, 1, e16, m1, ta, ma
1008989
; CHECK-NEXT: vfmv.s.f v10, fa0
1009-
; CHECK-NEXT: vfredusum.vs v8, v8, v10
1010-
; CHECK-NEXT: vfmv.f.s fa0, v8
990+
; CHECK-NEXT: lui a1, 1048568
991+
; CHECK-NEXT: vmv.s.x v11, a1
992+
; CHECK-NEXT: vsetvli zero, a0, e16, m2, ta, ma
993+
; CHECK-NEXT: vfredusum.vs v11, v8, v10
994+
; CHECK-NEXT: vfmv.f.s fa0, v11
1011995
; CHECK-NEXT: ret
1012996
%red = call reassoc half @llvm.vector.reduce.fadd.nxv6f16(half %s, <vscale x 6 x half> %v)
1013997
ret half %red

0 commit comments

Comments
 (0)