Skip to content

Commit d9f165d

Browse files
authored
[SDAG] Add an ISD node to help lower vector.extract.last.active (#118810)
Based on feedback from the clastb codegen PR, I'm refactoring basic codegen for the vector.extract.last.active intrinsic to lower to an ISD node in SelectionDAGBuilder then expand in LegalizeVectorOps, instead of doing everything in the builder. The new ISD node (vector_find_last_active) only covers finding the index of the last active element of the mask, and extracting the element + handling passthru is left to existing ISD nodes.
1 parent 5ce271e commit d9f165d

File tree

11 files changed

+169
-141
lines changed

11 files changed

+169
-141
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,10 @@ enum NodeType {
14801480
// Output: Output Chain
14811481
EXPERIMENTAL_VECTOR_HISTOGRAM,
14821482

1483+
// Finds the index of the last active mask element
1484+
// Operands: Mask
1485+
VECTOR_FIND_LAST_ACTIVE,
1486+
14831487
// llvm.clear_cache intrinsic
14841488
// Operands: Input Chain, Start Addres, End Address
14851489
// Outputs: Output Chain

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5368,6 +5368,11 @@ class TargetLowering : public TargetLoweringBase {
53685368
/// \returns The expansion result or SDValue() if it fails.
53695369
SDValue expandVPCTTZElements(SDNode *N, SelectionDAG &DAG) const;
53705370

5371+
/// Expand VECTOR_FIND_LAST_ACTIVE nodes
5372+
/// \param N Node to expand
5373+
/// \returns The expansion result or SDValue() if it fails.
5374+
SDValue expandVectorFindLastActive(SDNode *N, SelectionDAG &DAG) const;
5375+
53715376
/// Expand ABS nodes. Expands vector/scalar ABS nodes,
53725377
/// vector nodes can only succeed if all operations are legal/custom.
53735378
/// (ABS x) -> (XOR (ADD x, (SRA x, type_size)), (SRA x, type_size))

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,10 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
155155
case ISD::ZERO_EXTEND_VECTOR_INREG:
156156
Res = PromoteIntRes_EXTEND_VECTOR_INREG(N); break;
157157

158+
case ISD::VECTOR_FIND_LAST_ACTIVE:
159+
Res = PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(N);
160+
break;
161+
158162
case ISD::SIGN_EXTEND:
159163
case ISD::VP_SIGN_EXTEND:
160164
case ISD::ZERO_EXTEND:
@@ -2069,6 +2073,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
20692073
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
20702074
Res = PromoteIntOp_VECTOR_HISTOGRAM(N, OpNo);
20712075
break;
2076+
case ISD::VECTOR_FIND_LAST_ACTIVE:
2077+
Res = PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(N, OpNo);
2078+
break;
20722079
}
20732080

20742081
// If the result is null, the sub-method took care of registering results etc.
@@ -2810,6 +2817,13 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N,
28102817
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
28112818
}
28122819

2820+
SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
2821+
unsigned OpNo) {
2822+
SmallVector<SDValue, 1> NewOps(N->ops());
2823+
NewOps[OpNo] = GetPromotedInteger(N->getOperand(OpNo));
2824+
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
2825+
}
2826+
28132827
//===----------------------------------------------------------------------===//
28142828
// Integer Result Expansion
28152829
//===----------------------------------------------------------------------===//
@@ -6120,6 +6134,12 @@ SDValue DAGTypeLegalizer::PromoteIntRes_EXTEND_VECTOR_INREG(SDNode *N) {
61206134
return DAG.getNode(N->getOpcode(), dl, NVT, N->getOperand(0));
61216135
}
61226136

