Skip to content

[AArch64] Optimise test of the LSB of a paired whileCC instruction #81141

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2727,6 +2727,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::INSR)
MAKE_CASE(AArch64ISD::PTEST)
MAKE_CASE(AArch64ISD::PTEST_ANY)
MAKE_CASE(AArch64ISD::PTEST_FIRST)
MAKE_CASE(AArch64ISD::PTRUE)
MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
Expand Down Expand Up @@ -18733,21 +18734,41 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
AArch64CC::CondCode Cond);

static bool isPredicateCCSettingOp(SDValue N) {
if ((N.getOpcode() == ISD::SETCC) ||
(N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
(N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
// get_active_lane_mask is lowered to a whilelo instruction.
N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
if (N.getOpcode() == ISD::SETCC)
return true;

return false;
if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
isNullConstant(N.getOperand(1)))
N = N.getOperand(0);

if (N.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
return false;

switch (N.getConstantOperandVal(0)) {
default:
return false;
case Intrinsic::aarch64_sve_whilege_x2:
case Intrinsic::aarch64_sve_whilegt_x2:
case Intrinsic::aarch64_sve_whilehi_x2:
case Intrinsic::aarch64_sve_whilehs_x2:
case Intrinsic::aarch64_sve_whilele_x2:
case Intrinsic::aarch64_sve_whilelo_x2:
case Intrinsic::aarch64_sve_whilels_x2:
case Intrinsic::aarch64_sve_whilelt_x2:
if (N.getResNo() != 0)
return false;
[[fallthrough]];
case Intrinsic::aarch64_sve_whilege:
case Intrinsic::aarch64_sve_whilegt:
case Intrinsic::aarch64_sve_whilehi:
case Intrinsic::aarch64_sve_whilehs:
case Intrinsic::aarch64_sve_whilele:
case Intrinsic::aarch64_sve_whilelo:
case Intrinsic::aarch64_sve_whilels:
case Intrinsic::aarch64_sve_whilelt:
case Intrinsic::get_active_lane_mask:
return true;
}
}

// Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
Expand Down Expand Up @@ -20666,9 +20687,19 @@ static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
}

// Set condition code (CC) flags.
SDValue Test = DAG.getNode(
Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
DL, MVT::Other, Pg, Op);
AArch64ISD::NodeType NT;
switch (Cond) {
default:
NT = AArch64ISD::PTEST;
break;
case AArch64CC::ANY_ACTIVE:
NT = AArch64ISD::PTEST_ANY;
break;
case AArch64CC::FIRST_ACTIVE:
NT = AArch64ISD::PTEST_FIRST;
break;
}
SDValue Test = DAG.getNode(NT, DL, MVT::Other, Pg, Op);

// Convert CC to integer based on requested condition.
// NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ enum NodeType : unsigned {
INSR,
PTEST,
PTEST_ANY,
PTEST_FIRST,
PTRUE,

CTTZ_ELTS,
Expand Down
52 changes: 35 additions & 17 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1184,6 +1184,7 @@ bool AArch64InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg,
break;
case AArch64::PTEST_PP:
case AArch64::PTEST_PP_ANY:
case AArch64::PTEST_PP_FIRST:
SrcReg = MI.getOperand(0).getReg();
SrcReg2 = MI.getOperand(1).getReg();
// Not sure about the mask and value for now...
Expand Down Expand Up @@ -1355,12 +1356,25 @@ static bool areCFlagsAccessedBetweenInstrs(
return false;
}

std::optional<unsigned>
std::optional<std::pair<unsigned, MachineInstr *>>
AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
MachineInstr *Pred,
const MachineRegisterInfo *MRI) const {
unsigned MaskOpcode = Mask->getOpcode();
unsigned PredOpcode = Pred->getOpcode();

// Handle a COPY from the LSB of the results of paired WHILEcc instruction.
if ((PredOpcode == TargetOpcode::COPY &&
Pred->getOperand(1).getSubReg() == AArch64::psub0) ||
// Handle unpack of the LSB of the result of a WHILEcc instruction.
PredOpcode == AArch64::PUNPKLO_PP) {
MachineInstr *MI = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
if (MI && isWhileOpcode(MI->getOpcode())) {
Pred = MI;
PredOpcode = MI->getOpcode();
}
}

bool PredIsPTestLike = isPTestLikeOpcode(PredOpcode);
bool PredIsWhileLike = isWhileOpcode(PredOpcode);

Expand All @@ -1369,15 +1383,16 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
// instruction and the condition is "any" since WHILcc does an implicit
// PTEST(ALL, PG) check and PG is always a subset of ALL.
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
return PredOpcode;
return std::make_pair(PredOpcode, Pred);

// For PTEST(PTRUE_ALL, WHILE), if the element size matches, the PTEST is
// redundant since WHILE performs an implicit PTEST with an all active
// mask.
// For PTEST(PTRUE_ALL, WHILE), since WHILE performs an implicit PTEST
// with an all active mask, the PTEST is redundant if ether the element
// size matches or the PTEST condition is "first".
if (isPTrueOpcode(MaskOpcode) && Mask->getOperand(1).getImm() == 31 &&
getElementSizeForOpcode(MaskOpcode) ==
getElementSizeForOpcode(PredOpcode))
return PredOpcode;
(PTest->getOpcode() == AArch64::PTEST_PP_FIRST ||
getElementSizeForOpcode(MaskOpcode) ==
getElementSizeForOpcode(PredOpcode)))
return std::make_pair(PredOpcode, Pred);

return {};
}
Expand All @@ -1388,7 +1403,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
// "any" since PG is always a subset of the governing predicate of the
// ptest-like instruction.
if ((Mask == Pred) && PTest->getOpcode() == AArch64::PTEST_PP_ANY)
return PredOpcode;
return std::make_pair(PredOpcode, Pred);

