Skip to content

Commit 9f4a2a8

Browse files
committed
[RISCV] Separate lowering of constant build vector into a helper [nfc]
We have a bunch of special casing for constant vectors, and the costing is generally different. Separate out the logic so that it's easier to follow.
1 parent feafc2d commit 9f4a2a8

File tree

1 file changed

+194
-151
lines changed

1 file changed

+194
-151
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 194 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -2999,8 +2999,101 @@ static SDValue matchSplatAsGather(SDValue SplatVal, MVT VT, const SDLoc &DL,
29992999
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
30003000
}
30013001

3002-
static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
3003-
const RISCVSubtarget &Subtarget) {
3002+
3003+
/// Try and optimize BUILD_VECTORs with "dominant values" - these are values
3004+
/// which constitute a large proportion of the elements. In such cases we can
3005+
/// splat a vector with the dominant element and make up the shortfall with
3006+
/// INSERT_VECTOR_ELTs. Returns SDValue if not profitable.
3007+
/// Note that this includes vectors of 2 elements by association. The
3008+
/// upper-most element is the "dominant" one, allowing us to use a splat to
3009+
/// "insert" the upper element, and an insert of the lower element at position
3010+
/// 0, which improves codegen.
3011+
static SDValue lowerBuildVectorViaDominantValues(SDValue Op, SelectionDAG &DAG,
3012+
const RISCVSubtarget &Subtarget) {
3013+
MVT VT = Op.getSimpleValueType();
3014+
assert(VT.isFixedLengthVector() && "Unexpected vector!");
3015+
3016+
SDLoc DL(Op);
3017+
3018+
MVT XLenVT = Subtarget.getXLenVT();
3019+
unsigned NumElts = Op.getNumOperands();
3020+
3021+
SDValue DominantValue;
3022+
unsigned MostCommonCount = 0;
3023+
DenseMap<SDValue, unsigned> ValueCounts;
3024+
unsigned NumUndefElts =
3025+
count_if(Op->op_values(), [](const SDValue &V) { return V.isUndef(); });
3026+
3027+
// Track the number of scalar loads we know we'd be inserting, estimated as
3028+
// any non-zero floating-point constant. Other kinds of element are either
3029+
// already in registers or are materialized on demand. The threshold at which
3030+
// a vector load is more desirable than several scalar materializion and
3031+
// vector-insertion instructions is not known.
3032+
unsigned NumScalarLoads = 0;
3033+
3034+
for (SDValue V : Op->op_values()) {
3035+
if (V.isUndef())
3036+
continue;
3037+
3038+
ValueCounts.insert(std::make_pair(V, 0));
3039+
unsigned &Count = ValueCounts[V];
3040+
if (0 == Count)
3041+
if (auto *CFP = dyn_cast<ConstantFPSDNode>(V))
3042+
NumScalarLoads += !CFP->isExactlyValue(+0.0);
3043+
3044+
// Is this value dominant? In case of a tie, prefer the highest element as
3045+
// it's cheaper to insert near the beginning of a vector than it is at the
3046+
// end.
3047+
if (++Count >= MostCommonCount) {
3048+
DominantValue = V;
3049+
MostCommonCount = Count;
3050+
}
3051+
}
3052+
3053+
assert(DominantValue && "Not expecting an all-undef BUILD_VECTOR");
3054+
unsigned NumDefElts = NumElts - NumUndefElts;
3055+
unsigned DominantValueCountThreshold = NumDefElts <= 2 ? 0 : NumDefElts - 2;
3056+
3057+
// Don't perform this optimization when optimizing for size, since
3058+
// materializing elements and inserting them tends to cause code bloat.
3059+
if (!DAG.shouldOptForSize() && NumScalarLoads < NumElts &&
3060+
(NumElts != 2 || ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) &&
3061+
((MostCommonCount > DominantValueCountThreshold) ||
3062+
(ValueCounts.size() <= Log2_32(NumDefElts)))) {
3063+
// Start by splatting the most common element.
3064+
SDValue Vec = DAG.getSplatBuildVector(VT, DL, DominantValue);
3065+
3066+
DenseSet<SDValue> Processed{DominantValue};
3067+
MVT SelMaskTy = VT.changeVectorElementType(MVT::i1);
3068+
for (const auto &OpIdx : enumerate(Op->ops())) {
3069+
const SDValue &V = OpIdx.value();
3070+
if (V.isUndef() || !Processed.insert(V).second)
3071+
continue;
3072+
if (ValueCounts[V] == 1) {
3073+
Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V,
3074+
DAG.getConstant(OpIdx.index(), DL, XLenVT));
3075+
} else {
3076+
// Blend in all instances of this value using a VSELECT, using a
3077+
// mask where each bit signals whether that element is the one
3078+
// we're after.
3079+
SmallVector<SDValue> Ops;
3080+
transform(Op->op_values(), std::back_inserter(Ops), [&](SDValue V1) {
3081+
return DAG.getConstant(V == V1, DL, XLenVT);
3082+
});
3083+
Vec = DAG.getNode(ISD::VSELECT, DL, VT,
3084+
DAG.getBuildVector(SelMaskTy, DL, Ops),
3085+
DAG.getSplatBuildVector(VT, DL, V), Vec);
3086+
}
3087+
}
3088+
3089+
return Vec;
3090+
}
3091+
3092+
return SDValue();
3093+
}
3094+
3095+
static SDValue lowerBuildVectorOfConstants(SDValue Op, SelectionDAG &DAG,
3096+
const RISCVSubtarget &Subtarget) {
30043097
MVT VT = Op.getSimpleValueType();
30053098
assert(VT.isFixedLengthVector() && "Unexpected vector!");
30063099

@@ -3033,91 +3126,63 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
30333126
// codegen across RV32 and RV64.
30343127
unsigned NumViaIntegerBits = std::clamp(NumElts, 8u, Subtarget.getXLen());
30353128
NumViaIntegerBits = std::min(NumViaIntegerBits, Subtarget.getELEN());
3036-
if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) {
3037-
// If we have to use more than one INSERT_VECTOR_ELT then this
3038-
// optimization is likely to increase code size; avoid peforming it in
3039-
// such a case. We can use a load from a constant pool in this case.
3040-
if (DAG.shouldOptForSize() && NumElts > NumViaIntegerBits)
3041-
return SDValue();
3042-
// Now we can create our integer vector type. Note that it may be larger
3043-
// than the resulting mask type: v4i1 would use v1i8 as its integer type.
3044-
unsigned IntegerViaVecElts = divideCeil(NumElts, NumViaIntegerBits);
3045-
MVT IntegerViaVecVT =
3046-
MVT::getVectorVT(MVT::getIntegerVT(NumViaIntegerBits),
3047-
IntegerViaVecElts);
3048-
3049-
uint64_t Bits = 0;
3050-
unsigned BitPos = 0, IntegerEltIdx = 0;
3051-
SmallVector<SDValue, 8> Elts(IntegerViaVecElts);
3052-
3053-
for (unsigned I = 0; I < NumElts;) {
3054-
SDValue V = Op.getOperand(I);
3055-
bool BitValue = !V.isUndef() && cast<ConstantSDNode>(V)->getZExtValue();
3056-
Bits |= ((uint64_t)BitValue << BitPos);
3057-
++BitPos;
3058-
++I;
3059-
3060-
// Once we accumulate enough bits to fill our scalar type or process the
3061-
// last element, insert into our vector and clear our accumulated data.
3062-
if (I % NumViaIntegerBits == 0 || I == NumElts) {
3063-
if (NumViaIntegerBits <= 32)
3064-
Bits = SignExtend64<32>(Bits);
3065-
SDValue Elt = DAG.getConstant(Bits, DL, XLenVT);
3066-
Elts[IntegerEltIdx] = Elt;
3067-
Bits = 0;
3068-
BitPos = 0;
3069-
IntegerEltIdx++;
3070-
}
3071-
}
3072-
3073-
SDValue Vec = DAG.getBuildVector(IntegerViaVecVT, DL, Elts);
3074-
3075-
if (NumElts < NumViaIntegerBits) {
3076-
// If we're producing a smaller vector than our minimum legal integer
3077-
// type, bitcast to the equivalent (known-legal) mask type, and extract
3078-
// our final mask.
3079-
assert(IntegerViaVecVT == MVT::v1i8 && "Unexpected mask vector type");
3080-
Vec = DAG.getBitcast(MVT::v8i1, Vec);
3081-
Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Vec,
3082-
DAG.getConstant(0, DL, XLenVT));
3083-
} else {
3084-
// Else we must have produced an integer type with the same size as the
3085-
// mask type; bitcast for the final result.
3086-
assert(VT.getSizeInBits() == IntegerViaVecVT.getSizeInBits());
3087-
Vec = DAG.getBitcast(VT, Vec);
3129+
// If we have to use more than one INSERT_VECTOR_ELT then this
3130+
// optimization is likely to increase code size; avoid peforming it in
3131+
// such a case. We can use a load from a constant pool in this case.
3132+
if (DAG.shouldOptForSize() && NumElts > NumViaIntegerBits)
3133+
return SDValue();
3134+
// Now we can create our integer vector type. Note that it may be larger
3135+
// than the resulting mask type: v4i1 would use v1i8 as its integer type.
3136+
unsigned IntegerViaVecElts = divideCeil(NumElts, NumViaIntegerBits);
3137+
MVT IntegerViaVecVT =
3138+
MVT::getVectorVT(MVT::getIntegerVT(NumViaIntegerBits),
3139+
IntegerViaVecElts);
3140+
3141+
uint64_t Bits = 0;
3142+
unsigned BitPos = 0, IntegerEltIdx = 0;
3143+
SmallVector<SDValue, 8> Elts(IntegerViaVecElts);
3144+
3145+
for (unsigned I = 0; I < NumElts;) {
3146+
SDValue V = Op.getOperand(I);
3147+
bool BitValue = !V.isUndef() && cast<ConstantSDNode>(V)->getZExtValue();
3148+
Bits |= ((uint64_t)BitValue << BitPos);
3149+
++BitPos;
3150+
++I;
3151+
3152+
// Once we accumulate enough bits to fill our scalar type or process the
3153+
// last element, insert into our vector and clear our accumulated data.
3154+
if (I % NumViaIntegerBits == 0 || I == NumElts) {
3155+
if (NumViaIntegerBits <= 32)
3156+
Bits = SignExtend64<32>(Bits);
3157+
SDValue Elt = DAG.getConstant(Bits, DL, XLenVT);
3158+
Elts[IntegerEltIdx] = Elt;
3159+
Bits = 0;
3160+
BitPos = 0;
3161+
IntegerEltIdx++;
30883162
}
3089-
3090-
return Vec;
30913163
}
30923164

3093-
// A BUILD_VECTOR can be lowered as a SETCC. For each fixed-length mask
3094-
// vector type, we have a legal equivalently-sized i8 type, so we can use
3095-
// that.
3096-
MVT WideVecVT = VT.changeVectorElementType(MVT::i8);
3097-
SDValue VecZero = DAG.getConstant(0, DL, WideVecVT);
3165+
SDValue Vec = DAG.getBuildVector(IntegerViaVecVT, DL, Elts);
30983166

3099-
SDValue WideVec;
3100-
if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
3101-
// For a splat, perform a scalar truncate before creating the wider
3102-
// vector.
3103-
assert(Splat.getValueType() == XLenVT &&
3104-
"Unexpected type for i1 splat value");
3105-
Splat = DAG.getNode(ISD::AND, DL, XLenVT, Splat,
3106-
DAG.getConstant(1, DL, XLenVT));
3107-
WideVec = DAG.getSplatBuildVector(WideVecVT, DL, Splat);
3167+
if (NumElts < NumViaIntegerBits) {
3168+
// If we're producing a smaller vector than our minimum legal integer
3169+
// type, bitcast to the equivalent (known-legal) mask type, and extract
3170+
// our final mask.
3171+
assert(IntegerViaVecVT == MVT::v1i8 && "Unexpected mask vector type");
3172+
Vec = DAG.getBitcast(MVT::v8i1, Vec);
3173+
Vec = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Vec,
3174+
DAG.getConstant(0, DL, XLenVT));
31083175
} else {
3109-
SmallVector<SDValue, 8> Ops(Op->op_values());
3110-
WideVec = DAG.getBuildVector(WideVecVT, DL, Ops);
3111-
SDValue VecOne = DAG.getConstant(1, DL, WideVecVT);
3112-
WideVec = DAG.getNode(ISD::AND, DL, WideVecVT, WideVec, VecOne);
3176+
// Else we must have produced an integer type with the same size as the
3177+
// mask type; bitcast for the final result.
3178+
assert(VT.getSizeInBits() == IntegerViaVecVT.getSizeInBits());
3179+
Vec = DAG.getBitcast(VT, Vec);
31133180
}
31143181