6137+
SDValue DAGTypeLegalizer::PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N) {
6138+
EVT VT = N->getValueType(0);
6139+
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
6140+
return DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, SDLoc(N), NVT, N->ops());
6141+
}
6142+
61236143
SDValue DAGTypeLegalizer::PromoteIntRes_INSERT_VECTOR_ELT(SDNode *N) {
61246144
EVT OutVT = N->getValueType(0);
61256145
EVT NOutVT = TLI.getTypeToTransformTo(*DAG.getContext(), OutVT);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
378378
SDValue PromoteIntRes_VPFunnelShift(SDNode *N);
379379
SDValue PromoteIntRes_IS_FPCLASS(SDNode *N);
380380
SDValue PromoteIntRes_PATCHPOINT(SDNode *N);
381+
SDValue PromoteIntRes_VECTOR_FIND_LAST_ACTIVE(SDNode *N);
381382

382383
// Integer Operand Promotion.
383384
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
@@ -428,6 +429,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
428429
SDValue PromoteIntOp_VP_STRIDED(SDNode *N, unsigned OpNo);
429430
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
430431
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
432+
SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
431433

432434
void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
433435
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
503503
case ISD::VECREDUCE_FMIN:
504504
case ISD::VECREDUCE_FMAXIMUM:
505505
case ISD::VECREDUCE_FMINIMUM:
506+
case ISD::VECTOR_FIND_LAST_ACTIVE:
506507
Action = TLI.getOperationAction(Node->getOpcode(),
507508
Node->getOperand(0).getValueType());
508509
break;
@@ -1225,6 +1226,9 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
12251226
case ISD::VECTOR_COMPRESS:
12261227
Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG));
12271228
return;
1229+
case ISD::VECTOR_FIND_LAST_ACTIVE:
1230+
Results.push_back(TLI.expandVectorFindLastActive(Node, DAG));
1231+
return;
12281232
case ISD::SCMP:
12291233
case ISD::UCMP:
12301234
Results.push_back(TLI.expandCMP(Node, DAG));

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6427,42 +6427,25 @@ void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
64276427
assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
64286428
"Tried lowering invalid vector extract last");
64296429
SDLoc sdl = getCurSDLoc();
6430+
const DataLayout &Layout = DAG.getDataLayout();
64306431
SDValue Data = getValue(I.getOperand(0));
64316432
SDValue Mask = getValue(I.getOperand(1));
6432-
SDValue PassThru = getValue(I.getOperand(2));
64336433

6434-
EVT DataVT = Data.getValueType();
6435-
EVT ScalarVT = PassThru.getValueType();
6436-
EVT BoolVT = Mask.getValueType().getScalarType();
6437-
6438-
// Find a suitable type for a stepvector.
6439-
ConstantRange VScaleRange(1, /*isFullSet=*/true); // Dummy value.
6440-
if (DataVT.isScalableVector())
6441-
VScaleRange = getVScaleRange(I.getCaller(), 64);
64426434
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
6443-
unsigned EltWidth = TLI.getBitWidthForCttzElements(
6444-
I.getType(), DataVT.getVectorElementCount(), /*ZeroIsPoison=*/true,
6445-
&VScaleRange);
6446-
MVT StepVT = MVT::getIntegerVT(EltWidth);
6447-
EVT StepVecVT = DataVT.changeVectorElementType(StepVT);
6448-
6449-
// Zero out lanes with inactive elements, then find the highest remaining
6450-
// value from the stepvector.
6451-
SDValue Zeroes = DAG.getConstant(0, sdl, StepVecVT);
6452-
SDValue StepVec = DAG.getStepVector(sdl, StepVecVT);
6453-
SDValue ActiveElts = DAG.getSelect(sdl, StepVecVT, Mask, StepVec, Zeroes);
6454-
SDValue HighestIdx =
6455-
DAG.getNode(ISD::VECREDUCE_UMAX, sdl, StepVT, ActiveElts);
6456-
6457-
// Extract the corresponding lane from the data vector
6458-
EVT ExtVT = TLI.getVectorIdxTy(DAG.getDataLayout());
6459-
SDValue Idx = DAG.getZExtOrTrunc(HighestIdx, sdl, ExtVT);
6460-
SDValue Extract =
6461-
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, ScalarVT, Data, Idx);
6462-
6463-
// If all mask lanes were inactive, choose the passthru value instead.
6464-
SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, sdl, BoolVT, Mask);
6465-
SDValue Result = DAG.getSelect(sdl, ScalarVT, AnyActive, Extract, PassThru);
6435+
EVT ResVT = TLI.getValueType(Layout, I.getType());
6436+
6437+
EVT ExtVT = TLI.getVectorIdxTy(Layout);
6438+
SDValue Idx = DAG.getNode(ISD::VECTOR_FIND_LAST_ACTIVE, sdl, ExtVT, Mask);
6439+
SDValue Result = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, sdl, ResVT, Data, Idx);
6440+
6441+
Value *Default = I.getOperand(2);
6442+
if (!isa<PoisonValue>(Default) && !isa<UndefValue>(Default)) {
6443+
SDValue PassThru = getValue(Default);
6444+
EVT BoolVT = Mask.getValueType().getScalarType();
6445+
SDValue AnyActive = DAG.getNode(ISD::VECREDUCE_OR, sdl, BoolVT, Mask);
6446+
Result = DAG.getSelect(sdl, ResVT, AnyActive, Result, PassThru);
6447+
}
6448+
64666449
setValue(&I, Result);
64676450
}
64686451

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
567567
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
568568
return "histogram";
569569

