Skip to content

[AArch64][SVE2] Lower OR to SLI/SRI #77555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 93 additions & 65 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

if (!Subtarget->isLittleEndian())
setOperationAction(ISD::BITCAST, VT, Expand);

if (Subtarget->hasSVE2orSME())
// For SLI/SRI.
setOperationAction(ISD::OR, VT, Custom);
}

// Illegal unpacked integer vector types.
Expand Down Expand Up @@ -5409,15 +5413,18 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}

case Intrinsic::aarch64_neon_vsri:
case Intrinsic::aarch64_neon_vsli: {
case Intrinsic::aarch64_neon_vsli:
case Intrinsic::aarch64_sve_sri:
case Intrinsic::aarch64_sve_sli: {
EVT Ty = Op.getValueType();

if (!Ty.isVector())
report_fatal_error("Unexpected type for aarch64_neon_vsli");

assert(Op.getConstantOperandVal(3) <= Ty.getScalarSizeInBits());

bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri;
bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri ||
IntNo == Intrinsic::aarch64_sve_sri;
unsigned Opcode = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
Op.getOperand(3));
Expand Down Expand Up @@ -12542,6 +12549,53 @@ static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
return true;
}

static bool isAllInactivePredicate(SDValue N) {
// Look through cast.
while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
N = N.getOperand(0);

return ISD::isConstantSplatVectorAllZeros(N.getNode());
}

static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
unsigned NumElts = N.getValueType().getVectorMinNumElements();

// Look through cast.
while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
N = N.getOperand(0);
// When reinterpreting from a type with fewer elements the "new" elements
// are not active, so bail if they're likely to be used.
if (N.getValueType().getVectorMinNumElements() < NumElts)
return false;
}

if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
return true;

// "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
// or smaller than the implicit element type represented by N.
// NOTE: A larger element count implies a smaller element type.
if (N.getOpcode() == AArch64ISD::PTRUE &&
N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
return N.getValueType().getVectorMinNumElements() >= NumElts;

// If we're compiling for a specific vector-length, we can check if the
// pattern's VL equals that of the scalable vector at runtime.
if (N.getOpcode() == AArch64ISD::PTRUE) {
const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
if (MaxSVESize && MinSVESize == MaxSVESize) {
unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
unsigned PatNumElts =
getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
return PatNumElts == (NumElts * VScale);
}
}

return false;
}

// Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
// to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
// BUILD_VECTORs with constant element C1, C2 is a constant, and:
Expand All @@ -12567,59 +12621,78 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
// Is one of the operands an AND or a BICi? The AND may have been optimised to
// a BICi in order to use an immediate instead of a register.
// Is the other operand an shl or lshr? This will have been turned into:
// AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift.
// AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift
// or (AArch64ISD::SHL_PRED || AArch64ISD::SRL_PRED) mask, vector, #shiftVec.
if ((FirstOpc == ISD::AND || FirstOpc == AArch64ISD::BICi) &&
(SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR)) {
(SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR ||
SecondOpc == AArch64ISD::SHL_PRED ||
SecondOpc == AArch64ISD::SRL_PRED)) {
And = FirstOp;
Shift = SecondOp;

} else if ((SecondOpc == ISD::AND || SecondOpc == AArch64ISD::BICi) &&
(FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR)) {
(FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR ||
FirstOpc == AArch64ISD::SHL_PRED ||
FirstOpc == AArch64ISD::SRL_PRED)) {
And = SecondOp;
Shift = FirstOp;
} else
return SDValue();

bool IsAnd = And.getOpcode() == ISD::AND;
bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR;

// Is the shift amount constant?
ConstantSDNode *C2node = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
if (!C2node)
bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR ||
Shift.getOpcode() == AArch64ISD::SRL_PRED;
bool ShiftHasPredOp = Shift.getOpcode() == AArch64ISD::SHL_PRED ||
Shift.getOpcode() == AArch64ISD::SRL_PRED;

// Is the shift amount constant and are all lanes active?
uint64_t C2;
if (ShiftHasPredOp) {
if (!isAllActivePredicate(DAG, Shift.getOperand(0)))
return SDValue();
APInt C;
if (!ISD::isConstantSplatVector(Shift.getOperand(2).getNode(), C))
return SDValue();
C2 = C.getZExtValue();
} else if (ConstantSDNode *C2node =
dyn_cast<ConstantSDNode>(Shift.getOperand(1)))
C2 = C2node->getZExtValue();
else
return SDValue();

