Skip to content

[GlobalIsel] combine ext of trunc with flags #87115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,6 @@ class CombinerHelper {
/// This variant does not erase \p MI after calling the build function.
void applyBuildFnNoErase(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Use a function which takes in a MachineIRBuilder to perform a combine.
/// By default, it erases the instruction \p MI from the function.
void applyBuildFnMO(const MachineOperand &MO, BuildFnTy &MatchInfo);

bool matchOrShiftToFunnelShift(MachineInstr &MI, BuildFnTy &MatchInfo);
bool matchFunnelShiftToRotate(MachineInstr &MI);
void applyFunnelShiftToRotate(MachineInstr &MI);
Expand Down Expand Up @@ -814,6 +810,12 @@ class CombinerHelper {
/// Match constant LHS ops that should be commuted.
bool matchCommuteConstantToRHS(MachineInstr &MI);

/// Combine sext of trunc.
bool matchSextOfTrunc(const MachineOperand &MO, BuildFnTy &MatchInfo);

/// Combine zext of trunc.
bool matchZextOfTrunc(const MachineOperand &MO, BuildFnTy &MatchInfo);

/// Match constant LHS FP ops that should be commuted.
bool matchCommuteFPConstantToRHS(MachineInstr &MI);

Expand Down Expand Up @@ -857,6 +859,9 @@ class CombinerHelper {
/// register and different indices.
bool matchExtractVectorElementWithDifferentIndices(const MachineOperand &MO,
BuildFnTy &MatchInfo);
/// Use a function which takes in a MachineIRBuilder to perform a combine.
/// By default, it erases the instruction def'd on \p MO from the function.
void applyBuildFnMO(const MachineOperand &MO, BuildFnTy &MatchInfo);

/// Combine insert vector element OOB.
bool matchInsertVectorElementOOB(MachineInstr &MI, BuildFnTy &MatchInfo);
Expand Down
53 changes: 53 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,59 @@ class GFreeze : public GenericMachineInstr {
}
};

/// Represents a cast operation.
/// It models the llvm::CastInst concept.
/// The exception is bitcast.
class GCastOp : public GenericMachineInstr {
public:
Register getSrcReg() const { return getOperand(1).getReg(); }

static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_ADDRSPACE_CAST:
case TargetOpcode::G_FPEXT:
case TargetOpcode::G_FPTOSI:
case TargetOpcode::G_FPTOUI:
case TargetOpcode::G_FPTRUNC:
case TargetOpcode::G_INTTOPTR:
case TargetOpcode::G_PTRTOINT:
case TargetOpcode::G_SEXT:
case TargetOpcode::G_SITOFP:
case TargetOpcode::G_TRUNC:
case TargetOpcode::G_UITOFP:
case TargetOpcode::G_ZEXT:
case TargetOpcode::G_ANYEXT:
return true;
default:
return false;
}
};
};

/// Represents a sext.
class GSext : public GCastOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_SEXT;
};
};

/// Represents a zext.
class GZext : public GCastOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_ZEXT;
};
};

/// Represents a trunc.
class GTrunc : public GCastOp {
public:
static bool classof(const MachineInstr *MI) {
return MI->getOpcode() == TargetOpcode::G_TRUNC;
};
};

} // namespace llvm

