Skip to content

[LegalizeDAG] Optimize CodeGen for ISD::CTLZ_ZERO_UNDEF #83039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2461,13 +2461,22 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
NewOpc = TargetOpcode::G_CTTZ_ZERO_UNDEF;
}

unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();

if (MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
// An optimization where the result is the CTLZ after the left shift by
// (Difference in widety and current ty), that is,
// MIBSrc = MIBSrc << (sizeinbits(WideTy) - sizeinbits(CurTy))
// Result = ctlz MIBSrc
MIBSrc = MIRBuilder.buildShl(WideTy, MIBSrc,
MIRBuilder.buildConstant(WideTy, SizeDiff));
}

// Perform the operation at the larger size.
auto MIBNewOp = MIRBuilder.buildInstr(NewOpc, {WideTy}, {MIBSrc});
// This is already the correct result for CTPOP and CTTZs
if (MI.getOpcode() == TargetOpcode::G_CTLZ ||
MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
if (MI.getOpcode() == TargetOpcode::G_CTLZ) {
// The correct result is NewOp - (Difference in widety and current ty).
unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();
MIBNewOp = MIRBuilder.buildSub(
WideTy, MIBNewOp, MIRBuilder.buildConstant(WideTy, SizeDiff));
}
Expand Down
22 changes: 20 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5083,7 +5083,6 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
case ISD::CTTZ:
case ISD::CTTZ_ZERO_UNDEF:
case ISD::CTLZ:
case ISD::CTLZ_ZERO_UNDEF:
case ISD::CTPOP: {
// Zero extend the argument unless its cttz, then use any_extend.
if (Node->getOpcode() == ISD::CTTZ ||
Expand All @@ -5106,7 +5105,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
// Perform the larger operation. For CTPOP and CTTZ_ZERO_UNDEF, this is
// already the correct result.
Tmp1 = DAG.getNode(NewOpc, dl, NVT, Tmp1);
if (NewOpc == ISD::CTLZ || NewOpc == ISD::CTLZ_ZERO_UNDEF) {
if (NewOpc == ISD::CTLZ) {
// Tmp1 = Tmp1 - (sizeinbits(NVT) - sizeinbits(Old VT))
Tmp1 = DAG.getNode(ISD::SUB, dl, NVT, Tmp1,
DAG.getConstant(NVT.getSizeInBits() -
Expand All @@ -5115,6 +5114,25 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1));
break;
}
case ISD::CTLZ_ZERO_UNDEF: {
// We know that the argument is unlikely to be zero, hence we can take a
// different approach as compared to ISD::CTLZ

// Any Extend the argument
auto AnyExtendedNode =
DAG.getNode(ISD::ANY_EXTEND, dl, NVT, Node->getOperand(0));

// Tmp1 = Tmp1 << (sizeinbits(NVT) - sizeinbits(Old VT))
auto ShiftConstant = DAG.getShiftAmountConstant(
NVT.getSizeInBits() - OVT.getSizeInBits(), NVT, dl);
auto LeftShiftResult =
DAG.getNode(ISD::SHL, dl, NVT, AnyExtendedNode, ShiftConstant);

// Perform the larger operation
auto CTLZResult = DAG.getNode(Node->getOpcode(), dl, NVT, LeftShiftResult);
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, CTLZResult));
break;
}
case ISD::BITREVERSE:
case ISD::BSWAP: {
unsigned DiffBits = NVT.getSizeInBits() - OVT.getSizeInBits();
Expand Down
54 changes: 38 additions & 16 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,24 +655,46 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
}
}

// Subtract off the extra leading bits in the bigger type.
SDValue ExtractLeadingBits = DAG.getConstant(
NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits(), dl, NVT);
if (!N->isVPOpcode()) {
unsigned CtlzOpcode = N->getOpcode();
if (CtlzOpcode == ISD::CTLZ || CtlzOpcode == ISD::VP_CTLZ) {
// Subtract off the extra leading bits in the bigger type.
SDValue ExtractLeadingBits = DAG.getConstant(
NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits(), dl, NVT);

if (!N->isVPOpcode()) {
// Zero extend to the promoted type and do the count there.
SDValue Op = ZExtPromotedInteger(N->getOperand(0));
return DAG.getNode(ISD::SUB, dl, NVT,
DAG.getNode(N->getOpcode(), dl, NVT, Op),
ExtractLeadingBits);
}
SDValue Mask = N->getOperand(1);
SDValue EVL = N->getOperand(2);
// Zero extend to the promoted type and do the count there.
SDValue Op = ZExtPromotedInteger(N->getOperand(0));
return DAG.getNode(ISD::SUB, dl, NVT,
DAG.getNode(N->getOpcode(), dl, NVT, Op),
ExtractLeadingBits);
}
SDValue Op = VPZExtPromotedInteger(N->getOperand(0), Mask, EVL);
return DAG.getNode(ISD::VP_SUB, dl, NVT,
DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
ExtractLeadingBits, Mask, EVL);
}
if (CtlzOpcode == ISD::CTLZ_ZERO_UNDEF ||
CtlzOpcode == ISD::VP_CTLZ_ZERO_UNDEF) {
// Any Extend the argument
SDValue Op = GetPromotedInteger(N->getOperand(0));
// Op = Op << (sizeinbits(NVT) - sizeinbits(Old VT))
unsigned SHLAmount = NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits();
auto ShiftConst =
DAG.getShiftAmountConstant(SHLAmount, Op.getValueType(), dl);
if (!N->isVPOpcode()) {
Op = DAG.getNode(ISD::SHL, dl, NVT, Op, ShiftConst);
return DAG.getNode(CtlzOpcode, dl, NVT, Op);
}

SDValue Mask = N->getOperand(1);
SDValue EVL = N->getOperand(2);
// Zero extend to the promoted type and do the count there.
SDValue Op = VPZExtPromotedInteger(N->getOperand(0), Mask, EVL);
return DAG.getNode(ISD::VP_SUB, dl, NVT,
DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
ExtractLeadingBits, Mask, EVL);
SDValue Mask = N->getOperand(1);
SDValue EVL = N->getOperand(2);
Op = DAG.getNode(ISD::VP_SHL, dl, NVT, Op, ShiftConst, Mask, EVL);
return DAG.getNode(CtlzOpcode, dl, NVT, Op, Mask, EVL);
}
llvm_unreachable("Invalid CTLZ Opcode");
}