uint64_t C1;
APInt C1AsAPInt;
unsigned ElemSizeInBits = VT.getScalarSizeInBits();
if (IsAnd) {
// Is the and mask vector all constant?
if (!isAllConstantBuildVector(And.getOperand(1), C1))
if (!ISD::isConstantSplatVector(And.getOperand(1).getNode(), C1AsAPInt))
return SDValue();
} else {
// Reconstruct the corresponding AND immediate from the two BICi immediates.
ConstantSDNode *C1nodeImm = dyn_cast<ConstantSDNode>(And.getOperand(1));
ConstantSDNode *C1nodeShift = dyn_cast<ConstantSDNode>(And.getOperand(2));
assert(C1nodeImm && C1nodeShift);
C1 = ~(C1nodeImm->getZExtValue() << C1nodeShift->getZExtValue());
C1AsAPInt = ~(C1nodeImm->getAPIntValue() << C1nodeShift->getAPIntValue());
C1AsAPInt = C1AsAPInt.zextOrTrunc(ElemSizeInBits);
}

// Is C1 == ~(Ones(ElemSizeInBits) << C2) or
// C1 == ~(Ones(ElemSizeInBits) >> C2), taking into account
// how much one can shift elements of a particular size?
uint64_t C2 = C2node->getZExtValue();
unsigned ElemSizeInBits = VT.getScalarSizeInBits();
if (C2 > ElemSizeInBits)
return SDValue();

APInt C1AsAPInt(ElemSizeInBits, C1);
APInt RequiredC1 = IsShiftRight ? APInt::getHighBitsSet(ElemSizeInBits, C2)
: APInt::getLowBitsSet(ElemSizeInBits, C2);
if (C1AsAPInt != RequiredC1)
return SDValue();

SDValue X = And.getOperand(0);
SDValue Y = Shift.getOperand(0);
SDValue Y = ShiftHasPredOp ? Shift.getOperand(1) : Shift.getOperand(0);
SDValue Imm = ShiftHasPredOp ? DAG.getTargetConstant(C2, DL, MVT::i32)
: Shift.getOperand(1);

unsigned Inst = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Shift.getOperand(1));
SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Imm);

LLVM_DEBUG(dbgs() << "aarch64-lower: transformed: \n");
LLVM_DEBUG(N->dump(&DAG));
Expand All @@ -12641,6 +12714,8 @@ SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
return Res;

EVT VT = Op.getValueType();
if (VT.isScalableVector())
return Op;

SDValue LHS = Op.getOperand(0);
BuildVectorSDNode *BVN =
Expand Down Expand Up @@ -17432,53 +17507,6 @@ static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) {
return false;
}

static bool isAllInactivePredicate(SDValue N) {
// Look through cast.
while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
N = N.getOperand(0);

return ISD::isConstantSplatVectorAllZeros(N.getNode());
}

static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
unsigned NumElts = N.getValueType().getVectorMinNumElements();

// Look through cast.
while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
N = N.getOperand(0);
// When reinterpreting from a type with fewer elements the "new" elements
// are not active, so bail if they're likely to be used.
if (N.getValueType().getVectorMinNumElements() < NumElts)
return false;
}

if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
return true;

// "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
// or smaller than the implicit element type represented by N.
// NOTE: A larger element count implies a smaller element type.
if (N.getOpcode() == AArch64ISD::PTRUE &&
N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
return N.getValueType().getVectorMinNumElements() >= NumElts;

// If we're compiling for a specific vector-length, we can check if the
// pattern's VL equals that of the scalable vector at runtime.
if (N.getOpcode() == AArch64ISD::PTRUE) {
const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
if (MaxSVESize && MinSVESize == MaxSVESize) {
unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
unsigned PatNumElts =
getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
return PatNumElts == (NumElts * VScale);
}
}

return false;
}

static SDValue performReinterpretCastCombine(SDNode *N) {
SDValue LeafOp = SDValue(N, 0);
SDValue Op = N->getOperand(0);
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3577,8 +3577,8 @@ let Predicates = [HasSVE2orSME] in {
defm PMULLT_ZZZ : sve2_pmul_long<0b1, "pmullt", int_aarch64_sve_pmullt_pair>;

// SVE2 bitwise shift and insert
defm SRI_ZZI : sve2_int_bin_shift_imm_right<0b0, "sri", int_aarch64_sve_sri>;
defm SLI_ZZI : sve2_int_bin_shift_imm_left< 0b1, "sli", int_aarch64_sve_sli>;
defm SRI_ZZI : sve2_int_bin_shift_imm_right<0b0, "sri", AArch64vsri>;
defm SLI_ZZI : sve2_int_bin_shift_imm_left< 0b1, "sli", AArch64vsli>;

// SVE2 bitwise shift right and accumulate
defm SSRA_ZZI : sve2_int_bin_accum_shift_imm_right<0b00, "ssra", AArch64ssra>;
Expand Down
Loading