Skip to content

Commit 2c585f0

Browse files
committed
WIP on fixed-vector-patterns: 7231776 Recommit "[DAG] Reducing instructions by better legalization handling of AVGFLOORU for illegal data types" (llvm#101223)
2 parents 7231776 + 516319d commit 2c585f0

12 files changed

+533
-426
lines changed

llvm/lib/Target/RISCV/RISCVFeatures.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,12 @@ def HasVInstructionsBF16Minimal : Predicate<"Subtarget->hasVInstructionsBF16Mini
874874
def HasVInstructionsF16 : Predicate<"Subtarget->hasVInstructionsF16()">;
875875
def HasVInstructionsF64 : Predicate<"Subtarget->hasVInstructionsF64()">;
876876

877+
878+
foreach i = { 6-16 } in {
879+
defvar I = !shl(1, i);
880+
def HasZvl#I#b : Predicate<"Subtarget->getRealMinVLen() >= " # I # "">;
881+
}
882+
877883
def HasVInstructionsFullMultiply : Predicate<"Subtarget->hasVInstructionsFullMultiply()">;
878884

879885
// Hypervisor Extensions

llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3316,6 +3316,8 @@ bool RISCVDAGToDAGISel::selectVLOp(SDValue N, SDValue &VL) {
33163316
}
33173317

33183318
static SDValue findVSplat(SDValue N) {
3319+
if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR)
3320+
N = N.getOperand(0);
33193321
if (N.getOpcode() == ISD::INSERT_SUBVECTOR) {
33203322
if (!N.getOperand(0).isUndef())
33213323
return SDValue();

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
12891289
{ISD::CTLZ, ISD::CTLZ_ZERO_UNDEF, ISD::CTTZ_ZERO_UNDEF}, VT,
12901290
Custom);
12911291
}
1292+
1293+
setOperationAction({ISD::LOAD, ISD::STORE, ISD::ADD, ISD::SUB, ISD::MUL, ISD::AND, ISD::OR, ISD::XOR, ISD::SHL, ISD::SRL, ISD::SRA, ISD::UMAX, ISD::SMAX, ISD::UMIN, ISD::SMIN, ISD::SIGN_EXTEND, ISD::ZERO_EXTEND }, VT, Legal);
12921294
}
12931295

12941296
for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) {
@@ -3714,6 +3716,7 @@ static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
37143716
: RISCVISD::VMV_V_X_VL;
37153717
if (!VT.isFloatingPoint())
37163718
Splat = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Splat);
3719+
return DAG.getNode(Opc, DL, VT, DAG.getUNDEF(VT), Splat, VL);
37173720
Splat =
37183721
DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Splat, VL);
37193722
return convertFromScalableVector(VT, Splat, DAG, Subtarget);

llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,9 @@ class VTypeInfo<ValueType Vec, ValueType Mas, int Sew, LMULInfo M,
264264
RegisterClass ScalarRegClass = ScalarReg;
265265
// The pattern fragment which produces the AVL operand, representing the
266266
// "natural" vector length for this type. For scalable vectors this is VLMax.
267-
OutPatFrag AVL = VLMax;
267+
OutPatFrag AVL = !if(Vec.isScalable, VLMax, !if(!lt(Vec.nElem, 32),
268+
OutPatFrag<(ops), (XLenVT Vec.nElem)>,
269+
OutPatFrag<(ops), (ADDI (XLenVT X0), (XLenVT Vec.nElem))>));
268270

269271
string ScalarSuffix = !cond(!eq(Scal, XLenVT) : "X",
270272
!eq(Scal, f16) : "FPR16",
@@ -280,6 +282,36 @@ class GroupVTypeInfo<ValueType Vec, ValueType VecM1, ValueType Mas, int Sew,
280282
ValueType VectorM1 = VecM1;
281283
}
282284