// For PTEST(PTRUE_ALL, PTEST_LIKE), the PTEST is redundant if the
// the element size matches and either the PTEST_LIKE instruction uses
Expand All @@ -1398,7 +1413,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
getElementSizeForOpcode(PredOpcode)) {
auto PTestLikeMask = MRI->getUniqueVRegDef(Pred->getOperand(1).getReg());
if (Mask == PTestLikeMask || PTest->getOpcode() == AArch64::PTEST_PP_ANY)
return PredOpcode;
return std::make_pair(PredOpcode, Pred);
}

// For PTEST(PG, PTEST_LIKE(PG, ...)), the PTEST is redundant since the
Expand Down Expand Up @@ -1427,7 +1442,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
uint64_t PredElementSize = getElementSizeForOpcode(PredOpcode);
if (Mask == PTestLikeMask && (PredElementSize == AArch64::ElementSizeB ||
PTest->getOpcode() == AArch64::PTEST_PP_ANY))
return PredOpcode;
return std::make_pair(PredOpcode, Pred);

return {};
}
Expand Down Expand Up @@ -1471,7 +1486,7 @@ AArch64InstrInfo::canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
return {};
}

return convertToFlagSettingOpc(PredOpcode);
return std::make_pair(convertToFlagSettingOpc(PredOpcode), Pred);
}

