Skip to content

[IR][RISCV] Add llvm.vector.(de)interleave3/5/7 #124825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions llvm/include/llvm/IR/DerivedTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,15 @@ class VectorType : public Type {
EltCnt.divideCoefficientBy(2));
}

static VectorType *getOneNthElementsVectorType(VectorType *VTy,
unsigned Denominator) {
auto EltCnt = VTy->getElementCount();
assert(EltCnt.isKnownMultipleOf(Denominator) &&
"Cannot take one-nth of a vector");
return VectorType::get(VTy->getScalarType(),
EltCnt.divideCoefficientBy(Denominator));
}

/// This static method returns a VectorType with twice as many elements as the
/// input type and the same element type.
static VectorType *getDoubleElementsVectorType(VectorType *VTy) {
Expand Down
16 changes: 12 additions & 4 deletions llvm/include/llvm/IR/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ namespace Intrinsic {
ExtendArgument,
TruncArgument,
HalfVecArgument,
OneThirdVecArgument,
OneFifthVecArgument,
OneSeventhVecArgument,
SameVecWidthArgument,
VecOfAnyPtrsToElt,
VecElementArgument,
Expand All @@ -159,6 +162,9 @@ namespace Intrinsic {
AArch64Svcount,
} Kind;

// These three have to be contiguous.
static_assert(OneFifthVecArgument == OneThirdVecArgument + 1 &&
OneSeventhVecArgument == OneFifthVecArgument + 1);
union {
unsigned Integer_Width;
unsigned Float_Width;
Expand All @@ -178,15 +184,17 @@ namespace Intrinsic {
unsigned getArgumentNumber() const {
assert(Kind == Argument || Kind == ExtendArgument ||
Kind == TruncArgument || Kind == HalfVecArgument ||
Kind == SameVecWidthArgument || Kind == VecElementArgument ||
Kind == Subdivide2Argument || Kind == Subdivide4Argument ||
Kind == VecOfBitcastsToInt);
Kind == OneThirdVecArgument || Kind == OneFifthVecArgument ||
Kind == OneSeventhVecArgument || Kind == SameVecWidthArgument ||
Kind == VecElementArgument || Kind == Subdivide2Argument ||
Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
return Argument_Info >> 3;
}
ArgKind getArgumentKind() const {
assert(Kind == Argument || Kind == ExtendArgument ||
Kind == TruncArgument || Kind == HalfVecArgument ||
Kind == SameVecWidthArgument ||
Kind == OneThirdVecArgument || Kind == OneFifthVecArgument ||
Kind == OneSeventhVecArgument || Kind == SameVecWidthArgument ||
Kind == VecElementArgument || Kind == Subdivide2Argument ||
Kind == Subdivide4Argument || Kind == VecOfBitcastsToInt);
return (ArgKind)(Argument_Info & 7);
Expand Down
60 changes: 60 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def IIT_I4 : IIT_Int<4, 58>;
def IIT_AARCH64_SVCOUNT : IIT_VT<aarch64svcount, 59>;
def IIT_V6 : IIT_Vec<6, 60>;
def IIT_V10 : IIT_Vec<10, 61>;
def IIT_ONE_THIRD_VEC_ARG : IIT_Base<62>;
def IIT_ONE_FIFTH_VEC_ARG : IIT_Base<63>;
def IIT_ONE_SEVENTH_VEC_ARG : IIT_Base<64>;
}

defvar IIT_all_FixedTypes = !filter(iit, IIT_all,
Expand Down Expand Up @@ -467,6 +470,15 @@ class LLVMVectorElementType<int num> : LLVMMatchType<num, IIT_VEC_ELEMENT>;
class LLVMHalfElementsVectorType<int num>
: LLVMMatchType<num, IIT_HALF_VEC_ARG>;

class LLVMOneThirdElementsVectorType<int num>
: LLVMMatchType<num, IIT_ONE_THIRD_VEC_ARG>;

class LLVMOneFifthElementsVectorType<int num>
: LLVMMatchType<num, IIT_ONE_FIFTH_VEC_ARG>;

class LLVMOneSeventhElementsVectorType<int num>
: LLVMMatchType<num, IIT_ONE_SEVENTH_VEC_ARG>;

// Match the type of another intrinsic parameter that is expected to be a
// vector type (i.e. <N x iM>) but with each element subdivided to
// form a vector with more elements that are smaller than the original.
Expand Down Expand Up @@ -2728,6 +2740,54 @@ def int_vector_deinterleave2 : DefaultAttrsIntrinsic<[LLVMHalfElementsVectorType
[llvm_anyvector_ty],
[IntrNoMem]>;

def int_vector_interleave3 : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[LLVMOneThirdElementsVectorType<0>,
LLVMOneThirdElementsVectorType<0>,
LLVMOneThirdElementsVectorType<0>],
[IntrNoMem]>;

def int_vector_deinterleave3 : DefaultAttrsIntrinsic<[LLVMOneThirdElementsVectorType<0>,
LLVMOneThirdElementsVectorType<0>,
LLVMOneThirdElementsVectorType<0>],
[llvm_anyvector_ty],
[IntrNoMem]>;

def int_vector_interleave5 : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[LLVMOneFifthElementsVectorType<0>,
LLVMOneFifthElementsVectorType<0>,
LLVMOneFifthElementsVectorType<0>,
LLVMOneFifthElementsVectorType<0>,
LLVMOneFifthElementsVectorType<0>],
[IntrNoMem]>;

def int_vector_deinterleave5 : DefaultAttrsIntrinsic<[LLVMOneFifthElementsVectorType<0>,
LLVMOneFifthElementsVectorType<0>,
LLVMOneFifthElementsVectorType<0>,
LLVMOneFifthElementsVectorType<0>,
LLVMOneFifthElementsVectorType<0>],
[llvm_anyvector_ty],
[IntrNoMem]>;

def int_vector_interleave7 : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>],
[IntrNoMem]>;

def int_vector_deinterleave7 : DefaultAttrsIntrinsic<[LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>,
LLVMOneSeventhElementsVectorType<0>],
[llvm_anyvector_ty],
[IntrNoMem]>;

//===-------------- Intrinsics to perform partial reduction ---------------===//

def int_experimental_vector_partial_reduce_add : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
Expand Down
20 changes: 12 additions & 8 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5825,15 +5825,19 @@ SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_SPLICE(SDNode *N) {
}

SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_INTERLEAVE_DEINTERLEAVE(SDNode *N) {
SDLoc dl(N);
SDLoc DL(N);
unsigned Factor = N->getNumOperands();

SmallVector<SDValue, 8> Ops(Factor);
for (unsigned i = 0; i != Factor; i++)
Ops[i] = GetPromotedInteger(N->getOperand(i));

SmallVector<EVT, 8> ResVTs(Factor, Ops[0].getValueType());
SDValue Res = DAG.getNode(N->getOpcode(), DL, DAG.getVTList(ResVTs), Ops);

for (unsigned i = 0; i != Factor; i++)
SetPromotedInteger(SDValue(N, i), Res.getValue(i));

SDValue V0 = GetPromotedInteger(N->getOperand(0));
SDValue V1 = GetPromotedInteger(N->getOperand(1));
EVT ResVT = V0.getValueType();
SDValue Res = DAG.getNode(N->getOpcode(), dl,
DAG.getVTList(ResVT, ResVT), V0, V1);
SetPromotedInteger(SDValue(N, 0), Res.getValue(0));
SetPromotedInteger(SDValue(N, 1), Res.getValue(1));
return SDValue();
}

Expand Down
68 changes: 48 additions & 20 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,15 @@ void DAGTypeLegalizer::SplitVecRes_INSERT_SUBVECTOR(SDNode *N, SDValue &Lo,
return;
}

if (getTypeAction(SubVecVT) == TargetLowering::TypeWidenVector &&
Vec.isUndef() && SubVecVT.getVectorElementType() == MVT::i1) {
SDValue WideSubVec = GetWidenedVector(SubVec);
if (WideSubVec.getValueType() == VecVT) {
std::tie(Lo, Hi) = DAG.SplitVector(WideSubVec, SDLoc(WideSubVec));
return;
}
}

// Spill the vector to the stack.
// In cases where the vector is illegal it will be broken down into parts
// and stored in parts - we should use the alignment for the smallest part.
Expand Down Expand Up @@ -3183,34 +3192,53 @@ void DAGTypeLegalizer::SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo,
}

void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
unsigned Factor = N->getNumOperands();

SmallVector<SDValue, 8> Ops(Factor * 2);
for (unsigned i = 0; i != Factor; ++i) {
SDValue OpLo, OpHi;
GetSplitVector(N->getOperand(i), OpLo, OpHi);
Ops[i * 2] = OpLo;
Ops[i * 2 + 1] = OpHi;
}

SmallVector<EVT, 8> VTs(Factor, Ops[0].getValueType());

SDValue Op0Lo, Op0Hi, Op1Lo, Op1Hi;
GetSplitVector(N->getOperand(0), Op0Lo, Op0Hi);
GetSplitVector(N->getOperand(1), Op1Lo, Op1Hi);
EVT VT = Op0Lo.getValueType();
SDLoc DL(N);
SDValue ResLo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
DAG.getVTList(VT, VT), Op0Lo, Op0Hi);
SDValue ResHi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
DAG.getVTList(VT, VT), Op1Lo, Op1Hi);
SDValue ResLo = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
ArrayRef(Ops).slice(0, Factor));
SDValue ResHi = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, VTs,
ArrayRef(Ops).slice(Factor, Factor));

