Skip to content

[AArch64][SelectionDAG] Vector splitting and promotion for histogram intrinsic #103037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,11 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Action = TLI.getOperationAction(Node->getOpcode(),
Node->getOperand(0).getValueType());
break;
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
Action = TLI.getOperationAction(
Node->getOpcode(),
cast<MaskedHistogramSDNode>(Node)->getIndex().getValueType());
break;
default:
if (Node->getOpcode() >= ISD::BUILTIN_OP_END) {
Action = TLI.getCustomOperationAction(*Node);
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2037,6 +2037,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
case ISD::EXPERIMENTAL_VP_SPLICE:
Res = PromoteIntOp_VP_SPLICE(N, OpNo);
break;
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
Res = PromoteIntOp_VECTOR_HISTOGRAM(N, OpNo);
break;
}

// If the result is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -2749,6 +2752,14 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo) {
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}

SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N,
unsigned OpNo) {
assert(OpNo == 1 && "Unexpected operand for promotion");
SmallVector<SDValue, 7> NewOps(N->ops());
NewOps[1] = GetPromotedInteger(N->getOperand(1));
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
}

//===----------------------------------------------------------------------===//
// Integer Result Expansion
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntOp_PATCHPOINT(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VP_STRIDED(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);

void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code);
Expand Down Expand Up @@ -972,6 +973,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue SplitVecOp_CMP(SDNode *N);
SDValue SplitVecOp_FP_TO_XINT_SAT(SDNode *N);
SDValue SplitVecOp_VP_CttzElements(SDNode *N);
SDValue SplitVecOp_VECTOR_HISTOGRAM(SDNode *N);

//===--------------------------------------------------------------------===//
// Vector Widening Support: LegalizeVectorTypes.cpp
Expand Down
25 changes: 25 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3264,6 +3264,9 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::VP_CTTZ_ELTS_ZERO_UNDEF:
Res = SplitVecOp_VP_CttzElements(N);
break;
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
Res = SplitVecOp_VECTOR_HISTOGRAM(N);
break;
}

// If the result is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -4274,6 +4277,28 @@ SDValue DAGTypeLegalizer::SplitVecOp_VP_CttzElements(SDNode *N) {
DAG.getNode(ISD::ADD, DL, ResVT, VLo, ResHi));
}

SDValue DAGTypeLegalizer::SplitVecOp_VECTOR_HISTOGRAM(SDNode *N) {
MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(N);
SDLoc DL(HG);
SDValue Inc = HG->getInc();
SDValue Ptr = HG->getBasePtr();
SDValue Scale = HG->getScale();
SDValue IntID = HG->getIntID();
EVT MemVT = HG->getMemoryVT();
MachineMemOperand *MMO = HG->getMemOperand();
ISD::MemIndexType IndexType = HG->getIndexType();

SDValue IndexLo, IndexHi, MaskLo, MaskHi;
std::tie(IndexLo, IndexHi) = DAG.SplitVector(HG->getIndex(), DL);
std::tie(MaskLo, MaskHi) = DAG.SplitVector(HG->getMask(), DL);
SDValue OpsLo[] = {HG->getChain(), Inc, MaskLo, Ptr, IndexLo, Scale, IntID};
SDValue Lo = DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL,
OpsLo, MMO, IndexType);
SDValue OpsHi[] = {Lo, Inc, MaskHi, Ptr, IndexHi, Scale, IntID};
return DAG.getMaskedHistogram(DAG.getVTList(MVT::Other), MemVT, DL, OpsHi,
MMO, IndexType);
}

//===----------------------------------------------------------------------===//
// Result Vector Widening
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 4 additions & 5 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1777,10 +1777,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,

// Histcnt is SVE2 only
if (Subtarget->hasSVE2()) {
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv4i32,
Custom);
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
Custom);
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
}
}

Expand Down Expand Up @@ -28176,11 +28176,10 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
assert(CID->getZExtValue() == Intrinsic::experimental_vector_histogram_add &&
"Unexpected histogram update operation");

EVT IncVT = Inc.getValueType();
EVT IndexVT = Index.getValueType();
LLVMContext &Ctx = *DAG.getContext();
ElementCount EC = IndexVT.getVectorElementCount();
EVT MemVT = EVT::getVectorVT(Ctx, IncVT, EC);
EVT MemVT = EVT::getVectorVT(Ctx, HG->getMemoryVT(), EC);
EVT IncExtVT =
EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
EVT IncSplatVT = EVT::getVectorVT(Ctx, IncExtVT, EC);
Expand Down
103 changes: 101 additions & 2 deletions llvm/test/CodeGen/AArch64/sve2-histcnt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vs
; CHECK-LABEL: histogram_i16_literal_1:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: mov z3.s, #1 // =0x1
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
; CHECK-NEXT: add z1.s, z2.s, z1.s
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
; CHECK-NEXT: ret
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
Expand All @@ -145,8 +147,10 @@ define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vs
; CHECK-LABEL: histogram_i16_literal_2:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: mov z3.s, #2 // =0x2
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1]
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
; CHECK-NEXT: ret
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
Expand All @@ -169,4 +173,99 @@ define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vs
ret void
}

