Skip to content

Commit 8aa3a77

Browse files
[RISCV][GISEL] Legalize G_ZEXT, G_SEXT, and G_ANYEXT, G_SPLAT_VECTOR, and G_ICMP for scalable vector types
This patch legalizes G_ZEXT, G_SEXT, and G_ANYEXT. If the type is a legal mask type, then the instruction is legalized as the element-wise select, where the condition on the select is the mask typed source operand, and the true and false values are 1 or -1 (for zero/any-extension and sign extension) and zero. If the type is a legal integer or vector integer type, then the instruction is marked as legal. The legalization of the extends may introduce a G_SPLAT_VECTOR, which needs to be legalized in this patch for the extend test cases to pass. A G_SPLAT_VECTOR is legal if the vector type is a legal integer or floating point vector type and the source operand is sXLen type. This is because the SelectionDAG patterns only support sXLen typed ISD::SPLAT_VECTORS, and we'd like to reuse those patterns. A G_SPLAT_VECTOR is cutom legalized if it has a legal s1 element vector type and s1 scalar operand. It is legalized to G_VMSET_VL or G_VMCLR_VL if the splat is all ones or all zeros respectivley. In the case of a non-constant mask splat, we legalize by promoting the scalar value to s8. In order to get the s8 element vector back into s1 vector, we use a G_ICMP. In order for the splat vector and extend tests to pass, we also need to legalize G_ICMP in this patch. A G_ICMP is legal if the destination type is a legal bool vector and the LHS and RHS are legal integer vector types.
1 parent 029e1d7 commit 8aa3a77

14 files changed

+7436
-16
lines changed

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3006,6 +3006,15 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
30063006
Observer.changedInstr(MI);
30073007
return Legalized;
30083008
}
3009+
case TargetOpcode::G_SPLAT_VECTOR: {
3010+
if (TypeIdx != 1)
3011+
return UnableToLegalize;
3012+
3013+
Observer.changingInstr(MI);
3014+
widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
3015+
Observer.changedInstr(MI);
3016+
return Legalized;
3017+
}
30093018
}
30103019
}
30113020

llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ MachineIRBuilder::buildInstr(unsigned Opc, ArrayRef<DstOp> DstOps,
12781278
return DstTy.isScalar();
12791279
else
12801280
return DstTy.isVector() &&
1281-
DstTy.getNumElements() == Op0Ty.getNumElements();
1281+
DstTy.getElementCount() == Op0Ty.getElementCount();
12821282
}() && "Type Mismatch");
12831283
break;
12841284
}

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp

Lines changed: 177 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,21 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
139139
.clampScalar(0, s32, sXLen)
140140
.minScalarSameAs(1, 0);
141141

142+
auto &ExtActions =
143+
getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT})
144+
.legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
145+
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)));
142146
if (ST.is64Bit()) {
143-
getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT})
144-
.legalFor({{sXLen, s32}})
145-
.maxScalar(0, sXLen);
146-
147+
ExtActions.legalFor({{sXLen, s32}});
147148
getActionDefinitionsBuilder(G_SEXT_INREG)
148149
.customFor({sXLen})
149150
.maxScalar(0, sXLen)
150151
.lower();
151152
} else {
152-
getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar(0, sXLen);
153-
154153
getActionDefinitionsBuilder(G_SEXT_INREG).maxScalar(0, sXLen).lower();
155154
}
155+
ExtActions.customIf(typeIsLegalBoolVec(1, BoolVecTys, ST))
156+
.maxScalar(0, sXLen);
156157

157158
// Merge/Unmerge
158159
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
@@ -235,7 +236,9 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
235236

236237
getActionDefinitionsBuilder(G_ICMP)
237238
.legalFor({{sXLen, sXLen}, {sXLen, p0}})
238-
.widenScalarToNextPow2(1)
239+
.legalIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST),
240+
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)))
241+
.widenScalarOrEltToNextPow2OrMinSize(1, 8)
239242
.clampScalar(1, sXLen, sXLen)
240243
.clampScalar(0, sXLen, sXLen);
241244

@@ -418,6 +421,29 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
418421
.clampScalar(0, sXLen, sXLen)
419422
.customFor({sXLen});
420423

