Skip to content

[WIP][RFC] Implementation for SVE2 long operations #89310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
292 changes: 292 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RuntimeLibcalls.h"
#include "llvm/CodeGen/SDPatternMatch.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
#include "llvm/CodeGen/TargetCallingConv.h"
Expand Down Expand Up @@ -104,6 +105,7 @@

using namespace llvm;
using namespace llvm::PatternMatch;
namespace sd = llvm::SDPatternMatch;

#define DEBUG_TYPE "aarch64-lower"

Expand Down Expand Up @@ -1416,6 +1418,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::OR, VT, Custom);
}

// Illegal wide integer scalable vector types.
if (Subtarget->hasSVE2orSME()) {
for (auto VT : {MVT::nxv16i16, MVT::nxv16i32, MVT::nxv16i64})
setOperationAction(ISD::ADD, VT, Custom);
for (auto VT : {MVT::nxv8i32, MVT::nxv8i64})
setOperationAction(ISD::ADD, VT, Custom);
setOperationAction(ISD::ADD, MVT::nxv4i64, Custom);
}

// Illegal unpacked integer vector types.
for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
Expand Down Expand Up @@ -2725,6 +2736,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
MAKE_CASE(AArch64ISD::URSHR_I_PRED)
MAKE_CASE(AArch64ISD::UADDLB)
MAKE_CASE(AArch64ISD::UADDLT)
}
#undef MAKE_CASE
return nullptr;
Expand Down Expand Up @@ -25081,6 +25094,282 @@ void AArch64TargetLowering::ReplaceBITCASTResults(
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Op));
}

static bool matchUADDLOps(SDNode *N, SelectionDAG &DAG, SDValue &A, SDValue &B,
unsigned &BotOpc, unsigned &TopOpc) {
BotOpc = AArch64ISD::UADDLB;
TopOpc = AArch64ISD::UADDLT;
if (sd_match(N, sd::m_Add(sd::m_OneUse(sd::m_ZExt(sd::m_Value(A))),
sd::m_OneUse(sd::m_ZExt(sd::m_Value(B))))))

return true;

#if 0
// Extended loads.
if (sd_match(N, sd::m_Add(sd::m_OneUse(sd::m_ZExt(sd::m_Value(A))),
sd::m_OneUse(sd::m_Value(B))))) {
auto *LDB = dyn_cast<LoadSDNode>(B);
if (LDB && LDB->getExtensionType() == ISD::ZEXTLOAD) {
B = DAG.getLoad(LDB->getMemoryVT(), SDLoc(LDB), LDB->getChain(),
LDB->getBasePtr(), LDB->getMemOperand());
return true;
}
} else if (sd_match(N, sd::m_Add(sd::m_OneUse(sd::m_Value(A)),
sd::m_OneUse(sd::m_Value(B)))) &&
isa<LoadSDNode>(A) && isa<LoadSDNode>(B)) {
auto *LDA = cast<LoadSDNode>(A);
auto *LDB = cast<LoadSDNode>(B);
if (LDA->getExtensionType() == ISD::ZEXTLOAD &&
LDB->getExtensionType() == ISD::ZEXTLOAD) {
A = DAG.getLoad(LDA->getMemoryVT(), SDLoc(LDA), LDA->getChain(),
LDA->getBasePtr(), LDA->getMemOperand());
B = DAG.getLoad(LDB->getMemoryVT(), SDLoc(LDB), LDB->getChain(),
LDB->getBasePtr(), LDB->getMemOperand());
return true;
}
}
#endif
return false;
}
static bool replaceIntOpWithSVE2LongOp(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
if (!Subtarget->hasSVE2orSME())
return false;

EVT VT = N->getValueType(0);
LLVMContext &Ctx = *DAG.getContext();
SDLoc DL(N);
SDValue LHS, RHS;
unsigned BotOpc, TopOpc;

auto CreateLongOpPair = [&](SDValue LHS,
SDValue RHS) -> std::pair<SDValue, SDValue> {
EVT WideResVT = LHS.getValueType()
.widenIntegerVectorElementType(Ctx)
.getHalfNumVectorElementsVT(Ctx);
SDValue Even = DAG.getNode(BotOpc, DL, WideResVT, LHS, RHS);
SDValue Odd = DAG.getNode(TopOpc, DL, WideResVT, LHS, RHS);
return std::make_pair(Even, Odd);
};

bool MatchedLongOp = matchUADDLOps(N, DAG, LHS, RHS, BotOpc, TopOpc);
// Should also work for similar long instructions.
// if (!MatchedLongOp) MatchedLongOp = match<OtherLongInstr>Ops(...);
if (!MatchedLongOp || LHS.getValueType() != RHS.getValueType())
return false;
EVT UnExtVT = LHS.getValueType();

// 128-bit unextended operands.
if (UnExtVT == MVT::nxv16i8 || UnExtVT == MVT::nxv8i16 ||
UnExtVT == MVT::nxv4i32) {
auto [Even, Odd] = CreateLongOpPair(LHS, RHS);
EVT WideResVT = Even.getValueType();
// Widening operations deinterleaves the results. Shuffle them to get
// their natural order.
SDValue Interleave =
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(WideResVT, WideResVT), Even, Odd);
SDValue Concat = DAG.getNode(
ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
Interleave.getValue(0), Interleave.getValue(1));
Results.push_back(DAG.getZExtOrTrunc(Concat, DL, VT));
return true;
}