3115-
return DAG.getSetCC(DL, VT, WideVec, VecZero, ISD::SETNE);
3182+
return Vec;
31163183
}
31173184

31183185
if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
3119-
if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget))
3120-
return Gather;
31213186
unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
31223187
: RISCVISD::VMV_V_X_VL;
31233188
Splat =
@@ -3246,91 +3311,69 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
32463311
}
32473312
}
32483313

3249-
// Try and optimize BUILD_VECTORs with "dominant values" - these are values
3250-
// which constitute a large proportion of the elements. In such cases we can
3251-
// splat a vector with the dominant element and make up the shortfall with
3252-
// INSERT_VECTOR_ELTs.
3253-
// Note that this includes vectors of 2 elements by association. The
3254-
// upper-most element is the "dominant" one, allowing us to use a splat to
3255-
// "insert" the upper element, and an insert of the lower element at position
3256-
// 0, which improves codegen.
3257-
SDValue DominantValue;
3258-
unsigned MostCommonCount = 0;
3259-
DenseMap<SDValue, unsigned> ValueCounts;
3260-
unsigned NumUndefElts =
3261-
count_if(Op->op_values(), [](const SDValue &V) { return V.isUndef(); });
3314+
if (SDValue Res = lowerBuildVectorViaDominantValues(Op, DAG, Subtarget))
3315+
return Res;
32623316