424+
auto &SplatActions =
425+
getActionDefinitionsBuilder(G_SPLAT_VECTOR)
426+
.legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
427+
typeIs(1, sXLen)))
428+
.customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), typeIs(1, s1)));
429+
// Handle case of s64 element vectors on RV32. If the subtarget does not have
430+
// f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
431+
// does have f64, then we don't know whether the type is an f64 or an i64,
432+
// so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
433+
// depending on how the instructions it consumes are legalized. They are not
434+
// legalized yet since legalization is in reverse postorder, so we cannot
435+
// make the decision at this moment.
436+
if (XLen == 32) {
437+
if (ST.hasVInstructionsF64() && ST.hasStdExtD())
438+
SplatActions.legalIf(all(
439+
typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64)));
440+
else if (ST.hasVInstructionsI64())
441+
SplatActions.customIf(all(
442+
typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64)));
443+
}
444+
445+
SplatActions.clampScalar(1, sXLen, sXLen);
446+
421447
getLegacyLegalizerInfo().computeTables();
422448
}
423449

@@ -576,7 +602,145 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
576602
auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3));
577603
MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val));
578604
}
605+
MI.eraseFromParent();
606+
return true;
607+
}
608+
609+
// Custom-lower extensions from mask vectors by using a vselect either with 1
610+
// for zero/any-extension or -1 for sign-extension:
611+
// (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
612+
// Note that any-extension is lowered identically to zero-extension.
613+
bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
614+
MachineIRBuilder &MIB) const {
615+
616+
unsigned Opc = MI.getOpcode();
617+
assert(Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
618+
Opc == TargetOpcode::G_ANYEXT);
619+
620+
MachineRegisterInfo &MRI = *MIB.getMRI();
621+
Register Dst = MI.getOperand(0).getReg();
622+
Register Src = MI.getOperand(1).getReg();
623+
624+
LLT DstTy = MRI.getType(Dst);
625+
int64_t ExtTrueVal = Opc == TargetOpcode::G_SEXT ? -1 : 1;
626+
LLT DstEltTy = DstTy.getElementType();
627+
auto SplatZero = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, 0));
628+
auto SplatTrue =
629+
MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, ExtTrueVal));
630+
MIB.buildSelect(Dst, Src, SplatTrue, SplatZero);
631+
632+
MI.eraseFromParent();
633+
return true;
634+
}
635+
636+
/// Return the type of the mask type suitable for masking the provided
637+
/// vector type. This is simply an i1 element type vector of the same
638+
/// (possibly scalable) length.
639+
static LLT getMaskTypeFor(LLT VecTy) {
640+
assert(VecTy.isVector());
641+
ElementCount EC = VecTy.getElementCount();
642+
return LLT::vector(EC, LLT::scalar(1));
643+
}
644+
645+
/// Creates an all ones mask suitable for masking a vector of type VecTy with
646+
/// vector length VL.
647+
static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
648+
MachineIRBuilder &MIB,
649+
MachineRegisterInfo &MRI) {
650+
LLT MaskTy = getMaskTypeFor(VecTy);
651+
return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL});
652+
}
653+
654+
/// Gets the two common "VL" operands: an all-ones mask and the vector length.
655+
/// VecTy is a scalable vector type.
656+
static std::pair<MachineInstrBuilder, Register>
657+
buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
658+
MachineRegisterInfo &MRI) {
659+
LLT VecTy = Dst.getLLTTy(MRI);
660+
assert(VecTy.isScalableVector() && "Expecting scalable container type");
661+
Register VL(RISCV::X0);
662+
MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI);
663+
return {Mask, VL};
664+
}
665+
666+
static MachineInstrBuilder
667+
buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo,
668+
Register Hi, Register VL, MachineIRBuilder &MIB,
669+
MachineRegisterInfo &MRI) {
670+
// TODO: If the Hi bits of the splat are undefined, then it's fine to just
671+
// splat Lo even if it might be sign extended. I don't think we have
672+
// introduced a case where we're build a s64 where the upper bits are undef
673+
// yet.
674+
675+
// Fall back to a stack store and stride x0 vector load.
676+
// TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
677+
// preprocessDAG in SDAG.
678+
return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
679+
{Passthru, Lo, Hi, VL});
680+
}
681+
682+
static MachineInstrBuilder
683+
buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru,
684+
const SrcOp &Scalar, Register VL,
685+
MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
686+
assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!");
687+
auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar);
688+
return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0),
689+
Unmerge.getReg(1), VL, MIB, MRI);
690+
}
691+
692+
// Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
693+
// legal equivalently-sized i8 type, so we can use that as a go-between.
694+
// Splats of s1 types that have constant value can be legalized as VMSET_VL or
695+
// VMCLR_VL.
696+
bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
697+
MachineIRBuilder &MIB) const {
698+
assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR);
699+
700+
MachineRegisterInfo &MRI = *MIB.getMRI();
701+
702+
Register Dst = MI.getOperand(0).getReg();
703+
Register SplatVal = MI.getOperand(1).getReg();
704+
705+
LLT VecTy = MRI.getType(Dst);
706+
LLT XLenTy(STI.getXLenVT());
707+
708+
// Handle case of s64 element vectors on rv32
709+
if (XLenTy.getSizeInBits() == 32 &&
710+
VecTy.getElementType().getSizeInBits() == 64) {
711+
auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI);
712+
buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB,
713+
MRI);
714+
MI.eraseFromParent();
715+
return true;
716+
}
717+
718+
// All-zeros or all-ones splats are handled specially.
719+
MachineInstr &SplatValMI = *MRI.getVRegDef(SplatVal);
720+
if (isAllOnesOrAllOnesSplat(SplatValMI, MRI)) {
721+
auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
722+
MIB.buildInstr(RISCV::G_VMSET_VL, {Dst}, {VL});
723+
MI.eraseFromParent();
724+
return true;
725+
}
726+
if (isNullOrNullSplat(SplatValMI, MRI)) {
727+
auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
728+
MIB.buildInstr(RISCV::G_VMCLR_VL, {Dst}, {VL});
729+
MI.eraseFromParent();
730+
return true;
731+
}
579732

