Skip to content

[SelectionDAG] Add an ISD node for vector.extract.last.active #118810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,10 @@ enum NodeType {
// Output: Output Chain
EXPERIMENTAL_VECTOR_HISTOGRAM,

// Finds the index of the last active mask element
// Operands: Mask
VECTOR_FIND_LAST_ACTIVE,

// llvm.clear_cache intrinsic
// Operands: Input Chain, Start Addres, End Address
// Outputs: Output Chain
Expand Down
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5368,6 +5368,11 @@ class TargetLowering : public TargetLoweringBase {
/// \returns The expansion result or SDValue() if it fails.
SDValue expandVPCTTZElements(SDNode *N, SelectionDAG &DAG) const;

/// Expand VECTOR_FIND_LAST_ACTIVE nodes
/// \param N Node to expand
/// \returns The expansion result or SDValue() if it fails.
SDValue expandVectorFindLastActive(SDNode *N, SelectionDAG &DAG) const;

/// Expand ABS nodes. Expands vector/scalar ABS nodes,
/// vector nodes can only succeed if all operations are legal/custom.
/// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size))
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::ZERO_EXTEND_VECTOR_INREG:
Res = PromoteIntRes_EXTEND_VECTOR_INREG(N); break;

case ISD::VECTOR_FIND_LAST_ACTIVE:
Res = PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(N);
break;

case ISD::SIGN_EXTEND:
case ISD::VP_SIGN_EXTEND:
case ISD::ZERO_EXTEND:
Expand Down Expand Up @@ -2069,6 +2073,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
Res = PromoteIntOp_VECTOR_HISTOGRAM(N, OpNo);
break;
case ISD::VECTOR_FIND_LAST_ACTIVE:
Res = PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(N, OpNo);
break;
}

// If the result is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -2810,6 +2817,13 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N,
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}

SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
unsigned OpNo) {
SmallVector<SDValue, 1> NewOps(N->ops());
NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}

//===----------------------------------------------------------------------===//
// Integer Result Expansion
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -6124,6 +6138,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_EXTEND_VECTOR_INREG(SDNode *N) {
return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
}

SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
EVT VT = N->getValueType(0);
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
return DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, SDLoc(N), NVT, N->ops());
}

SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
EVT OutVT = N->getValueType(0);
EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_VPFunnelShift(SDNode *N);
SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
SDValue PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N);

// Integer Operand Promotion.
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
Expand Down Expand Up @@ -428,6 +429,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_VP_STRIDED(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);

void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::VECREDUCE_FMIN:
case ISD::VECREDUCE_FMAXIMUM:
case ISD::VECREDUCE_FMINIMUM:
case ISD::VECTOR_FIND_LAST_ACTIVE:
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
break;
Expand Down Expand Up @@ -1208,6 +1209,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
case ISD::VECTOR_COMPRESS:
Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG));
return;
case ISD::VECTOR_FIND_LAST_ACTIVE:
Results.push_back(TLI.expandVectorFindLastActive(Node, DAG));
return;
case ISD::SCMP:
case ISD::UCMP:
Results.push_back(TLI.expandCMP(Node, DAG));
Expand Down
47 changes: 15 additions & 32 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6427,42 +6427,25 @@ void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
"Tried lowering invalid vector extract last");
SDLoc sdl = getCurSDLoc();
const DataLayout &Layout = DAG.getDataLayout();
SDValue Data = getValue(I.getOperand(0));
SDValue Mask = getValue(I.getOperand(1));
SDValue PassThru = getValue(I.getOperand(2));

EVT DataVT = Data.getValueType();
EVT ScalarVT = PassThru.getValueType();
EVT BoolVT = Mask.getValueType().getScalarType();

// Find a suitable type for a stepvector.
ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value.
if (DataVT.isScalableVector())
VScaleRange = getVScaleRange(I.getCaller(), 64);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
unsigned EltWidth = TLI.getBitWidthForCttzElements(
I.getType(), DataVT.getVectorElementCount(), /*ZeroIsPoison=*/true,
&VScaleRange);
MVT StepVT = MVT::getIntegerVT(EltWidth);
EVT StepVecVT = DataVT.changeVectorElementType(StepVT);