SDValue DAGTypeLegalizer::PromoteIntRes_CTPOP_PARITY(SDNode *N) {
Expand Down
44 changes: 44 additions & 0 deletions llvm/test/CodeGen/AArch64/ctlz_zero_undef.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc < %s --mtriple=aarch64 | FileCheck %s

declare i8 @llvm.ctlz.i8(i8, i1 immarg)
declare <8 x i8> @llvm.ctlz.v8i8(<8 x i8>, i1 immarg)
declare i11 @llvm.ctlz.i11(i11, i1 immarg)

define i32 @clz_nzu8(i8 %self) {
; CHECK-LABEL: clz_nzu8:
; CHECK: // %bb.0: // %start
; CHECK-NEXT: lsl w8, w0, #24
; CHECK-NEXT: clz w0, w8
; CHECK-NEXT: ret
start:
%ctlz_res = call i8 @llvm.ctlz.i8(i8 %self, i1 true)
%ret = zext i8 %ctlz_res to i32
ret i32 %ret
}

; non standard bit size argument to ctlz
define i32 @clz_nzu11(i11 %self) {
; CHECK-LABEL: clz_nzu11:
; CHECK: // %bb.0:
; CHECK-NEXT: lsl w8, w0, #21
; CHECK-NEXT: clz w0, w8
; CHECK-NEXT: ret
%ctlz_res = call i11 @llvm.ctlz.i11(i11 %self, i1 true)
%ret = zext i11 %ctlz_res to i32
ret i32 %ret
}

; vector type argument to ctlz intrinsic
define <8 x i32> @clz_vec_nzu8(<8 x i8> %self) {
; CHECK-LABEL: clz_vec_nzu8:
; CHECK: // %bb.0:
; CHECK-NEXT: clz v0.8b, v0.8b
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
; CHECK-NEXT: ushll2 v1.4s, v0.8h, #0
; CHECK-NEXT: ushll v0.4s, v0.4h, #0
; CHECK-NEXT: ret
%ctlz_res = call <8 x i8> @llvm.ctlz.v8i8(<8 x i8> %self, i1 true)
%ret = zext <8 x i8> %ctlz_res to <8 x i32>
ret <8 x i32> %ret
}
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,10 @@ body: |
; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $vgpr0_vgpr1
; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 8589934591
; CHECK-NEXT: [[AND:%[0-9]+]]:_(s64) = G_AND [[COPY]], [[C]]
; CHECK-NEXT: [[CTLZ_ZERO_UNDEF:%[0-9]+]]:_(s32) = G_CTLZ_ZERO_UNDEF [[AND]](s64)
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 31
; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[C1]](s64)
; CHECK-NEXT: [[USUBO:%[0-9]+]]:_(s32), [[USUBO1:%[0-9]+]]:_(s1) = G_USUBO [[CTLZ_ZERO_UNDEF]], [[UV]]
; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[USUBO]](s32)
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 31
; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[AND]], [[C1]](s32)
; CHECK-NEXT: [[CTLZ_ZERO_UNDEF:%[0-9]+]]:_(s32) = G_CTLZ_ZERO_UNDEF [[SHL]](s64)
; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[CTLZ_ZERO_UNDEF]](s32)
; CHECK-NEXT: $vgpr0_vgpr1 = COPY [[ZEXT]](s64)
%0:_(s64) = COPY $vgpr0_vgpr1
%1:_(s33) = G_TRUNC %0
Expand Down
Loading
Loading