SetSplitVector(SDValue(N, 0), ResLo.getValue(0), ResHi.getValue(0));
SetSplitVector(SDValue(N, 1), ResLo.getValue(1), ResHi.getValue(1));
for (unsigned i = 0; i != Factor; ++i)
SetSplitVector(SDValue(N, i), ResLo.getValue(i), ResHi.getValue(i));
}

void DAGTypeLegalizer::SplitVecRes_VECTOR_INTERLEAVE(SDNode *N) {
SDValue Op0Lo, Op0Hi, Op1Lo, Op1Hi;
GetSplitVector(N->getOperand(0), Op0Lo, Op0Hi);
GetSplitVector(N->getOperand(1), Op1Lo, Op1Hi);
EVT VT = Op0Lo.getValueType();
unsigned Factor = N->getNumOperands();

SmallVector<SDValue, 8> Ops(Factor * 2);
for (unsigned i = 0; i != Factor; ++i) {
SDValue OpLo, OpHi;
GetSplitVector(N->getOperand(i), OpLo, OpHi);
Ops[i] = OpLo;
Ops[i + Factor] = OpHi;
}

SmallVector<EVT, 8> VTs(Factor, Ops[0].getValueType());

SDLoc DL(N);
SDValue Res[] = {DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(VT, VT), Op0Lo, Op1Lo),
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(VT, VT), Op0Hi, Op1Hi)};
SDValue Res[] = {DAG.getNode(ISD::VECTOR_INTERLEAVE, DL, VTs,
ArrayRef(Ops).slice(0, Factor)),
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL, VTs,
ArrayRef(Ops).slice(Factor, Factor))};