// Zero out lanes with inactive elements, then find the highest remaining
// value from the stepvector.
SDValue Zeroes = DAG.getConstant(0, sdl, StepVecVT);
SDValue StepVec = DAG.getStepVector(sdl, StepVecVT);
SDValue ActiveElts = DAG.getSelect(sdl, StepVecVT, Mask, StepVec, Zeroes);
SDValue HighestIdx =
DAG.getNode(ISD::VECREDUCE_UMAX, sdl, StepVT, ActiveElts);

// Extract the corresponding lane from the data vector
EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout());
SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, sdl, ExtVT);
SDValue Extract =
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, ScalarVT, Data, Idx);

// If all mask lanes were inactive, choose the passthru value instead.
SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, sdl, BoolVT, Mask);
SDValue Result = DAG.getSelect(sdl, ScalarVT, AnyActive, Extract, PassThru);
EVT ResVT = TLI.getValueType(Layout, I.getType());

EVT ExtVT = TLI.getVectorIdxTy(Layout);
SDValue Idx = DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, sdl, ExtVT, Mask);
SDValue Result = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, ResVT, Data, Idx);

Value *Default = I.getOperand(2);
if (!isa<PoisonValue>(Default) && !isa<UndefValue>(Default)) {
SDValue PassThru = getValue(Default);
EVT BoolVT = Mask.getValueType().getScalarType();
SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, sdl, BoolVT, Mask);
Result = DAG.getSelect(sdl, ResVT, AnyActive, Result, PassThru);
}

setValue(&I, Result);
}

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return "histogram";

case ISD::VECTOR_FIND_LAST_ACTIVE:
return "find_last_active";

// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
case ISD::SDID: \
Expand Down
38 changes: 38 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/CallingConvLower.h"
#include "llvm/CodeGen/CodeGenCommonISel.h"
Expand Down Expand Up @@ -9453,6 +9454,43 @@ SDValue TargetLowering::expandVPCTTZElements(SDNode *N,
return DAG.getNode(ISD::VP_REDUCE_UMIN, DL, ResVT, ExtEVL, Select, Mask, EVL);
}

SDValue TargetLowering::expandVectorFindLastActive(SDNode *N,
SelectionDAG &DAG) const {
SDLoc DL(N);
SDValue Mask = N->getOperand(0);
EVT MaskVT = Mask.getValueType();
EVT BoolVT = MaskVT.getScalarType();

// Find a suitable type for a stepvector.
ConstantRange VScaleRange(1, /*isFullSet=*/true); // Fixed length default.
if (MaskVT.isScalableVector())
VScaleRange = getVScaleRange(&DAG.getMachineFunction().getFunction(), 64);
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
unsigned EltWidth = TLI.getBitWidthForCttzElements(
BoolVT.getTypeForEVT(*DAG.getContext()), MaskVT.getVectorElementCount(),
/*ZeroIsPoison=*/true, &VScaleRange);
EVT StepVT = MVT::getIntegerVT(EltWidth);
EVT StepVecVT = MaskVT.changeVectorElementType(StepVT);

// If promotion is required to make the type legal, do it here; promotion
// of integers within LegalizeVectorOps is looking for types of the same
// size but with a smaller number of larger elements, not the usual larger
// size with the same number of larger elements.
if (TLI.getTypeAction(StepVecVT.getSimpleVT()) ==
TargetLowering::TypePromoteInteger) {
StepVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), StepVecVT);
StepVT = StepVecVT.getVectorElementType();
}

// Zero out lanes with inactive elements, then find the highest remaining
// value from the stepvector.
SDValue Zeroes = DAG.getConstant(0, DL, StepVecVT);
SDValue StepVec = DAG.getStepVector(DL, StepVecVT);
SDValue ActiveElts = DAG.getSelect(DL, StepVecVT, Mask, StepVec, Zeroes);
SDValue HighestIdx = DAG.getNode(ISD::VECREDUCE_UMAX, DL, StepVT, ActiveElts);
return DAG.getZExtOrTrunc(HighestIdx, DL, N->getValueType(0));
}

SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
bool IsNegative) const {
SDLoc dl(N);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,9 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::SDOPC, VT, Expand);
#include "llvm/IR/VPIntrinsics.def"

// Masked vector extracts default to expand.
setOperationAction(ISD::VECTOR_FIND_LAST_ACTIVE, VT, Expand);