3263-
// Track the number of scalar loads we know we'd be inserting, estimated as
3264-
// any non-zero floating-point constant. Other kinds of element are either
3265-
// already in registers or are materialized on demand. The threshold at which
3266-
// a vector load is more desirable than several scalar materializion and
3267-
// vector-insertion instructions is not known.
3268-
unsigned NumScalarLoads = 0;
3317+
// For constant vectors, use generic constant pool lowering. Otherwise,
3318+
// we'd have to materialize constants in GPRs just to move them into the
3319+
// vector.
3320+
return SDValue();
3321+
}
32693322

3270-
for (SDValue V : Op->op_values()) {
3271-
if (V.isUndef())
3272-
continue;
3323+
static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
3324+
const RISCVSubtarget &Subtarget) {
3325+
MVT VT = Op.getSimpleValueType();
3326+
assert(VT.isFixedLengthVector() && "Unexpected vector!");
32733327

3274-
ValueCounts.insert(std::make_pair(V, 0));
3275-
unsigned &Count = ValueCounts[V];
3276-
if (0 == Count)
3277-
if (auto *CFP = dyn_cast<ConstantFPSDNode>(V))
3278-
NumScalarLoads += !CFP->isExactlyValue(+0.0);
3328+
if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
3329+
ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
3330+
return lowerBuildVectorOfConstants(Op, DAG, Subtarget);
32793331

3280-
// Is this value dominant? In case of a tie, prefer the highest element as
3281-
// it's cheaper to insert near the beginning of a vector than it is at the
3282-
// end.
3283-
if (++Count >= MostCommonCount) {
3284-
DominantValue = V;
3285-
MostCommonCount = Count;
3286-
}
3287-
}
3332+
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
32883333