#endif // LLVM_CODEGEN_GLOBALISEL_GENERICMACHINEINSTRS_H
6 changes: 4 additions & 2 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,8 @@ class MachineIRBuilder {
/// \pre \p Op must be smaller than \p Res
///
/// \return The newly created instruction.
MachineInstrBuilder buildZExt(const DstOp &Res, const SrcOp &Op);
MachineInstrBuilder buildZExt(const DstOp &Res, const SrcOp &Op,
std::optional<unsigned> Flags = std::nullopt);

/// Build and insert \p Res = G_SEXT \p Op, \p Res = G_TRUNC \p Op, or
/// \p Res = COPY \p Op depending on the differing sizes of \p Res and \p Op.
Expand Down Expand Up @@ -1231,7 +1232,8 @@ class MachineIRBuilder {
/// \pre \p Res must be smaller than \p Op
///
/// \return The newly created instruction.
MachineInstrBuilder buildTrunc(const DstOp &Res, const SrcOp &Op);
MachineInstrBuilder buildTrunc(const DstOp &Res, const SrcOp &Op,
std::optional<unsigned> Flags = std::nullopt);

/// Build and insert a \p Res = G_ICMP \p Pred, \p Op0, \p Op1
///
Expand Down
18 changes: 17 additions & 1 deletion llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ def FmContract : MIFlagEnum<"FmContract">;
def FmAfn : MIFlagEnum<"FmAfn">;
def FmReassoc : MIFlagEnum<"FmReassoc">;
def IsExact : MIFlagEnum<"IsExact">;
def NoSWrap : MIFlagEnum<"NoSWrap">;
def NoUWrap : MIFlagEnum<"NoUWrap">;

def MIFlags;
// def not; -> Already defined as a SDNode
Expand Down Expand Up @@ -1501,6 +1503,20 @@ def extract_vector_element_freeze : GICombineRule<
[{ return Helper.matchExtractVectorElementWithFreeze(${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;

def sext_trunc : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_TRUNC $src, $x, (MIFlags NoSWrap)),
(G_SEXT $root, $src),
[{ return Helper.matchSextOfTrunc(${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;

def zext_trunc : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_TRUNC $src, $x, (MIFlags NoUWrap)),
(G_ZEXT $root, $src),
[{ return Helper.matchZextOfTrunc(${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFnMO(${root}, ${matchinfo}); }])>;

def extract_vector_element_shuffle_vector : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (G_SHUFFLE_VECTOR $src, $src1, $src2, $mask),
Expand Down Expand Up @@ -1666,7 +1682,7 @@ def all_combines : GICombineGroup<[trivial_combines, vector_ops_combines,
sub_add_reg, select_to_minmax, redundant_binop_in_equality,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
combine_concat_vector, double_icmp_zero_and_or_combine, match_addos,
combine_shuffle_concat]>;
sext_trunc, zext_trunc, combine_shuffle_concat]>;

// A combine group used to for prelegalizer combiners at -O0. The combines in
// this group have been selected based on experiments to balance code size and
Expand Down
83 changes: 75 additions & 8 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4137,14 +4137,6 @@ void CombinerHelper::applyBuildFn(
MI.eraseFromParent();
}

void CombinerHelper::applyBuildFnMO(const MachineOperand &MO,
BuildFnTy &MatchInfo) {
MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
Builder.setInstrAndDebugLoc(*Root);
MatchInfo(Builder);
Root->eraseFromParent();
}

void CombinerHelper::applyBuildFnNoErase(
MachineInstr &MI, std::function<void(MachineIRBuilder &)> &MatchInfo) {
MatchInfo(Builder);
Expand Down Expand Up @@ -7252,3 +7244,78 @@ bool CombinerHelper::matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo) {

return false;
}

void CombinerHelper::applyBuildFnMO(const MachineOperand &MO,
BuildFnTy &MatchInfo) {
MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI);
MatchInfo(Builder);
Root->eraseFromParent();
}

bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO,
BuildFnTy &MatchInfo) {
GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI));
GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI));

Register Dst = Sext->getReg(0);
Register Src = Trunc->getSrcReg();

LLT DstTy = MRI.getType(Dst);
LLT SrcTy = MRI.getType(Src);

if (DstTy == SrcTy) {
MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
return true;
}

if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap);
};
return true;
}

if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) {
MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
return true;
}

return false;
}

bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO,
BuildFnTy &MatchInfo) {
GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI));
GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI));

Register Dst = Zext->getReg(0);
Register Src = Trunc->getSrcReg();

LLT DstTy = MRI.getType(Dst);
LLT SrcTy = MRI.getType(Src);

if (DstTy == SrcTy) {
MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
return true;
}

if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap);
};
return true;
}

if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg);
};
return true;
}

return false;
}
5 changes: 4 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@ KnownBits GISelKnownBits::getKnownBits(MachineInstr &MI) {

KnownBits GISelKnownBits::getKnownBits(Register R) {
const LLT Ty = MRI.getType(R);
// Since the number of lanes in a scalable vector is unknown at compile time,
// we track one bit which is implicitly broadcast to all lanes. This means
// that all lanes in a scalable vector are considered demanded.
APInt DemandedElts =
Ty.isVector() ? APInt::getAllOnes(Ty.getNumElements()) : APInt(1, 1);
Ty.isFixedVector() ? APInt::getAllOnes(Ty.getNumElements()) : APInt(1, 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a separate patch? What's the impact?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a G_SPLAT_VECTOR test and a second combine zext_trunc_fold_matchinfo that relies on known bits to prove that zext(trunc(x)) is a noop. It crashed.

Copy link
Author

@tschuett tschuett Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes are copy-pasted from the DAG to support scalable vectors in known bits.

return getKnownBits(R, DemandedElts);
}

Expand Down
12 changes: 7 additions & 5 deletions llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,9 @@ MachineInstrBuilder MachineIRBuilder::buildSExt(const DstOp &Res,
}

MachineInstrBuilder MachineIRBuilder::buildZExt(const DstOp &Res,
const SrcOp &Op) {
return buildInstr(TargetOpcode::G_ZEXT, Res, Op);
const SrcOp &Op,
std::optional<unsigned> Flags) {
return buildInstr(TargetOpcode::G_ZEXT, Res, Op, Flags);
}

unsigned MachineIRBuilder::getBoolExtOp(bool IsVec, bool IsFP) const {
Expand Down Expand Up @@ -869,9 +870,10 @@ MachineInstrBuilder MachineIRBuilder::buildIntrinsic(Intrinsic::ID ID,
return buildIntrinsic(ID, Results, HasSideEffects, isConvergent);
}

MachineInstrBuilder MachineIRBuilder::buildTrunc(const DstOp &Res,
const SrcOp &Op) {
return buildInstr(TargetOpcode::G_TRUNC, Res, Op);
MachineInstrBuilder
MachineIRBuilder::buildTrunc(const DstOp &Res, const SrcOp &Op,
std::optional<unsigned> Flags) {
return buildInstr(TargetOpcode::G_TRUNC, Res, Op, Flags);
}

MachineInstrBuilder
Expand Down
Loading