570+
case ISD::VECTOR_FIND_LAST_ACTIVE:
571+
return "find_last_active";
572+
570573
// Vector Predication
571574
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
572575
case ISD::SDID: \

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "llvm/CodeGen/TargetLowering.h"
1414
#include "llvm/ADT/STLExtras.h"
15+
#include "llvm/Analysis/ValueTracking.h"
1516
#include "llvm/Analysis/VectorUtils.h"
1617
#include "llvm/CodeGen/CallingConvLower.h"
1718
#include "llvm/CodeGen/CodeGenCommonISel.h"
@@ -9451,6 +9452,43 @@ SDValue TargetLowering::expandVPCTTZElements(SDNode *N,
94519452
return DAG.getNode(ISD::VP_REDUCE_UMIN, DL, ResVT, ExtEVL, Select, Mask, EVL);
94529453
}
94539454

9455+
SDValue TargetLowering::expandVectorFindLastActive(SDNode *N,
9456+
SelectionDAG &DAG) const {
9457+
SDLoc DL(N);
9458+
SDValue Mask = N->getOperand(0);
9459+
EVT MaskVT = Mask.getValueType();
9460+
EVT BoolVT = MaskVT.getScalarType();
9461+
9462+
// Find a suitable type for a stepvector.
9463+
ConstantRange VScaleRange(1, /*isFullSet=*/true); // Fixed length default.
9464+
if (MaskVT.isScalableVector())
9465+
VScaleRange = getVScaleRange(&DAG.getMachineFunction().getFunction(), 64);
9466+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
9467+
unsigned EltWidth = TLI.getBitWidthForCttzElements(
9468+
BoolVT.getTypeForEVT(*DAG.getContext()), MaskVT.getVectorElementCount(),
9469+
/*ZeroIsPoison=*/true, &VScaleRange);
9470+
EVT StepVT = MVT::getIntegerVT(EltWidth);
9471+
EVT StepVecVT = MaskVT.changeVectorElementType(StepVT);
9472+
9473+
// If promotion is required to make the type legal, do it here; promotion
9474+
// of integers within LegalizeVectorOps is looking for types of the same
9475+
// size but with a smaller number of larger elements, not the usual larger
9476+
// size with the same number of larger elements.
9477+
if (TLI.getTypeAction(StepVecVT.getSimpleVT()) ==
9478+
TargetLowering::TypePromoteInteger) {
9479+
StepVecVT = TLI.getTypeToTransformTo(*DAG.getContext(), StepVecVT);
9480+
StepVT = StepVecVT.getVectorElementType();
9481+
}
9482+
9483+
// Zero out lanes with inactive elements, then find the highest remaining
9484+
// value from the stepvector.
9485+
SDValue Zeroes = DAG.getConstant(0, DL, StepVecVT);
9486+
SDValue StepVec = DAG.getStepVector(DL, StepVecVT);
9487+
SDValue ActiveElts = DAG.getSelect(DL, StepVecVT, Mask, StepVec, Zeroes);
9488+
SDValue HighestIdx = DAG.getNode(ISD::VECREDUCE_UMAX, DL, StepVT, ActiveElts);
9489+
return DAG.getZExtOrTrunc(HighestIdx, DL, N->getValueType(0));
9490+
}
9491+
94549492
SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
94559493
bool IsNegative) const {
94569494
SDLoc dl(N);

llvm/lib/CodeGen/TargetLoweringBase.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,9 @@ void TargetLoweringBase::initActions() {
818818
setOperationAction(ISD::SDOPC, VT, Expand);
819819
#include "llvm/IR/VPIntrinsics.def"
820820

821+
// Masked vector extracts default to expand.
822+
setOperationAction(ISD::VECTOR_FIND_LAST_ACTIVE, VT, Expand);
823+
821824
// FP environment operations default to expand.
822825
setOperationAction(ISD::GET_FPENV, VT, Expand);
823826
setOperationAction(ISD::SET_FPENV, VT, Expand);

llvm/test/CodeGen/AArch64/vector-extract-last-active.ll

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ define i16 @extract_last_i16_scalable(<vscale x 8 x i16> %data, <vscale x 8 x i1
318318
; CHECK-NEXT: sel z1.h, p0, z1.h, z2.h
319319
; CHECK-NEXT: umaxv h1, p1, z1.h
320320
; CHECK-NEXT: fmov w8, s1
321-
; CHECK-NEXT: and x8, x8, #0xff
321+
; CHECK-NEXT: and x8, x8, #0xffff
322322
; CHECK-NEXT: whilels p2.h, xzr, x8
323323
; CHECK-NEXT: ptest p1, p0.b
324324
; CHECK-NEXT: lastb w8, p2, z0.h
@@ -337,7 +337,7 @@ define i32 @extract_last_i32_scalable(<vscale x 4 x i32> %data, <vscale x 4 x i1
337337
; CHECK-NEXT: sel z1.s, p0, z1.s, z2.s
338338
; CHECK-NEXT: umaxv s1, p1, z1.s
339339
; CHECK-NEXT: fmov w8, s1
340-
; CHECK-NEXT: and x8, x8, #0xff
340+
; CHECK-NEXT: mov w8, w8
341341
; CHECK-NEXT: whilels p2.s, xzr, x8
342342
; CHECK-NEXT: ptest p1, p0.b
343343
; CHECK-NEXT: lastb w8, p2, z0.s
@@ -356,7 +356,6 @@ define i64 @extract_last_i64_scalable(<vscale x 2 x i64> %data, <vscale x 2 x i1
356356
; CHECK-NEXT: sel z1.d, p0, z1.d, z2.d
357357
; CHECK-NEXT: umaxv d1, p1, z1.d
358358
; CHECK-NEXT: fmov x8, d1
359-
; CHECK-NEXT: and x8, x8, #0xff
360359
; CHECK-NEXT: whilels p2.d, xzr, x8
361360
; CHECK-NEXT: ptest p1, p0.b
362361
; CHECK-NEXT: lastb x8, p2, z0.d
@@ -375,7 +374,7 @@ define float @extract_last_float_scalable(<vscale x 4 x float> %data, <vscale x
375374
; CHECK-NEXT: sel z2.s, p0, z2.s, z3.s
376375
; CHECK-NEXT: umaxv s2, p1, z2.s
377376
; CHECK-NEXT: fmov w8, s2
378-
; CHECK-NEXT: and x8, x8, #0xff
377+
; CHECK-NEXT: mov w8, w8
379378
; CHECK-NEXT: whilels p2.s, xzr, x8
380379
; CHECK-NEXT: ptest p1, p0.b
381380
; CHECK-NEXT: lastb s0, p2, z0.s
@@ -394,7 +393,6 @@ define double @extract_last_double_scalable(<vscale x 2 x double> %data, <vscale
394393
; CHECK-NEXT: sel z2.d, p0, z2.d, z3.d
395394
; CHECK-NEXT: umaxv d2, p1, z2.d
396395
; CHECK-NEXT: fmov x8, d2
397-
; CHECK-NEXT: and x8, x8, #0xff
398396
; CHECK-NEXT: whilels p2.d, xzr, x8
399397
; CHECK-NEXT: ptest p1, p0.b
400398
; CHECK-NEXT: lastb d0, p2, z0.d
@@ -404,6 +402,24 @@ define double @extract_last_double_scalable(<vscale x 2 x double> %data, <vscale
404402
ret double %res
405403
}
406404

405+
;; If the passthru parameter is poison, we shouldn't see a select at the end.
406+
define i8 @extract_last_i8_scalable_poison_passthru(<vscale x 16 x i8> %data, <vscale x 16 x i1> %mask) #0 {
407+
; CHECK-LABEL: extract_last_i8_scalable_poison_passthru:
408+
; CHECK: // %bb.0:
409+
; CHECK-NEXT: index z1.b, #0, #1
410+
; CHECK-NEXT: mov z2.b, #0 // =0x0
411+
; CHECK-NEXT: sel z1.b, p0, z1.b, z2.b
412+
; CHECK-NEXT: ptrue p0.b
413+
; CHECK-NEXT: umaxv b1, p0, z1.b
414+
; CHECK-NEXT: fmov w8, s1
415+
; CHECK-NEXT: and x8, x8, #0xff
416+
; CHECK-NEXT: whilels p0.b, xzr, x8
417+
; CHECK-NEXT: lastb w0, p0, z0.b
418+
; CHECK-NEXT: ret
419+
%res = call i8 @llvm.experimental.vector.extract.last.active.nxv16i8(<vscale x 16 x i8> %data, <vscale x 16 x i1> %mask, i8 poison)
420+
ret i8 %res
421+
}
422+
407423
declare i8 @llvm.experimental.vector.extract.last.active.v16i8(<16 x i8>, <16 x i1>, i8)
408424
declare i16 @llvm.experimental.vector.extract.last.active.v8i16(<8 x i16>, <8 x i1>, i16)
409425
declare i32 @llvm.experimental.vector.extract.last.active.v4i32(<4 x i32>, <4 x i1>, i32)

0 commit comments

Comments
 (0)