3289-
assert(DominantValue && "Not expecting an all-undef BUILD_VECTOR");
3290-
unsigned NumDefElts = NumElts - NumUndefElts;
3291-
unsigned DominantValueCountThreshold = NumDefElts <= 2 ? 0 : NumDefElts - 2;
3334+
SDLoc DL(Op);
3335+
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
32923336

3293-
// Don't perform this optimization when optimizing for size, since
3294-
// materializing elements and inserting them tends to cause code bloat.
3295-
if (!DAG.shouldOptForSize() && NumScalarLoads < NumElts &&
3296-
(NumElts != 2 || ISD::isBuildVectorOfConstantSDNodes(Op.getNode())) &&
3297-
((MostCommonCount > DominantValueCountThreshold) ||
3298-
(ValueCounts.size() <= Log2_32(NumDefElts)))) {
3299-
// Start by splatting the most common element.
3300-
SDValue Vec = DAG.getSplatBuildVector(VT, DL, DominantValue);
3337+
MVT XLenVT = Subtarget.getXLenVT();
33013338

3302-
DenseSet<SDValue> Processed{DominantValue};
3303-
MVT SelMaskTy = VT.changeVectorElementType(MVT::i1);
3304-
for (const auto &OpIdx : enumerate(Op->ops())) {
3305-
const SDValue &V = OpIdx.value();
3306-
if (V.isUndef() || !Processed.insert(V).second)
3307-
continue;
3308-
if (ValueCounts[V] == 1) {
3309-
Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VT, Vec, V,
3310-
DAG.getConstant(OpIdx.index(), DL, XLenVT));
3311-
} else {
3312-
// Blend in all instances of this value using a VSELECT, using a
3313-
// mask where each bit signals whether that element is the one
3314-
// we're after.
3315-
SmallVector<SDValue> Ops;
3316-
transform(Op->op_values(), std::back_inserter(Ops), [&](SDValue V1) {
3317-
return DAG.getConstant(V == V1, DL, XLenVT);
3318-
});
3319-
Vec = DAG.getNode(ISD::VSELECT, DL, VT,
3320-
DAG.getBuildVector(SelMaskTy, DL, Ops),
3321-
DAG.getSplatBuildVector(VT, DL, V), Vec);
3322-
}
3339+
if (VT.getVectorElementType() == MVT::i1) {
3340+
// A BUILD_VECTOR can be lowered as a SETCC. For each fixed-length mask
3341+
// vector type, we have a legal equivalently-sized i8 type, so we can use
3342+
// that.
3343+
MVT WideVecVT = VT.changeVectorElementType(MVT::i8);
3344+
SDValue VecZero = DAG.getConstant(0, DL, WideVecVT);
3345+
3346+
SDValue WideVec;
3347+
if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
3348+
// For a splat, perform a scalar truncate before creating the wider
3349+
// vector.
3350+
assert(Splat.getValueType() == XLenVT &&
3351+
"Unexpected type for i1 splat value");
3352+
Splat = DAG.getNode(ISD::AND, DL, XLenVT, Splat,
3353+
DAG.getConstant(1, DL, XLenVT));
3354+
WideVec = DAG.getSplatBuildVector(WideVecVT, DL, Splat);
3355+
} else {
3356+
SmallVector<SDValue, 8> Ops(Op->op_values());
3357+
WideVec = DAG.getBuildVector(WideVecVT, DL, Ops);
3358+
SDValue VecOne = DAG.getConstant(1, DL, WideVecVT);
3359+
WideVec = DAG.getNode(ISD::AND, DL, WideVecVT, WideVec, VecOne);
33233360
}
33243361

3325-
return Vec;
3362+
return DAG.getSetCC(DL, VT, WideVec, VecZero, ISD::SETNE);
33263363
}
33273364