733+
// Handle non-constant mask splat (i.e. not sure if it's all zeros or all
734+
// ones) by promoting it to an s8 splat.
735+
LLT InterEltTy = LLT::scalar(8);
736+
LLT InterTy = VecTy.changeElementType(InterEltTy);
737+
auto ZExtSplatVal = MIB.buildZExt(InterEltTy, SplatVal);
738+
auto And =
739+
MIB.buildAnd(InterEltTy, ZExtSplatVal, MIB.buildConstant(InterEltTy, 1));
740+
auto LHS = MIB.buildSplatVector(InterTy, And);
741+
auto ZeroSplat =
742+
MIB.buildSplatVector(InterTy, MIB.buildConstant(InterEltTy, 0));
743+
MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
580744
MI.eraseFromParent();
581745
return true;
582746
}
@@ -640,6 +804,12 @@ bool RISCVLegalizerInfo::legalizeCustom(
640804
return legalizeVAStart(MI, MIRBuilder);
641805
case TargetOpcode::G_VSCALE:
642806
return legalizeVScale(MI, MIRBuilder);
807+
case TargetOpcode::G_ZEXT:
808+
case TargetOpcode::G_SEXT:
809+
case TargetOpcode::G_ANYEXT:
810+
return legalizeExt(MI, MIRBuilder);
811+
case TargetOpcode::G_SPLAT_VECTOR:
812+
return legalizeSplatVector(MI, MIRBuilder);
643813
}
644814

645815
llvm_unreachable("expected switch to return");

llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {
4343

4444
bool legalizeVAStart(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
4545
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
46+
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
47+
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
4648
};
4749
} // end namespace llvm
4850
#endif

llvm/lib/Target/RISCV/RISCVInstrGISel.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,28 @@ def G_READ_VLENB : RISCVGenericInstruction {
3232
let hasSideEffects = false;
3333
}
3434
def : GINodeEquiv<G_READ_VLENB, riscv_read_vlenb>;
35+
36+
// Pseudo equivalent to a RISCVISD::VMCLR_VL
37+
def G_VMCLR_VL : RISCVGenericInstruction {
38+
let OutOperandList = (outs type0:$dst);
39+
let InOperandList = (ins type1:$vl);
40+
let hasSideEffects = false;
41+
}
42+
def : GINodeEquiv<G_VMCLR_VL, riscv_vmclr_vl>;
43+
44+
// Pseudo equivalent to a RISCVISD::VMSET_VL
45+
def G_VMSET_VL : RISCVGenericInstruction {
46+
let OutOperandList = (outs type0:$dst);
47+
let InOperandList = (ins type1:$vl);
48+
let hasSideEffects = false;
49+
}
50+
def : GINodeEquiv<G_VMSET_VL, riscv_vmset_vl>;
51+
52+
// Pseudo equivalent to a RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL. There is no
53+
// record to mark as equivalent to using GINodeEquiv because it gets lowered
54+
// before instruction selection.
55+
def G_SPLAT_VECTOR_SPLIT_I64_VL : RISCVGenericInstruction {
56+
let OutOperandList = (outs type0:$dst);
57+
let InOperandList = (ins type0:$passthru, type1:$hi, type1:$lo, type2:$vl);
58+
let hasSideEffects = false;
59+
}

0 commit comments

Comments
 (0)