Skip to content

[RISCV][GISEL] Legalize G_ZEXT, G_SEXT, G_ANYEXT, G_SPLAT_VECTOR, and G_ICMP #85938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3006,6 +3006,15 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
Observer.changedInstr(MI);
return Legalized;
}
case TargetOpcode::G_SPLAT_VECTOR: {
if (TypeIdx != 1)
return UnableToLegalize;

Observer.changingInstr(MI);
widenScalarSrc(MI, WideTy, 1, TargetOpcode::G_ANYEXT);
Observer.changedInstr(MI);
return Legalized;
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1278,7 +1278,7 @@ MachineIRBuilder::buildInstr(unsigned Opc, ArrayRef<DstOp> DstOps,
return DstTy.isScalar();
else
return DstTy.isVector() &&
DstTy.getNumElements() == Op0Ty.getNumElements();
DstTy.getElementCount() == Op0Ty.getElementCount();
}() && "Type Mismatch");
break;
}
Expand Down
184 changes: 177 additions & 7 deletions llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,21 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
.clampScalar(0, s32, sXLen)
.minScalarSameAs(1, 0);

auto &ExtActions =
getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT})
.legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)));
if (ST.is64Bit()) {
getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT})
.legalFor({{sXLen, s32}})
.maxScalar(0, sXLen);

ExtActions.legalFor({{sXLen, s32}});
getActionDefinitionsBuilder(G_SEXT_INREG)
.customFor({sXLen})
.maxScalar(0, sXLen)
.lower();
} else {
getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar(0, sXLen);

getActionDefinitionsBuilder(G_SEXT_INREG).maxScalar(0, sXLen).lower();
}
ExtActions.customIf(typeIsLegalBoolVec(1, BoolVecTys, ST))
.maxScalar(0, sXLen);

// Merge/Unmerge
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
Expand Down Expand Up @@ -235,7 +236,9 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)

getActionDefinitionsBuilder(G_ICMP)
.legalFor({{sXLen, sXLen}, {sXLen, p0}})
.widenScalarToNextPow2(1)
.legalIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST),
typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)))
.widenScalarOrEltToNextPow2OrMinSize(1, 8)
.clampScalar(1, sXLen, sXLen)
.clampScalar(0, sXLen, sXLen);

Expand Down Expand Up @@ -418,6 +421,29 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
.clampScalar(0, sXLen, sXLen)
.customFor({sXLen});

auto &SplatActions =
getActionDefinitionsBuilder(G_SPLAT_VECTOR)
.legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
typeIs(1, sXLen)))
.customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), typeIs(1, s1)));
// Handle case of s64 element vectors on RV32. If the subtarget does not have
// f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
// does have f64, then we don't know whether the type is an f64 or an i64,
// so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
// depending on how the instructions it consumes are legalized. They are not
// legalized yet since legalization is in reverse postorder, so we cannot
// make the decision at this moment.
if (XLen == 32) {
if (ST.hasVInstructionsF64() && ST.hasStdExtD())
SplatActions.legalIf(all(
typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64)));
else if (ST.hasVInstructionsI64())
SplatActions.customIf(all(
typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64)));
}

SplatActions.clampScalar(1, sXLen, sXLen);

getLegacyLegalizerInfo().computeTables();
}

Expand Down Expand Up @@ -576,7 +602,145 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3));
MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val));
}
MI.eraseFromParent();
return true;
}

// Custom-lower extensions from mask vectors by using a vselect either with 1
// for zero/any-extension or -1 for sign-extension:
// (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
// Note that any-extension is lowered identically to zero-extension.
bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
MachineIRBuilder &MIB) const {

unsigned Opc = MI.getOpcode();
assert(Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
Opc == TargetOpcode::G_ANYEXT);

MachineRegisterInfo &MRI = *MIB.getMRI();
Register Dst = MI.getOperand(0).getReg();
Register Src = MI.getOperand(1).getReg();

LLT DstTy = MRI.getType(Dst);
int64_t ExtTrueVal = Opc == TargetOpcode::G_SEXT ? -1 : 1;
LLT DstEltTy = DstTy.getElementType();
auto SplatZero = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, 0));
auto SplatTrue =
MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, ExtTrueVal));
MIB.buildSelect(Dst, Src, SplatTrue, SplatZero);

MI.eraseFromParent();
return true;
}

/// Return the type of the mask type suitable for masking the provided
/// vector type. This is simply an i1 element type vector of the same
/// (possibly scalable) length.
static LLT getMaskTypeFor(LLT VecTy) {
assert(VecTy.isVector());
ElementCount EC = VecTy.getElementCount();
return LLT::vector(EC, LLT::scalar(1));
}

/// Creates an all ones mask suitable for masking a vector of type VecTy with
/// vector length VL.
static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
MachineIRBuilder &MIB,
MachineRegisterInfo &MRI) {
LLT MaskTy = getMaskTypeFor(VecTy);
return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL});
}

/// Gets the two common "VL" operands: an all-ones mask and the vector length.
/// VecTy is a scalable vector type.
static std::pair<MachineInstrBuilder, Register>
buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI) {
LLT VecTy = Dst.getLLTTy(MRI);
assert(VecTy.isScalableVector() && "Expecting scalable container type");
Register VL(RISCV::X0);
MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI);
return {Mask, VL};
}