3328-
// For constant vectors, use generic constant pool lowering. Otherwise,
3329-
// we'd have to materialize constants in GPRs just to move them into the
3330-
// vector.
3331-
if (ISD::isBuildVectorOfConstantSDNodes(Op.getNode()) ||
3332-
ISD::isBuildVectorOfConstantFPSDNodes(Op.getNode()))
3333-
return SDValue();
3365+
if (SDValue Splat = cast<BuildVectorSDNode>(Op)->getSplatValue()) {
3366+
if (auto Gather = matchSplatAsGather(Splat, VT, DL, DAG, Subtarget))
3367+
return Gather;
3368+
unsigned Opc = VT.isFloatingPoint() ? RISCVISD::VFMV_V_F_VL
3369+
: RISCVISD::VMV_V_X_VL;
3370+
Splat =
3371+
DAG.getNode(Opc, DL, ContainerVT, DAG.getUNDEF(ContainerVT), Splat, VL);
3372+
return convertFromScalableVector(VT, Splat, DAG, Subtarget);
3373+
}
3374+
3375+
if (SDValue Res = lowerBuildVectorViaDominantValues(Op, DAG, Subtarget))
3376+
return Res;
33343377

33353378
assert((!VT.isFloatingPoint() ||
33363379
VT.getVectorElementType().getSizeInBits() <= Subtarget.getFLen()) &&

0 commit comments

Comments
 (0)