// 256-bit/512-bit unextended operands. Try to optimize by reducing the number
// of shuffles in cases where the operands are interleaved from existing
// even/odd pairs.
if (UnExtVT == MVT::nxv16i16 || UnExtVT == MVT::nxv8i32) {
// For the pattern:
// (LHSBot, LHSTop) = vector_interleave(LHSEven, LHSOdd)
// (RHSBot, RHSTop) = vector_interleave(RHSEven, RHSOdd)
// LHS = concat(LHSBot, LHSTop)
// RHS = concat(RHSBot, RHSTop)
// op(zext(LHS), zext(RHS))
// We can use the pre-interleaved operands to create the longOp(b|t) and
// push the shuffles across the operation.
SDValue LHSBot, LHSTop, RHSBot, RHSTop;
SDValue LHSEven, LHSOdd, RHSEven, RHSOdd;

if (!sd_match(LHS, sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(LHSBot),
sd::m_Value(LHSTop))))
return false;
if (LHSTop.getNode() != LHSBot.getNode() || LHSTop == LHSBot ||
!sd_match(LHSBot.getNode(),
sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(LHSEven),
sd::m_Value(LHSOdd))))
return false;

if (!sd_match(RHS, sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(RHSBot),
sd::m_Value(RHSTop))))
return false;
if (RHSTop.getNode() != RHSBot.getNode() || RHSTop == RHSBot ||
!sd_match(RHSBot.getNode(),
sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(RHSEven),
sd::m_Value(RHSOdd))))
return false;

// Do the following:
// v0 = longOpb(LHSEven, RHSEven)
// v1 = longOpt(LHSEven, RHSEven)
// v2 = longOpb(LHSOdd, RHSOdd)
// v3 = longOpt(LHSOdd, RHSOdd)
// InterleaveEven = interleave(v0, v2)
// InterleaveOdd = interleave(v1, v3)
// concat(InterleaveEven[0], InterleaveOdd[0], InterleaveEven[1],
// InterleaveOdd[1])
auto [V0, V1] = CreateLongOpPair(LHSEven, RHSEven);
auto [V2, V3] = CreateLongOpPair(LHSOdd, RHSOdd);
EVT WideResVT = V0.getValueType();

SDValue InterleaveEven =
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(WideResVT, WideResVT), V0, V2);
SDValue InterleaveOdd =
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(WideResVT, WideResVT), V1, V3);

SDValue Concat0 = DAG.getNode(
ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
InterleaveEven.getValue(0), InterleaveOdd.getValue(0));
SDValue Concat1 = DAG.getNode(
ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
InterleaveEven.getValue(1), InterleaveOdd.getValue(1));
SDValue Concat =
DAG.getNode(ISD::CONCAT_VECTORS, DL,
Concat0.getValueType().getDoubleNumVectorElementsVT(Ctx),
Concat0, Concat1);
Results.push_back(DAG.getZExtOrTrunc(Concat, DL, VT));
return true;
}