static MachineInstrBuilder
buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo,
Register Hi, Register VL, MachineIRBuilder &MIB,
MachineRegisterInfo &MRI) {
// TODO: If the Hi bits of the splat are undefined, then it's fine to just
// splat Lo even if it might be sign extended. I don't think we have
// introduced a case where we're build a s64 where the upper bits are undef
// yet.

// Fall back to a stack store and stride x0 vector load.
// TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
// preprocessDAG in SDAG.
return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
{Passthru, Lo, Hi, VL});
}

static MachineInstrBuilder
buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru,
const SrcOp &Scalar, Register VL,
MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!");
auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar);
return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0),
Unmerge.getReg(1), VL, MIB, MRI);
}

// Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
// legal equivalently-sized i8 type, so we can use that as a go-between.
// Splats of s1 types that have constant value can be legalized as VMSET_VL or
// VMCLR_VL.
bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
MachineIRBuilder &MIB) const {
assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR);

MachineRegisterInfo &MRI = *MIB.getMRI();

Register Dst = MI.getOperand(0).getReg();
Register SplatVal = MI.getOperand(1).getReg();

LLT VecTy = MRI.getType(Dst);
LLT XLenTy(STI.getXLenVT());

// Handle case of s64 element vectors on rv32
if (XLenTy.getSizeInBits() == 32 &&
VecTy.getElementType().getSizeInBits() == 64) {
auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI);
buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB,
MRI);
MI.eraseFromParent();
return true;
}

// All-zeros or all-ones splats are handled specially.
MachineInstr &SplatValMI = *MRI.getVRegDef(SplatVal);
if (isAllOnesOrAllOnesSplat(SplatValMI, MRI)) {
auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
MIB.buildInstr(RISCV::G_VMSET_VL, {Dst}, {VL});
MI.eraseFromParent();
return true;
}
if (isNullOrNullSplat(SplatValMI, MRI)) {
auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
MIB.buildInstr(RISCV::G_VMCLR_VL, {Dst}, {VL});
MI.eraseFromParent();
return true;
}

// Handle non-constant mask splat (i.e. not sure if it's all zeros or all
// ones) by promoting it to an s8 splat.
LLT InterEltTy = LLT::scalar(8);
LLT InterTy = VecTy.changeElementType(InterEltTy);
auto ZExtSplatVal = MIB.buildZExt(InterEltTy, SplatVal);
auto And =
MIB.buildAnd(InterEltTy, ZExtSplatVal, MIB.buildConstant(InterEltTy, 1));
auto LHS = MIB.buildSplatVector(InterTy, And);
auto ZeroSplat =
MIB.buildSplatVector(InterTy, MIB.buildConstant(InterEltTy, 0));
MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
MI.eraseFromParent();
return true;
}
Expand Down Expand Up @@ -640,6 +804,12 @@ bool RISCVLegalizerInfo::legalizeCustom(
return legalizeVAStart(MI, MIRBuilder);
case TargetOpcode::G_VSCALE:
return legalizeVScale(MI, MIRBuilder);
case TargetOpcode::G_ZEXT:
case TargetOpcode::G_SEXT:
case TargetOpcode::G_ANYEXT:
return legalizeExt(MI, MIRBuilder);
case TargetOpcode::G_SPLAT_VECTOR:
return legalizeSplatVector(MI, MIRBuilder);
}

llvm_unreachable("expected switch to return");
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class RISCVLegalizerInfo : public LegalizerInfo {

bool legalizeVAStart(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
bool legalizeVScale(MachineInstr &MI, MachineIRBuilder &MIB) const;
bool legalizeExt(MachineInstr &MI, MachineIRBuilder &MIRBuilder) const;
bool legalizeSplatVector(MachineInstr &MI, MachineIRBuilder &MIB) const;
};
} // end namespace llvm
#endif
25 changes: 25 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrGISel.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,28 @@ def G_READ_VLENB : RISCVGenericInstruction {
let hasSideEffects = false;
}
def : GINodeEquiv<G_READ_VLENB, riscv_read_vlenb>;

// Pseudo equivalent to a RISCVISD::VMCLR_VL
def G_VMCLR_VL : RISCVGenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type1:$vl);
let hasSideEffects = false;
}
def : GINodeEquiv<G_VMCLR_VL, riscv_vmclr_vl>;

// Pseudo equivalent to a RISCVISD::VMSET_VL
def G_VMSET_VL : RISCVGenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type1:$vl);
let hasSideEffects = false;
}
def : GINodeEquiv<G_VMSET_VL, riscv_vmset_vl>;

// Pseudo equivalent to a RISCVISD::SPLAT_VECTOR_SPLIT_I64_VL. There is no
// record to mark as equivalent to using GINodeEquiv because it gets lowered
// before instruction selection.
def G_SPLAT_VECTOR_SPLIT_I64_VL : RISCVGenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type0:$passthru, type1:$hi, type1:$lo, type2:$vl);
let hasSideEffects = false;
}
Loading