SetSplitVector(SDValue(N, 0), Res[0].getValue(0), Res[0].getValue(1));
SetSplitVector(SDValue(N, 1), Res[1].getValue(0), Res[1].getValue(1));
for (unsigned i = 0; i != Factor; ++i) {
unsigned IdxLo = 2 * i;
unsigned IdxHi = 2 * i + 1;
SetSplitVector(SDValue(N, i), Res[IdxLo / Factor].getValue(IdxLo % Factor),
Res[IdxHi / Factor].getValue(IdxHi % Factor));
}
}

//===----------------------------------------------------------------------===//
Expand Down
90 changes: 62 additions & 28 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8251,10 +8251,28 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
visitCallBrLandingPad(I);
return;
case Intrinsic::vector_interleave2:
visitVectorInterleave(I);
visitVectorInterleave(I, 2);
return;
case Intrinsic::vector_interleave3:
visitVectorInterleave(I, 3);
return;
case Intrinsic::vector_interleave5:
visitVectorInterleave(I, 5);
return;
case Intrinsic::vector_interleave7:
visitVectorInterleave(I, 7);
return;
case Intrinsic::vector_deinterleave2:
visitVectorDeinterleave(I);
visitVectorDeinterleave(I, 2);
return;
case Intrinsic::vector_deinterleave3:
visitVectorDeinterleave(I, 3);
return;
case Intrinsic::vector_deinterleave5:
visitVectorDeinterleave(I, 5);
return;
case Intrinsic::vector_deinterleave7:
visitVectorDeinterleave(I, 7);
return;
case Intrinsic::experimental_vector_compress:
setValue(&I, DAG.getNode(ISD::VECTOR_COMPRESS, sdl,
Expand Down Expand Up @@ -12565,59 +12583,75 @@ void SelectionDAGBuilder::visitVectorReverse(const CallInst &I) {
setValue(&I, DAG.getVectorShuffle(VT, DL, V, DAG.getUNDEF(VT), Mask));
}

void SelectionDAGBuilder::visitVectorDeinterleave(const CallInst &I) {
void SelectionDAGBuilder::visitVectorDeinterleave(const CallInst &I,
unsigned Factor) {
auto DL = getCurSDLoc();
SDValue InVec = getValue(I.getOperand(0));
EVT OutVT =
InVec.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());

SmallVector<EVT, 4> ValueVTs;
ComputeValueVTs(DAG.getTargetLoweringInfo(), DAG.getDataLayout(), I.getType(),
ValueVTs);

EVT OutVT = ValueVTs[0];
unsigned OutNumElts = OutVT.getVectorMinNumElements();

// ISD Node needs the input vectors split into two equal parts
SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT, InVec,
DAG.getVectorIdxConstant(0, DL));
SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT, InVec,
DAG.getVectorIdxConstant(OutNumElts, DL));
SmallVector<SDValue, 4> SubVecs(Factor);
for (unsigned i = 0; i != Factor; ++i) {
assert(ValueVTs[i] == OutVT && "Expected VTs to be the same");
SubVecs[i] = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OutVT, InVec,
DAG.getVectorIdxConstant(OutNumElts * i, DL));
}