285+
multiclass FixedVTypeInfo<ValueType Vec, ValueType Mask, int Sew> {
286+
foreach lmul = MxSet<Sew>.m in {
287+
// if we only need 32 bits per vec reg, skip
288+
defvar bitspervec = !div(!mul(Vec.Size, 8), lmul.octuple);
289+
if !ge(bitspervec, 64) then
290+
def "" # lmul.MX : VTypeInfo<Vec, Mask, Sew, lmul>;
291+
}
292+
}
293+
294+
defset list<VTypeInfo> FixedLengthVectors = {
295+
defvar FixedLengthVectorVTs = !filter(vec, ValueTypes, !and(vec.isVector, !not(vec.isScalable), !eq(!shl(1, !logtwo(vec.nElem)), vec.nElem), !le(vec.nElem, 128)));
296+
defvar SupportedFixedLengthVectorVTs = !filter(vec, FixedLengthVectorVTs,
297+
!or(!eq(vec.ElementType, i8),
298+
!eq(vec.ElementType, i16),
299+
!eq(vec.ElementType, i32),
300+
!eq(vec.ElementType, i64)));
301+
foreach VT = SupportedFixedLengthVectorVTs in {
302+
defm "" # VT.LLVMName : FixedVTypeInfo<VT, !cast<VTVec>("v"#VT.nElem#"i1"), VT.ElementType.Size>;
303+
}
304+
// defm V2I32 : FixedVTypeInfo<v2i32, v2i1, 32>;
305+
// defm V4I32 : FixedVTypeInfo<v4i32, v4i1, 32>;
306+
// defm V1I64 : FixedVTypeInfo<v1i64, v1i1, 64>;
307+
// defm V2I64 : FixedVTypeInfo<v2i64, v2i1, 64>;
308+
// defm V3I64 : FixedVTypeInfo<v3i64, v3i1, 64>;
309+
// defm V4I64 : FixedVTypeInfo<v4i64, v4i1, 64>;
310+
// defm V8I64 : FixedVTypeInfo<v8i64, v8i1, 64>;
311+
// defm V16I64 : FixedVTypeInfo<v16i64, v16i1, 64>;
312+
// defm V32I64 : FixedVTypeInfo<v32i64, v32i1, 64>;
313+
}
314+
283315
defset list<VTypeInfo> AllVectors = {
284316
defset list<VTypeInfo> AllIntegerVectors = {
285317
defset list<VTypeInfo> NoGroupIntegerVectors = {
@@ -371,6 +403,8 @@ defset list<VTypeInfo> AllVectors = {
371403
}
372404
}
373405

406+
defvar AllScalableAndFixedIntegerVectors = !listconcat(AllIntegerVectors, FixedLengthVectors);
407+
374408
// This functor is used to obtain the int vector type that has the same SEW and
375409
// multiplier as the input parameter type
376410
class GetIntVTypeInfo<VTypeInfo vti> {
@@ -438,6 +472,15 @@ defset list<VTypeInfoToWide> AllWidenableIntVectors = {
438472
def : VTypeInfoToWide<VI32M4, VI64M8>;
439473
}
440474

475+
defset list<VTypeInfoToWide> AllFixedWidenableIntVectors = {
476+
foreach vti = !filter(vti, FixedLengthVectors, !and(!lt(vti.Vector.ElementType.Size, 64), !ne(vti.LMul, V_M8))) in {
477+
defvar wide_vti = !cast<VTypeInfo>("v" # vti.Vector.nElem # "i" # !mul(vti.Vector.ElementType.Size, 2) # octuple_to_str<!mul(vti.LMul.octuple, 2)>.ret);
478+
def : VTypeInfoToWide<vti, wide_vti>;
479+
}
480+
}
481+
482+
defvar AllScalableAndFixedWidenableIntVectors = !listconcat(AllWidenableIntVectors, AllFixedWidenableIntVectors);
483+
441484
defset list<VTypeInfoToWide> AllWidenableFloatVectors = {
442485
def : VTypeInfoToWide<VF16MF4, VF32MF2>;
443486
def : VTypeInfoToWide<VF16MF2, VF32M1>;
@@ -469,6 +512,21 @@ defset list<VTypeInfoToFraction> AllFractionableVF2IntVectors = {
469512
def : VTypeInfoToFraction<VI64M8, VI32M4>;
470513
}
471514

515+
multiclass FixedFractionableIntVectors<int vf> {
516+
defvar min_sew = !mul(8, vf);
517+
defvar min_lmul = !mul(1, vf);
518+
foreach vti = !filter(vti, FixedLengthVectors, !and(!ge(vti.Vector.ElementType.Size, min_sew), !ge(vti.LMul.octuple, min_lmul))) in {
519+
defvar narrow_vti = !cast<VTypeInfo>("v" # vti.Vector.nElem # "i" # !div(vti.Vector.ElementType.Size, vf) # octuple_to_str<!div(vti.LMul.octuple, vf)>.ret);
520+
def : VTypeInfoToFraction<vti, narrow_vti>;
521+
}
522+
}
523+
524+
defset list<VTypeInfoToFraction> AllFixedFractionableVF2IntVectors = {
525+
defm : FixedFractionableIntVectors<2>;
526+
}
527+
528+
defvar AllFixedAndScalableFractionableVF2IntVectors = !listconcat(AllFractionableVF2IntVectors, AllFixedFractionableVF2IntVectors);
529+
472530
defset list<VTypeInfoToFraction> AllFractionableVF4IntVectors = {
473531
def : VTypeInfoToFraction<VI32MF2, VI8MF8>;
474532
def : VTypeInfoToFraction<VI32M1, VI8MF4>;
@@ -481,13 +539,25 @@ defset list<VTypeInfoToFraction> AllFractionableVF4IntVectors = {
481539
def : VTypeInfoToFraction<VI64M8, VI16M2>;
482540
}
483541

542+
defset list<VTypeInfoToFraction> AllFixedFractionableVF4IntVectors = {
543+
defm : FixedFractionableIntVectors<4>;
544+
}
545+
546+
defvar AllFixedAndScalableFractionableVF4IntVectors = !listconcat(AllFractionableVF4IntVectors, AllFixedFractionableVF4IntVectors);
547+
484548
defset list<VTypeInfoToFraction> AllFractionableVF8IntVectors = {
485549
def : VTypeInfoToFraction<VI64M1, VI8MF8>;
486550
def : VTypeInfoToFraction<VI64M2, VI8MF4>;
487551
def : VTypeInfoToFraction<VI64M4, VI8MF2>;
488552
def : VTypeInfoToFraction<VI64M8, VI8M1>;
489553
}
490554

555+
defset list<VTypeInfoToFraction> AllFixedFractionableVF8IntVectors = {
556+
defm : FixedFractionableIntVectors<8>;
557+
}
558+
559+
defvar AllFixedAndScalableFractionableVF8IntVectors = !listconcat(AllFractionableVF8IntVectors, AllFixedFractionableVF8IntVectors);
560+
491561
defset list<VTypeInfoToWide> AllWidenableIntToFloatVectors = {
492562
def : VTypeInfoToWide<VI8MF8, VF16MF4>;
493563
def : VTypeInfoToWide<VI8MF4, VF16MF2>;
@@ -750,12 +820,18 @@ class VPseudo<Instruction instr, LMULInfo m, dag outs, dag ins, int sew = 0> :
750820
}
751821

752822
class GetVTypePredicates<VTypeInfo vti> {
753-
list<Predicate> Predicates = !cond(!eq(vti.Scalar, f16) : [HasVInstructionsF16],
823+
defvar Zvls = !foreach(i, !range(6, 16), !shl(1, i));
824+
defvar MinZvl = !head(!filter(zvl, Zvls, !ge(!div(!mul(zvl, vti.LMul.octuple), 8), vti.Vector.Size)));
825+
defvar FixedMinLength =
826+
!if(vti.Vector.isScalable, []<Predicate>, [!cast<Predicate>("HasZvl"#MinZvl#"b")]);
827+
828+
defvar EltPredicate = !cond(!eq(vti.Scalar, f16) : [HasVInstructionsF16],
754829
!eq(vti.Scalar, bf16) : [HasVInstructionsBF16Minimal],
755830
!eq(vti.Scalar, f32) : [HasVInstructionsAnyF],
756831
!eq(vti.Scalar, f64) : [HasVInstructionsF64],
757832
!eq(vti.SEW, 64) : [HasVInstructionsI64],
758833
true : [HasVInstructions]);
834+
list<Predicate> Predicates = !listconcat(EltPredicate, FixedMinLength);
759835
}
760836

761837
class VPseudoUSLoadNoMask<VReg RetClass,

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class VPatBinarySDNode_XI<SDPatternOperator vop,
123123
avl, log2sew, TA_MA)>;
124124

125125
multiclass VPatBinarySDNode_VV_VX<SDPatternOperator vop, string instruction_name,
126-
list<VTypeInfo> vtilist = AllIntegerVectors,
126+
list<VTypeInfo> vtilist = AllScalableAndFixedIntegerVectors,
127127
bit isSEWAware = 0> {
128128
foreach vti = vtilist in {
129129
let Predicates = GetVTypePredicates<vti>.Predicates in {
@@ -141,7 +141,7 @@ multiclass VPatBinarySDNode_VV_VX<SDPatternOperator vop, string instruction_name
141141
multiclass VPatBinarySDNode_VV_VX_VI<SDPatternOperator vop, string instruction_name,
142142
Operand ImmType = simm5>
143143
: VPatBinarySDNode_VV_VX<vop, instruction_name> {
144-
foreach vti = AllIntegerVectors in {
144+
foreach vti = AllScalableAndFixedIntegerVectors in {
145145
let Predicates = GetVTypePredicates<vti>.Predicates in
146146
def : VPatBinarySDNode_XI<vop, instruction_name, "VI",
147147
vti.Vector, vti.Vector, vti.Log2SEW,
@@ -470,7 +470,7 @@ multiclass VPatNConvertFP2ISDNode_W<SDPatternOperator vop,
470470

471471
multiclass VPatWidenBinarySDNode_VV_VX<SDNode op, PatFrags extop1, PatFrags extop2,
472472
string instruction_name> {
473-
foreach vtiToWti = AllWidenableIntVectors in {
473+
foreach vtiToWti = AllScalableAndFixedWidenableIntVectors in {
474474
defvar vti = vtiToWti.Vti;
475475
defvar wti = vtiToWti.Wti;
476476
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
@@ -491,7 +491,7 @@ multiclass VPatWidenBinarySDNode_VV_VX<SDNode op, PatFrags extop1, PatFrags exto
491491

492492
multiclass VPatWidenBinarySDNode_WV_WX<SDNode op, PatFrags extop,
493493
string instruction_name> {
494-
foreach vtiToWti = AllWidenableIntVectors in {
494+
foreach vtiToWti = AllScalableAndFixedWidenableIntVectors in {
495495
defvar vti = vtiToWti.Vti;
496496
defvar wti = vtiToWti.Wti;
497497
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
@@ -834,7 +834,7 @@ multiclass VPatWidenFPNegMulSacSDNode_VV_VF_RM<string instruction_name> {
834834
}
835835

836836
multiclass VPatMultiplyAddSDNode_VV_VX<SDNode op, string instruction_name> {
837-
foreach vti = AllIntegerVectors in {
837+
foreach vti = AllScalableAndFixedIntegerVectors in {
838838
defvar suffix = vti.LMul.MX;
839839
let Predicates = GetVTypePredicates<vti>.Predicates in {
840840
// NOTE: We choose VMADD because it has the most commuting freedom. So it
@@ -877,7 +877,7 @@ multiclass VPatAVGADD_VV_VX_RM<SDNode vop, int vxrm, string suffix = ""> {
877877
//===----------------------------------------------------------------------===//
878878

879879
// 7.4. Vector Unit-Stride Instructions
880-
foreach vti = AllVectors in
880+
foreach vti = AllScalableAndFixedIntegerVectors in
881881
let Predicates = !if(!eq(vti.Scalar, f16), [HasVInstructionsF16Minimal],
882882
GetVTypePredicates<vti>.Predicates) in
883883
defm : VPatUSLoadStoreSDNode<vti.Vector, vti.Log2SEW, vti.LMul,
@@ -893,7 +893,7 @@ defm : VPatBinarySDNode_VV_VX_VI<add, "PseudoVADD">;
893893
defm : VPatBinarySDNode_VV_VX<sub, "PseudoVSUB">;
894894
// Handle VRSUB specially since it's the only integer binary op with reversed
895895
// pattern operands
896-
foreach vti = AllIntegerVectors in {
896+
foreach vti = AllScalableAndFixedIntegerVectors in {
897897
// FIXME: The AddedComplexity here is covering up a missing matcher for
898898
// widening vwsub.vx which can recognize a extended folded into the
899899
// scalar of the splat.
@@ -922,7 +922,7 @@ defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, zext_oneuse, "PseudoVWSUBU">;
922922
defm : VPatWidenBinarySDNode_VV_VX_WV_WX<sub, anyext_oneuse, "PseudoVWSUBU">;
923923

924924
// shl (ext v, splat 1) is a special case of widening add.
925-
foreach vtiToWti = AllWidenableIntVectors in {
925+
foreach vtiToWti = AllScalableAndFixedWidenableIntVectors in {
926926
defvar vti = vtiToWti.Vti;
927927
defvar wti = vtiToWti.Wti;
928928
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
@@ -933,6 +933,7 @@ foreach vtiToWti = AllWidenableIntVectors in {
933933
(wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs1,
934934
vti.AVL, vti.Log2SEW, TA_MA)>;
935935
def : Pat<(shl (wti.Vector (zext_oneuse (vti.Vector vti.RegClass:$rs1))),
936+
// TODO: Need to make this splat 1 generic over fixed and scalable types
936937
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), 1, (XLenVT srcvalue)))),
937938
(!cast<Instruction>("PseudoVWADDU_VV_"#vti.LMul.MX)
938939
(wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs1,
@@ -942,6 +943,7 @@ foreach vtiToWti = AllWidenableIntVectors in {
942943
(!cast<Instruction>("PseudoVWADDU_VV_"#vti.LMul.MX)
943944
(wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs1,
944945
vti.AVL, vti.Log2SEW, TA_MA)>;
946+
if wti.Vector.isScalable then {
945947
def : Pat<(shl (wti.Vector (riscv_sext_vl_oneuse (vti.Vector vti.RegClass:$rs1), (vti.Mask V0), VLOpFrag)),
946948
(wti.Vector (riscv_vmv_v_x_vl (wti.Vector undef), 1, (XLenVT srcvalue)))),
947949
(!cast<Instruction>("PseudoVWADD_VV_"#vti.LMul.MX#"_MASK")
@@ -952,22 +954,23 @@ foreach vtiToWti = AllWidenableIntVectors in {
952954
(!cast<Instruction>("PseudoVWADDU_VV_"#vti.LMul.MX#"_MASK")
953955
(wti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs1, vti.RegClass:$rs1,
954956
(vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>;
957+
}
955958
}
956959
}
957960

958961
// 11.3. Vector Integer Extension
959962
defm : VPatExtendSDNode_V<[zext, anyext], "PseudoVZEXT", "VF2",
960-
AllFractionableVF2IntVectors>;
963+
AllFixedAndScalableFractionableVF2IntVectors>;
961964
defm : VPatExtendSDNode_V<[sext], "PseudoVSEXT", "VF2",
962-
AllFractionableVF2IntVectors>;
965+
AllFixedAndScalableFractionableVF2IntVectors>;
963966
defm : VPatExtendSDNode_V<[zext, anyext], "PseudoVZEXT", "VF4",
964-
AllFractionableVF4IntVectors>;
967+
AllFixedAndScalableFractionableVF4IntVectors>;
965968
defm : VPatExtendSDNode_V<[sext], "PseudoVSEXT", "VF4",
966-
AllFractionableVF4IntVectors>;
969+
AllFixedAndScalableFractionableVF4IntVectors>;
967970
defm : VPatExtendSDNode_V<[zext, anyext], "PseudoVZEXT", "VF8",
968-
AllFractionableVF8IntVectors>;
971+
AllFixedAndScalableFractionableVF8IntVectors>;
969972
defm : VPatExtendSDNode_V<[sext], "PseudoVSEXT", "VF8",
970-
AllFractionableVF8IntVectors>;
973+
AllFixedAndScalableFractionableVF8IntVectors>;
971974

972975
// 11.5. Vector Bitwise Logical Instructions
973976
defm : VPatBinarySDNode_VV_VX_VI<and, "PseudoVAND">;
@@ -1455,3 +1458,14 @@ foreach vti = NoGroupFloatVectors in {
14551458
def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), 0)),
14561459
(vfmv_f_s_inst vti.RegClass:$rs2, vti.Log2SEW)>;
14571460
}
1461+
1462+
foreach vti = FixedLengthVectors in {
1463+
defvar vslidedown_v_i_inst = !cast<Instruction>("PseudoVSLIDEDOWN_VI_" # vti.LMul.MX);
1464+
defvar vmv_x_s_inst = !cast<Instruction>(!strconcat("PseudoVMV_",
1465+
vti.ScalarSuffix,
1466+
"_S"));
1467+
// let Predicates = GetVTypePredicates<vti>.Predicates in
1468+
// def : Pat<(vti.Scalar (extractelt (vti.Vector vti.RegClass:$rs2), uimm5:$idx)),
1469+
// (vmv_x_s_inst (vslidedown_v_i_inst (IMPLICIT_DEF), vti.RegClass:$rs2, $idx, vti.AVL, vti.Log2SEW, 0),
1470+
// vti.Log2SEW)>;
1471+
}

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2368,7 +2368,7 @@ foreach vti = AllVectors in {
23682368
vti.RegClass:$passthru, vti.RegClass:$rs2, GPR:$vl, vti.Log2SEW, TU_MU)>;
23692369
}
23702370

2371-
foreach vti = AllIntegerVectors in {
2371+
foreach vti = AllScalableAndFixedIntegerVectors in {
23722372
def : Pat<(vti.Vector (riscv_vmv_v_x_vl vti.RegClass:$passthru, GPR:$rs2, VLOpFrag)),
23732373
(!cast<Instruction>("PseudoVMV_V_X_"#vti.LMul.MX)
23742374
vti.RegClass:$passthru, GPR:$rs2, GPR:$vl, vti.Log2SEW, TU_MU)>;

0 commit comments

Comments
 (0)