Skip to content

Commit 3045357

Browse files
[RISCV][GISEL] Legalize G_SPLAT_VECTOR
1 parent 14693ad commit 3045357

File tree

9 files changed

+1854
-45
lines changed

9 files changed

+1854
-45
lines changed

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,6 +3006,15 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
30063006
Observer.changedInstr(MI);
30073007
return Legalized;
30083008
}
3009+
case TargetOpcode::G_SPLAT_VECTOR: {
3010+
if (TypeIdx != 1)
3011+
return UnableToLegalize;
3012+
3013+
Observer.changingInstr(MI);
3014+
widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
3015+
Observer.changedInstr(MI);
3016+
return Legalized;
3017+
}
30093018
}
30103019
}
30113020

llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ MachineIRBuilder::buildInstr(unsigned Opc, ArrayRef<DstOp> DstOps,
12781278
return DstTy.isScalar();
12791279
else
12801280
return DstTy.isVector() &&
1281-
DstTy.getNumElements() == Op0Ty.getNumElements();
1281+
DstTy.getElementCount() == Op0Ty.getElementCount();
12821282
}() && "Type Mismatch");
12831283
break;
12841284
}

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ typeIsLegalBoolVec(unsigned TypeIdx, std::initializer_list<LLT> BoolVecTys,
6666
return all(typeInSet(TypeIdx, BoolVecTys), P);
6767
}
6868

69+
static LegalityPredicate hasStdExtD(const RISCVSubtarget &ST) {
70+
return [=, &ST](const LegalityQuery &Query) { return ST.hasStdExtD(); };
71+
}
72+
static LegalityPredicate hasVInstructionsI64(const RISCVSubtarget &ST) {
73+
return
74+
[=, &ST](const LegalityQuery &Query) { return ST.hasVInstructionsI64(); };
75+
}
76+
static LegalityPredicate hasVInstructionsF64(const RISCVSubtarget &ST) {
77+
return
78+
[=, &ST](const LegalityQuery &Query) { return ST.hasVInstructionsI64(); };
79+
}
80+
6981
RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
7082
: STI(ST), XLen(STI.getXLen()), sXLen(LLT::scalar(XLen)) {
7183
const LLT sDoubleXLen = LLT::scalar(2 * XLen);
@@ -413,6 +425,27 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
413425
.clampScalar(0, sXLen, sXLen)
414426
.customFor({sXLen});
415427

428+
auto &SplatActions =
429+
getActionDefinitionsBuilder(G_SPLAT_VECTOR)
430+
.legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
431+
typeIs(1, sXLen)))
432+
.customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), typeIs(1, s1)));
433+
// Handle case of s64 element vectors on RV32. If the subtarget does not have
434+
// f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
435+
// does have f64, then we don't know whether the type is an f64 or an i64,
436+
// so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
437+
// depending on how the instructions it consumes are legalized. They are not
438+
// legalized yet since legalization is in reverse postorder, so we cannot
439+
// make the decision at this moment.
440+
if (XLen == 32)
441+
SplatActions
442+
.legalIf(all(typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}),
443+
typeIs(1, s64), hasVInstructionsF64(ST), hasStdExtD(ST)))
444+
.customIf(all(typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}),
445+
hasVInstructionsI64(ST), typeIs(1, s64)));
446+
447+
SplatActions.clampScalar(1, sXLen, sXLen);
448+
416449
getLegacyLegalizerInfo().computeTables();
417450
}
418451

@@ -603,6 +636,118 @@ bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
603636
return true;
604637
}
605638