// Use VECTOR_SHUFFLE for fixed-length vectors to benefit from existing
// legalisation and combines.
if (OutVT.isFixedLengthVector()) {
SDValue Even = DAG.getVectorShuffle(OutVT, DL, Lo, Hi,
// Use VECTOR_SHUFFLE for fixed-length vectors with factor of 2 to benefit
// from existing legalisation and combines.
if (OutVT.isFixedLengthVector() && Factor == 2) {
SDValue Even = DAG.getVectorShuffle(OutVT, DL, SubVecs[0], SubVecs[1],
createStrideMask(0, 2, OutNumElts));
SDValue Odd = DAG.getVectorShuffle(OutVT, DL, Lo, Hi,
SDValue Odd = DAG.getVectorShuffle(OutVT, DL, SubVecs[0], SubVecs[1],
createStrideMask(1, 2, OutNumElts));
SDValue Res = DAG.getMergeValues({Even, Odd}, getCurSDLoc());
setValue(&I, Res);
return;
}

SDValue Res = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL,
DAG.getVTList(OutVT, OutVT), Lo, Hi);
DAG.getVTList(ValueVTs), SubVecs);
setValue(&I, Res);
}

void SelectionDAGBuilder::visitVectorInterleave(const CallInst &I) {
void SelectionDAGBuilder::visitVectorInterleave(const CallInst &I,
unsigned Factor) {
auto DL = getCurSDLoc();
EVT InVT = getValue(I.getOperand(0)).getValueType();
SDValue InVec0 = getValue(I.getOperand(0));
SDValue InVec1 = getValue(I.getOperand(1));
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
EVT InVT = getValue(I.getOperand(0)).getValueType();
EVT OutVT = TLI.getValueType(DAG.getDataLayout(), I.getType());

// Use VECTOR_SHUFFLE for fixed-length vectors to benefit from existing
// legalisation and combines.
if (OutVT.isFixedLengthVector()) {
SmallVector<SDValue, 8> InVecs(Factor);
for (unsigned i = 0; i < Factor; ++i) {
InVecs[i] = getValue(I.getOperand(i));
assert(InVecs[i].getValueType() == InVecs[0].getValueType() &&
"Expected VTs to be the same");
}

// Use VECTOR_SHUFFLE for fixed-length vectors with factor of 2 to benefit
// from existing legalisation and combines.
if (OutVT.isFixedLengthVector() && Factor == 2) {
unsigned NumElts = InVT.getVectorMinNumElements();
SDValue V = DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, InVec0, InVec1);
SDValue V = DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, InVecs);
setValue(&I, DAG.getVectorShuffle(OutVT, DL, V, DAG.getUNDEF(OutVT),
createInterleaveMask(NumElts, 2)));
return;
}

SDValue Res = DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
DAG.getVTList(InVT, InVT), InVec0, InVec1);
Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, Res.getValue(0),
Res.getValue(1));
SmallVector<EVT, 8> ValueVTs(Factor, InVT);
SDValue Res =
DAG.getNode(ISD::VECTOR_INTERLEAVE, DL, DAG.getVTList(ValueVTs), InVecs);

SmallVector<SDValue, 8> Results(Factor);
for (unsigned i = 0; i < Factor; ++i)
Results[i] = Res.getValue(i);

Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, OutVT, Results);
setValue(&I, Res);
}

Expand Down
Loading