// FP environment operations default to expand.
setOperationAction(ISD::GET_FPENV, VT, Expand);
setOperationAction(ISD::SET_FPENV, VT, Expand);
Expand Down
26 changes: 21 additions & 5 deletions llvm/test/CodeGen/AArch64/vector-extract-last-active.ll
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ define i16 @extract_last_i16_scalable(<vscale x 8 x i16> %data, <vscale x 8 x i1
; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h
; CHECK-NEXT: umaxv h1, p1, z1.h
; CHECK-NEXT: fmov w8, s1
; CHECK-NEXT: and x8, x8, #0xff
; CHECK-NEXT: and x8, x8, #0xffff
; CHECK-NEXT: whilels p2.h, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb w8, p2, z0.h
Expand All @@ -337,7 +337,7 @@ define i32 @extract_last_i32_scalable(<vscale x 4 x i32> %data, <vscale x 4 x i1
; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s
; CHECK-NEXT: umaxv s1, p1, z1.s
; CHECK-NEXT: fmov w8, s1
; CHECK-NEXT: and x8, x8, #0xff
; CHECK-NEXT: mov w8, w8
; CHECK-NEXT: whilels p2.s, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb w8, p2, z0.s
Expand All @@ -356,7 +356,6 @@ define i64 @extract_last_i64_scalable(<vscale x 2 x i64> %data, <vscale x 2 x i1
; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d
; CHECK-NEXT: umaxv d1, p1, z1.d
; CHECK-NEXT: fmov x8, d1
; CHECK-NEXT: and x8, x8, #0xff
; CHECK-NEXT: whilels p2.d, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb x8, p2, z0.d
Expand All @@ -375,7 +374,7 @@ define float @extract_last_float_scalable(<vscale x 4 x float> %data, <vscale x
; CHECK-NEXT: sel z2.s, p0, z2.s, z3.s
; CHECK-NEXT: umaxv s2, p1, z2.s
; CHECK-NEXT: fmov w8, s2
; CHECK-NEXT: and x8, x8, #0xff
; CHECK-NEXT: mov w8, w8
; CHECK-NEXT: whilels p2.s, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb s0, p2, z0.s
Expand All @@ -394,7 +393,6 @@ define double @extract_last_double_scalable(<vscale x 2 x double> %data, <vscale
; CHECK-NEXT: sel z2.d, p0, z2.d, z3.d
; CHECK-NEXT: umaxv d2, p1, z2.d
; CHECK-NEXT: fmov x8, d2
; CHECK-NEXT: and x8, x8, #0xff
; CHECK-NEXT: whilels p2.d, xzr, x8
; CHECK-NEXT: ptest p1, p0.b
; CHECK-NEXT: lastb d0, p2, z0.d
Expand All @@ -404,6 +402,24 @@ define double @extract_last_double_scalable(<vscale x 2 x double> %data, <vscale
ret double %res
}

;; If the passthru parameter is poison, we shouldn't see a select at the end.
define i8 @extract_last_i8_scalable_poison_passthru(<vscale x 16 x i8> %data, <vscale x 16 x i1> %mask) #0 {
; CHECK-LABEL: extract_last_i8_scalable_poison_passthru:
; CHECK: // %bb.0:
; CHECK-NEXT: index z1.b, #0, #1
; CHECK-NEXT: mov z2.b, #0 // =0x0
; CHECK-NEXT: sel z1.b, p0, z1.b, z2.b
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: umaxv b1, p0, z1.b
; CHECK-NEXT: fmov w8, s1
; CHECK-NEXT: and x8, x8, #0xff
; CHECK-NEXT: whilels p0.b, xzr, x8
; CHECK-NEXT: lastb w0, p0, z0.b
; CHECK-NEXT: ret
%res = call i8 @llvm.experimental.vector.extract.last.active.nxv16i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %mask, i8 poison)
ret i8 %res
}

declare i8 @llvm.experimental.vector.extract.last.active.v16i8(<16 x i8>, <16 x i1>, i8)
declare i16 @llvm.experimental.vector.extract.last.active.v8i16(<8 x i16>, <8 x i1>, i16)
declare i32 @llvm.experimental.vector.extract.last.active.v4i32(<4 x i32>, <4 x i1>, i32)
Expand Down
Loading
Loading