if (UnExtVT == MVT::nxv16i32) {
// [LHS0, LHS2] = interleave(...)
// [LHS1, LHS3] = interleave(...)
// LHS = concat(concat(LHS0, LHS1), concat(LHS2, LHS3))
// See comments for 256-bit unextended operands to understand
// where this pattern comes from.
// Example:
// LHS = 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
// LHS0 = 3, 2, 1, 0
// LHS1 = 7, 6, 5, 4
// LHS2 = 11, 10, 9, 8
// LHS3 = 15, 14, 13, 12
// After Deinterleaving/pre-interleaved values:
// LHS0 = 10, 8, 2, 0
// LHS1 = 14, 12, 6, 4
// LHS2 = 11, 9, 3, 1
// LHS3 = 15, 13, 7, 5

SDValue LHS0, LHS1, LHS2, LHS3;
SDValue RHS0, RHS1, RHS2, RHS3;
if (!sd_match(LHS,
sd::m_Node(ISD::CONCAT_VECTORS,
sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(LHS0),
sd::m_Value(LHS1)),
sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(LHS2),
sd::m_Value(LHS3)))))
return false;
if (!sd_match(RHS,
sd::m_Node(ISD::CONCAT_VECTORS,
sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(RHS0),
sd::m_Value(RHS1)),
sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(RHS2),
sd::m_Value(RHS3)))))
return false;

if (LHS0.getNode() != LHS2.getNode() || LHS0 == LHS2 ||
!sd_match(LHS0.getNode(),
sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(LHS0),
sd::m_Value(LHS2))))
return false;
if (LHS1.getNode() != LHS3.getNode() || LHS1 == LHS3 ||
!sd_match(LHS1.getNode(),
sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(LHS1),
sd::m_Value(LHS3))))
return false;

if (RHS0.getNode() != RHS2.getNode() || RHS0 == RHS2 ||
!sd_match(RHS0.getNode(),
sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(RHS0),
sd::m_Value(RHS2))))
return false;
if (RHS1.getNode() != RHS3.getNode() || RHS1 == RHS3 ||
!sd_match(RHS1.getNode(),
sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(RHS1),
sd::m_Value(RHS3))))
return false;

// After long operation:
// v0 = 8, 0
// v1 = 10, 2
//
// v2 = 12, 4
// v3 = 14, 6
//
// v4 = 9, 1
// v5 = 11, 3
//
// v6 = 13, 5
// v7 = 15, 7
auto [V0, V1] = CreateLongOpPair(LHS0, RHS0);
auto [V2, V3] = CreateLongOpPair(LHS1, RHS1);
auto [V4, V5] = CreateLongOpPair(LHS2, RHS2);
auto [V6, V7] = CreateLongOpPair(LHS3, RHS3);
EVT WideResVT = V0.getValueType();

// Now we can interleave and concat:
// i0 = interleave(v0, v4) ; i0 = [(1, 0), (12, 8)]
// i1 = interleave(v1, v5) ; i1 = [(3, 2), (11, 10)]
// i2 = interleave(v2, v6) ; i2 = [(5, 4), (13, 12)]
// i3 = interleave(v3, v7) ; i3 = [(7, 6), (15, 14)]
// res = concat(i0[0], i1[0]...i0[1], i1[1]...)
SDValue Interleave0 =
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(WideResVT, WideResVT), V0, V4);
SDValue Interleave1 =
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(WideResVT, WideResVT), V1, V5);
SDValue Interleave2 =
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(WideResVT, WideResVT), V2, V6);
SDValue Interleave3 =
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(WideResVT, WideResVT), V3, V7);

SDValue Concat0 = DAG.getNode(
ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
Interleave0.getValue(0), Interleave1.getValue(0));
SDValue Concat1 = DAG.getNode(
ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
Interleave2.getValue(0), Interleave3.getValue(0));
SDValue Concat2 = DAG.getNode(
ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
Interleave0.getValue(1), Interleave1.getValue(1));
SDValue Concat3 = DAG.getNode(
ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
Interleave2.getValue(1), Interleave3.getValue(1));
Concat0 =
DAG.getNode(ISD::CONCAT_VECTORS, DL,
Concat0.getValueType().getDoubleNumVectorElementsVT(Ctx),
Concat0, Concat1);
Concat2 =
DAG.getNode(ISD::CONCAT_VECTORS, DL,
Concat2.getValueType().getDoubleNumVectorElementsVT(Ctx),
Concat2, Concat3);
Concat0 =
DAG.getNode(ISD::CONCAT_VECTORS, DL,
Concat0.getValueType().getDoubleNumVectorElementsVT(Ctx),
Concat0, Concat2);