define void @histogram_i64_4_lane(<vscale x 4 x ptr> %buckets, i64 %inc, <vscale x 4 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i64_4_lane:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: mov z4.d, x0
; CHECK-NEXT: ptrue p2.d
; CHECK-NEXT: histcnt z2.d, p1/z, z0.d, z0.d
; CHECK-NEXT: ld1d { z3.d }, p1/z, [z0.d]
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: mad z2.d, p2/m, z4.d, z3.d
; CHECK-NEXT: st1d { z2.d }, p1, [z0.d]
; CHECK-NEXT: histcnt z0.d, p0/z, z1.d, z1.d
; CHECK-NEXT: ld1d { z2.d }, p0/z, [z1.d]
; CHECK-NEXT: mad z0.d, p2/m, z4.d, z2.d
; CHECK-NEXT: st1d { z0.d }, p0, [z1.d]
; CHECK-NEXT: ret
call void @llvm.experimental.vector.histogram.add.nxv4p0.i64(<vscale x 4 x ptr> %buckets, i64 %inc, <vscale x 4 x i1> %mask)
ret void
}

define void @histogram_i64_8_lane(<vscale x 8 x ptr> %buckets, i64 %inc, <vscale x 8 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i64_8_lane:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p2.h, p0.b
; CHECK-NEXT: mov z6.d, x0
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: punpklo p3.h, p2.b
; CHECK-NEXT: punpkhi p2.h, p2.b
; CHECK-NEXT: histcnt z4.d, p3/z, z0.d, z0.d
; CHECK-NEXT: ld1d { z5.d }, p3/z, [z0.d]
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: mad z4.d, p1/m, z6.d, z5.d
; CHECK-NEXT: st1d { z4.d }, p3, [z0.d]
; CHECK-NEXT: histcnt z0.d, p2/z, z1.d, z1.d
; CHECK-NEXT: ld1d { z4.d }, p2/z, [z1.d]
; CHECK-NEXT: mad z0.d, p1/m, z6.d, z4.d
; CHECK-NEXT: st1d { z0.d }, p2, [z1.d]
; CHECK-NEXT: punpklo p2.h, p0.b
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: histcnt z0.d, p2/z, z2.d, z2.d
; CHECK-NEXT: ld1d { z1.d }, p2/z, [z2.d]
; CHECK-NEXT: mad z0.d, p1/m, z6.d, z1.d
; CHECK-NEXT: st1d { z0.d }, p2, [z2.d]
; CHECK-NEXT: histcnt z0.d, p0/z, z3.d, z3.d
; CHECK-NEXT: ld1d { z1.d }, p0/z, [z3.d]
; CHECK-NEXT: mad z0.d, p1/m, z6.d, z1.d
; CHECK-NEXT: st1d { z0.d }, p0, [z3.d]
; CHECK-NEXT: ret
call void @llvm.experimental.vector.histogram.add.nxv8p0.i64(<vscale x 8 x ptr> %buckets, i64 %inc, <vscale x 8 x i1> %mask)
ret void
}

define void @histogram_i32_8_lane(ptr %base, <vscale x 8 x i32> %indices, i32 %inc, <vscale x 8 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i32_8_lane:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: mov z4.s, w1
; CHECK-NEXT: ptrue p2.s
; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
; CHECK-NEXT: ld1w { z3.s }, p1/z, [x0, z0.s, sxtw #2]
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
; CHECK-NEXT: st1w { z2.s }, p1, [x0, z0.s, sxtw #2]
; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
; CHECK-NEXT: ld1w { z2.s }, p0/z, [x0, z1.s, sxtw #2]
; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
; CHECK-NEXT: st1w { z0.s }, p0, [x0, z1.s, sxtw #2]
; CHECK-NEXT: ret
%buckets = getelementptr i32, ptr %base, <vscale x 8 x i32> %indices
call void @llvm.experimental.vector.histogram.add.nxv8p0.i32(<vscale x 8 x ptr> %buckets, i32 %inc, <vscale x 8 x i1> %mask)
ret void
}

define void @histogram_i16_8_lane(ptr %base, <vscale x 8 x i32> %indices, i16 %inc, <vscale x 8 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i16_8_lane:
; CHECK: // %bb.0:
; CHECK-NEXT: punpklo p1.h, p0.b
; CHECK-NEXT: mov z4.s, w1
; CHECK-NEXT: ptrue p2.s
; CHECK-NEXT: histcnt z2.s, p1/z, z0.s, z0.s
; CHECK-NEXT: ld1h { z3.s }, p1/z, [x0, z0.s, sxtw #1]
; CHECK-NEXT: punpkhi p0.h, p0.b
; CHECK-NEXT: mad z2.s, p2/m, z4.s, z3.s
; CHECK-NEXT: st1h { z2.s }, p1, [x0, z0.s, sxtw #1]
; CHECK-NEXT: histcnt z0.s, p0/z, z1.s, z1.s
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z1.s, sxtw #1]
; CHECK-NEXT: mad z0.s, p2/m, z4.s, z2.s
; CHECK-NEXT: st1h { z0.s }, p0, [x0, z1.s, sxtw #1]
; CHECK-NEXT: ret
%buckets = getelementptr i16, ptr %base, <vscale x 8 x i32> %indices
call void @llvm.experimental.vector.histogram.add.nxv8p0.i16(<vscale x 8 x ptr> %buckets, i16 %inc, <vscale x 8 x i1> %mask)
ret void
}


attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }
Loading