639+
/// Return the type of the mask type suitable for masking the provided
640+
/// vector type. This is simply an i1 element type vector of the same
641+
/// (possibly scalable) length.
642+
static LLT getMaskTypeFor(LLT VecTy) {
643+
assert(VecTy.isVector());
644+
ElementCount EC = VecTy.getElementCount();
645+
return LLT::vector(EC, LLT::scalar(1));
646+
}
647+
648+
/// Creates an all ones mask suitable for masking a vector of type VecTy with
649+
/// vector length VL.
650+
static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
651+
MachineIRBuilder &MIB,
652+
MachineRegisterInfo &MRI) {
653+
LLT MaskTy = getMaskTypeFor(VecTy);
654+
return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL});
655+
}
656+
657+
/// Gets the two common "VL" operands: an all-ones mask and the vector length.
658+
/// VecTy is a scalable vector type.
659+
static std::pair<MachineInstrBuilder, Register>
660+
buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
661+
MachineRegisterInfo &MRI) {
662+
LLT VecTy = Dst.getLLTTy(MRI);
663+
assert(VecTy.isScalableVector() && "Expecting scalable container type");
664+
Register VL(RISCV::X0);
665+
MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI);
666+
return {Mask, VL};
667+
}
668+
669+
static MachineInstrBuilder
670+
buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo,
671+
Register Hi, Register VL, MachineIRBuilder &MIB,
672+
MachineRegisterInfo &MRI) {
673+
// TODO: If the Hi bits of the splat are undefined, then it's fine to just
674+
// splat Lo even if it might be sign extended. I don't think we have
675+
// introduced a case where we're build a s64 where the upper bits are undef
676+
// yet.
677+
678+
// Fall back to a stack store and stride x0 vector load.
679+
// TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
680+
// preprocessDAG in SDAG.
681+
return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
682+
{Passthru, Lo, Hi, VL});
683+
}
684+
685+
static MachineInstrBuilder
686+
buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru,
687+
const SrcOp &Scalar, Register VL,
688+
MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
689+
assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!");
690+
auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar);
691+
return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0),
692+
Unmerge.getReg(1), VL, MIB, MRI);
693+
}
694+
695+
// Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
696+
// legal equivalently-sized i8 type, so we can use that as a go-between.
697+
// Splats of s1 types that have constant value can be legalized as VMSET_VL or
698+
// VMCLR_VL.
699+
bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
700+
MachineIRBuilder &MIB) const {
701+
assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR);
702+
703+
MachineRegisterInfo &MRI = *MIB.getMRI();
704+
705+
Register Dst = MI.getOperand(0).getReg();
706+
Register SplatVal = MI.getOperand(1).getReg();
707+
708+
LLT VecTy = MRI.getType(Dst);
709+
LLT XLenTy(STI.getXLenVT());
710+
711+
// Handle case of s64 element vectors on rv32
712+
if (XLenTy.getSizeInBits() == 32 &&
713+
VecTy.getElementType().getSizeInBits() == 64) {
714+
auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI);
715+
buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB,
716+
MRI);
717+
MI.eraseFromParent();
718+
return true;
719+
}
720+
721+
// All-zeros or all-ones splats are handled specially.
722+
MachineInstr &SplatValMI = *MRI.getVRegDef(SplatVal);
723+
if (isAllOnesOrAllOnesSplat(SplatValMI, MRI)) {
724+
auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
725+
MIB.buildInstr(RISCV::G_VMSET_VL, {Dst}, {VL});
726+
MI.eraseFromParent();
727+
return true;
728+
}
729+
if (isNullOrNullSplat(SplatValMI, MRI)) {
730+
auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
731+
MIB.buildInstr(RISCV::G_VMCLR_VL, {Dst}, {VL});
732+
MI.eraseFromParent();
733+
return true;
734+
}
735+
736+
// Handle non-constant mask splat (i.e. not sure if it's all zeros or all
737+
// ones) by promoting it to an s8 splat.
738+
LLT InterEltTy = LLT::scalar(8);
739+
LLT InterTy = VecTy.changeElementType(InterEltTy);
740+
auto ZExtSplatVal = MIB.buildZExt(InterEltTy, SplatVal);
741+
auto And =
742+
MIB.buildAnd(InterEltTy, ZExtSplatVal, MIB.buildConstant(InterEltTy, 1));
743+
auto LHS = MIB.buildSplatVector(InterTy, And);
744+
auto ZeroSplat =
745+
MIB.buildSplatVector(InterTy, MIB.buildConstant(InterEltTy, 0));
746+
MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
747+
MI.eraseFromParent();
748+
return true;
749+
}
750+
606751
bool RISCVLegalizerInfo::legalizeCustom(
607752
LegalizerHelper &Helper, MachineInstr &MI,
608753
LostDebugLocObserver &LocObserver) const {
@@ -666,6 +811,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
666811
case TargetOpcode::G_SEXT:
667812
case TargetOpcode::G_ANYEXT:
668813
return legalizeExt(MI, MIRBuilder);
814+
case TargetOpcode::G_SPLAT_VECTOR:
815+
return legalizeSplatVector(MI, MIRBuilder);
669816
}
670817

671818
llvm_unreachable("expected switch to return");

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class RISCVLegalizerInfo : public LegalizerInfo {
4444
bool legalizeVAStart(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
4545
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
4646
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
47+
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
4748
};
4849
} // end namespace llvm
4950
#endif

llvm/lib/Target/RISCV/RISCVInstrGISel.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,28 @@ def G_READ_VLENB : RISCVGenericInstruction {
3232
let hasSideEffects = false;
3333
}
3434
def : GINodeEquiv<G_READ_VLENB, riscv_read_vlenb>;
35+
36+
// Pseudo equivalent to a RISCVISD::VMCLR_VL
37+
def G_VMCLR_VL : RISCVGenericInstruction {
38+
let OutOperandList = (outs type0:$dst);
39+
let InOperandList = (ins type1:$vl);
40+
let hasSideEffects = false;
41+
}
42+
def : GINodeEquiv<G_VMCLR_VL, riscv_vmclr_vl>;
43+
44+
// Pseudo equivalent to a RISCVISD::VMSET_VL
45+
def G_VMSET_VL : RISCVGenericInstruction {
46+
let OutOperandList = (outs type0:$dst);
47+
let InOperandList = (ins type1:$vl);
48+
let hasSideEffects = false;
49+
}
50+
def : GINodeEquiv<G_VMSET_VL, riscv_vmset_vl>;
51+
52+
// Pseudo equivalent to a RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL. There is no
53+
// record to mark is equivalent to using GINodeEquiv because it gets lowered
54+
// before instruction selection.
55+
def G_SPLAT_VECTOR_SPLIT_I64_VL : RISCVGenericInstruction {
56+
let OutOperandList = (outs type0:$dst);
57+
let InOperandList = (ins type0:$passthru, type1:$hi, type1:$lo, type2:$vl);
58+
let hasSideEffects = false;
59+
}

0 commit comments

Comments
 (0)