Results.push_back(DAG.getZExtOrTrunc(Concat0, DL, VT));
return true;
}

return false;
}

static void ReplaceAddWithADDP(SDNode *N, SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
Expand Down Expand Up @@ -25429,6 +25718,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
return;
case ISD::ADD:
case ISD::FADD:
if (replaceIntOpWithSVE2LongOp(N, Results, DAG, Subtarget))
return;

ReplaceAddWithADDP(N, Results, DAG, Subtarget);
return;

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ enum NodeType : unsigned {
URSHR_I,
URSHR_I_PRED,

UADDLB,
UADDLT,

// Vector narrowing shift by immediate (bottom)
RSHRNB_I,

Expand Down
20 changes: 18 additions & 2 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3556,6 +3556,22 @@ let Predicates = [HasSVE2orSME, UseExperimentalZeroingPseudos] in {
defm SQSHLU_ZPZI : sve_int_bin_pred_shift_imm_left_zeroing_bhsd<int_aarch64_sve_sqshlu>;
} // End HasSVE2orSME, UseExperimentalZeroingPseudos

def SDT_AArch64ArithLong_Unpred : SDTypeProfile<1, 2, [
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameAs<1,2>,
SDTCisInt<0>, SDTCisInt<1>,
SDTCisOpSmallerThanOp<1, 0>
]>;
def AArch64uaddlb_node : SDNode<"AArch64ISD::UADDLB", SDT_AArch64ArithLong_Unpred>;
def AArch64uaddlt_node : SDNode<"AArch64ISD::UADDLT", SDT_AArch64ArithLong_Unpred>;

// TODO: lower the intrinsic to the isd node.
def AArch64uaddlb : PatFrags<(ops node:$op1, node:$op2),
[(int_aarch64_sve_uaddlb node:$op1, node:$op2),
(AArch64uaddlb_node node:$op1, node:$op2)]>;
def AArch64uaddlt : PatFrags<(ops node:$op1, node:$op2),
[(int_aarch64_sve_uaddlt node:$op1, node:$op2),
(AArch64uaddlt_node node:$op1, node:$op2)]>;

let Predicates = [HasSVE2orSME] in {
// SVE2 predicated shifts
defm SQSHL_ZPmI : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl", "SQSHL_ZPZI", int_aarch64_sve_sqshl>;
Expand All @@ -3567,8 +3583,8 @@ let Predicates = [HasSVE2orSME] in {
// SVE2 integer add/subtract long
defm SADDLB_ZZZ : sve2_wide_int_arith_long<0b00000, "saddlb", int_aarch64_sve_saddlb>;
defm SADDLT_ZZZ : sve2_wide_int_arith_long<0b00001, "saddlt", int_aarch64_sve_saddlt>;
defm UADDLB_ZZZ : sve2_wide_int_arith_long<0b00010, "uaddlb", int_aarch64_sve_uaddlb>;
defm UADDLT_ZZZ : sve2_wide_int_arith_long<0b00011, "uaddlt", int_aarch64_sve_uaddlt>;
defm UADDLB_ZZZ : sve2_wide_int_arith_long<0b00010, "uaddlb", AArch64uaddlb>;
defm UADDLT_ZZZ : sve2_wide_int_arith_long<0b00011, "uaddlt", AArch64uaddlt>;
defm SSUBLB_ZZZ : sve2_wide_int_arith_long<0b00100, "ssublb", int_aarch64_sve_ssublb>;
defm SSUBLT_ZZZ : sve2_wide_int_arith_long<0b00101, "ssublt", int_aarch64_sve_ssublt>;
defm USUBLB_ZZZ : sve2_wide_int_arith_long<0b00110, "usublb", int_aarch64_sve_usublb>;
Expand Down
Loading