/// optimizePTestInstr - Attempt to remove a ptest of a predicate-generating
Expand All @@ -1481,10 +1496,12 @@ bool AArch64InstrInfo::optimizePTestInstr(
const MachineRegisterInfo *MRI) const {
auto *Mask = MRI->getUniqueVRegDef(MaskReg);
auto *Pred = MRI->getUniqueVRegDef(PredReg);
unsigned NewOp;
unsigned PredOpcode = Pred->getOpcode();
auto NewOp = canRemovePTestInstr(PTest, Mask, Pred, MRI);
if (!NewOp)
auto canRemove = canRemovePTestInstr(PTest, Mask, Pred, MRI);
if (!canRemove)
return false;
std::tie(NewOp, Pred) = *canRemove;

const TargetRegisterInfo *TRI = &getRegisterInfo();

Expand All @@ -1498,8 +1515,8 @@ bool AArch64InstrInfo::optimizePTestInstr(
// operand to be replaced with an equivalent instruction that also sets the
// flags.
PTest->eraseFromParent();
if (*NewOp != PredOpcode) {
Pred->setDesc(get(*NewOp));
if (NewOp != PredOpcode) {
Pred->setDesc(get(NewOp));
bool succeeded = UpdateOperandRegClass(*Pred);
(void)succeeded;
assert(succeeded && "Operands have incompatible register classes!");
Expand Down Expand Up @@ -1560,7 +1577,8 @@ bool AArch64InstrInfo::optimizeCompareInstr(
}

if (CmpInstr.getOpcode() == AArch64::PTEST_PP ||
CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY)
CmpInstr.getOpcode() == AArch64::PTEST_PP_ANY ||
CmpInstr.getOpcode() == AArch64::PTEST_PP_FIRST)
return optimizePTestInstr(&CmpInstr, SrcReg, SrcReg2, MRI);

if (SrcReg2 != 0)
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/AArch64InstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
bool optimizePTestInstr(MachineInstr *PTest, unsigned MaskReg,
unsigned PredReg,
const MachineRegisterInfo *MRI) const;
std::optional<unsigned>

std::optional<std::pair<unsigned, MachineInstr *>>
canRemovePTestInstr(MachineInstr *PTest, MachineInstr *Mask,
MachineInstr *Pred, const MachineRegisterInfo *MRI) const;
};
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,10 @@ def AArch64fadda_p : PatFrags<(ops node:$op1, node:$op2, node:$op3),
(AArch64fadda_p_node (SVEAllActive), node:$op2,
(vselect node:$op1, node:$op3, (splat_vector (f64 fpimm_minus0))))]>;

def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
def SDT_AArch64PTest : SDTypeProfile<0, 2, [SDTCisVec<0>, SDTCisSameAs<0,1>]>;
def AArch64ptest : SDNode<"AArch64ISD::PTEST", SDT_AArch64PTest>;
def AArch64ptest_any : SDNode<"AArch64ISD::PTEST_ANY", SDT_AArch64PTest>;
def AArch64ptest_first : SDNode<"AArch64ISD::PTEST_FIRST", SDT_AArch64PTest>;

def SDT_AArch64DUP_PRED : SDTypeProfile<1, 3,
[SDTCisVec<0>, SDTCisSameAs<0, 3>, SDTCisVec<1>, SDTCVecEltisVT<1,i1>, SDTCisSameNumEltsAs<0, 1>]>;
Expand Down Expand Up @@ -948,7 +949,7 @@ let Predicates = [HasSVEorSME] in {
defm BRKB_PPmP : sve_int_break_m<0b101, "brkb", int_aarch64_sve_brkb>;
defm BRKBS_PPzP : sve_int_break_z<0b110, "brkbs", null_frag>;

defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any>;
defm PTEST_PP : sve_int_ptest<0b010000, "ptest", AArch64ptest, AArch64ptest_any, AArch64ptest_first>;
defm PFALSE : sve_int_pfalse<0b000000, "pfalse">;
defm PFIRST : sve_int_pfirst<0b00000, "pfirst", int_aarch64_sve_pfirst>;
defm PNEXT : sve_int_pnext<0b00110, "pnext", int_aarch64_sve_pnext>;
Expand Down
17 changes: 11 additions & 6 deletions llvm/lib/Target/AArch64/SVEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -784,13 +784,16 @@ class sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op>
}

multiclass sve_int_ptest<bits<6> opc, string asm, SDPatternOperator op,
SDPatternOperator op_any> {
SDPatternOperator op_any, SDPatternOperator op_first> {
def NAME : sve_int_ptest<opc, asm, op>;

let hasNoSchedulingInfo = 1, isCompare = 1, Defs = [NZCV] in {
def _ANY : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
[(op_any (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
def _FIRST : Pseudo<(outs), (ins PPRAny:$Pg, PPR8:$Pn),
[(op_first (nxv16i1 PPRAny:$Pg), (nxv16i1 PPR8:$Pn))]>,
PseudoInstExpansion<(!cast<Instruction>(NAME) PPRAny:$Pg, PPR8:$Pn)>;
}
}

Expand Down Expand Up @@ -9669,7 +9672,7 @@ multiclass sve2p1_int_while_rr_pn<string mnemonic, bits<3> opc> {

// SVE integer compare scalar count and limit (predicate pair)
class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
RegisterOperand ppr_ty>
RegisterOperand ppr_ty, ElementSizeEnum EltSz>
: I<(outs ppr_ty:$Pd), (ins GPR64:$Rn, GPR64:$Rm),
mnemonic, "\t$Pd, $Rn, $Rm",
"", []>, Sched<[]> {
Expand All @@ -9687,16 +9690,18 @@ class sve2p1_int_while_rr_pair<string mnemonic, bits<2> sz, bits<3> opc,
let Inst{3-1} = Pd;
let Inst{0} = opc{0};

let ElementSize = EltSz;
let Defs = [NZCV];
let hasSideEffects = 0;
let isWhile = 1;
}


multiclass sve2p1_int_while_rr_pair<string mnemonic, bits<3> opc> {
def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r>;
def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r>;
def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r>;
def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r>;
def _B : sve2p1_int_while_rr_pair<mnemonic, 0b00, opc, PP_b_mul_r, ElementSizeB>;
def _H : sve2p1_int_while_rr_pair<mnemonic, 0b01, opc, PP_h_mul_r, ElementSizeH>;
def _S : sve2p1_int_while_rr_pair<mnemonic, 0b10, opc, PP_s_mul_r, ElementSizeS>;
def _D : sve2p1_int_while_rr_pair<mnemonic, 0b11, opc, PP_d_mul_r, ElementSizeD>;
}


Expand Down
Loading
Loading