Skip to content

Commit 69192e0

Browse files
NirharManish Kausik H
andauthored
[LegalizeDAG] Optimize CodeGen for ISD::CTLZ_ZERO_UNDEF (llvm#83039)
Previously we had the same instructions being generated for `ISD::CTLZ` and `ISD::CTLZ_ZERO_UNDEF` which did not take advantage of the fact that zero is an invalid input for `ISD::CTLZ_ZERO_UNDEF`. This commit separates codegen for the two cases to allow for the optimization for the latter case. The details of the optimization are outlined in llvm#82075 Fixes llvm#82075 Co-authored-by: Manish Kausik H <[email protected]>
1 parent cb72aec commit 69192e0

File tree

18 files changed

+328
-287
lines changed

18 files changed

+328
-287
lines changed

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2461,13 +2461,22 @@ LegalizerHelper::widenScalar(MachineInstr &MI, unsigned TypeIdx, LLT WideTy) {
24612461
NewOpc = TargetOpcode::G_CTTZ_ZERO_UNDEF;
24622462
}
24632463

2464+
unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();
2465+
2466+
if (MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
2467+
// An optimization where the result is the CTLZ after the left shift by
2468+
// (Difference in widety and current ty), that is,
2469+
// MIBSrc = MIBSrc << (sizeinbits(WideTy) - sizeinbits(CurTy))
2470+
// Result = ctlz MIBSrc
2471+
MIBSrc = MIRBuilder.buildShl(WideTy, MIBSrc,
2472+
MIRBuilder.buildConstant(WideTy, SizeDiff));
2473+
}
2474+
24642475
// Perform the operation at the larger size.
24652476
auto MIBNewOp = MIRBuilder.buildInstr(NewOpc, {WideTy}, {MIBSrc});
24662477
// This is already the correct result for CTPOP and CTTZs
2467-
if (MI.getOpcode() == TargetOpcode::G_CTLZ ||
2468-
MI.getOpcode() == TargetOpcode::G_CTLZ_ZERO_UNDEF) {
2478+
if (MI.getOpcode() == TargetOpcode::G_CTLZ) {
24692479
// The correct result is NewOp - (Difference in widety and current ty).
2470-
unsigned SizeDiff = WideTy.getSizeInBits() - CurTy.getSizeInBits();
24712480
MIBNewOp = MIRBuilder.buildSub(
24722481
WideTy, MIBNewOp, MIRBuilder.buildConstant(WideTy, SizeDiff));
24732482
}

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5083,7 +5083,6 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
50835083
case ISD::CTTZ:
50845084
case ISD::CTTZ_ZERO_UNDEF:
50855085
case ISD::CTLZ:
5086-
case ISD::CTLZ_ZERO_UNDEF:
50875086
case ISD::CTPOP: {
50885087
// Zero extend the argument unless its cttz, then use any_extend.
50895088
if (Node->getOpcode() == ISD::CTTZ ||
@@ -5106,7 +5105,7 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
51065105
// Perform the larger operation. For CTPOP and CTTZ_ZERO_UNDEF, this is
51075106
// already the correct result.
51085107
Tmp1 = DAG.getNode(NewOpc, dl, NVT, Tmp1);
5109-
if (NewOpc == ISD::CTLZ || NewOpc == ISD::CTLZ_ZERO_UNDEF) {
5108+
if (NewOpc == ISD::CTLZ) {
51105109
// Tmp1 = Tmp1 - (sizeinbits(NVT) - sizeinbits(Old VT))
51115110
Tmp1 = DAG.getNode(ISD::SUB, dl, NVT, Tmp1,
51125111
DAG.getConstant(NVT.getSizeInBits() -
@@ -5115,6 +5114,25 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) {
51155114
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1));
51165115
break;
51175116
}
5117+
case ISD::CTLZ_ZERO_UNDEF: {
5118+
// We know that the argument is unlikely to be zero, hence we can take a
5119+
// different approach as compared to ISD::CTLZ
5120+
5121+
// Any Extend the argument
5122+
auto AnyExtendedNode =
5123+
DAG.getNode(ISD::ANY_EXTEND, dl, NVT, Node->getOperand(0));
5124+
5125+
// Tmp1 = Tmp1 << (sizeinbits(NVT) - sizeinbits(Old VT))
5126+
auto ShiftConstant = DAG.getShiftAmountConstant(
5127+
NVT.getSizeInBits() - OVT.getSizeInBits(), NVT, dl);
5128+
auto LeftShiftResult =
5129+
DAG.getNode(ISD::SHL, dl, NVT, AnyExtendedNode, ShiftConstant);
5130+
5131+
// Perform the larger operation
5132+
auto CTLZResult = DAG.getNode(Node->getOpcode(), dl, NVT, LeftShiftResult);
5133+
Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, CTLZResult));
5134+
break;
5135+
}
51185136
case ISD::BITREVERSE:
51195137
case ISD::BSWAP: {
51205138
unsigned DiffBits = NVT.getSizeInBits() - OVT.getSizeInBits();

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -655,24 +655,46 @@ SDValue DAGTypeLegalizer::PromoteIntRes_CTLZ(SDNode *N) {
655655
}
656656
}
657657

658-
// Subtract off the extra leading bits in the bigger type.
659-
SDValue ExtractLeadingBits = DAG.getConstant(
660-
NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits(), dl, NVT);
661-
if (!N->isVPOpcode()) {
658+
unsigned CtlzOpcode = N->getOpcode();
659+
if (CtlzOpcode == ISD::CTLZ || CtlzOpcode == ISD::VP_CTLZ) {
660+
// Subtract off the extra leading bits in the bigger type.
661+
SDValue ExtractLeadingBits = DAG.getConstant(
662+
NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits(), dl, NVT);
663+
664+
if (!N->isVPOpcode()) {
665+
// Zero extend to the promoted type and do the count there.
666+
SDValue Op = ZExtPromotedInteger(N->getOperand(0));
667+
return DAG.getNode(ISD::SUB, dl, NVT,
668+
DAG.getNode(N->getOpcode(), dl, NVT, Op),
669+
ExtractLeadingBits);
670+
}
671+
SDValue Mask = N->getOperand(1);
672+
SDValue EVL = N->getOperand(2);
662673
// Zero extend to the promoted type and do the count there.
663-
SDValue Op = ZExtPromotedInteger(N->getOperand(0));
664-
return DAG.getNode(ISD::SUB, dl, NVT,
665-
DAG.getNode(N->getOpcode(), dl, NVT, Op),
666-
ExtractLeadingBits);
667-
}
674+
SDValue Op = VPZExtPromotedInteger(N->getOperand(0), Mask, EVL);
675+
return DAG.getNode(ISD::VP_SUB, dl, NVT,
676+
DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
677+
ExtractLeadingBits, Mask, EVL);
678+
}
679+
if (CtlzOpcode == ISD::CTLZ_ZERO_UNDEF ||
680+
CtlzOpcode == ISD::VP_CTLZ_ZERO_UNDEF) {
681+
// Any Extend the argument
682+
SDValue Op = GetPromotedInteger(N->getOperand(0));
683+
// Op = Op << (sizeinbits(NVT) - sizeinbits(Old VT))
684+
unsigned SHLAmount = NVT.getScalarSizeInBits() - OVT.getScalarSizeInBits();
685+
auto ShiftConst =
686+
DAG.getShiftAmountConstant(SHLAmount, Op.getValueType(), dl);
687+
if (!N->isVPOpcode()) {
688+
Op = DAG.getNode(ISD::SHL, dl, NVT, Op, ShiftConst);
689+
return DAG.getNode(CtlzOpcode, dl, NVT, Op);
690+
}
668691

669-
SDValue Mask = N->getOperand(1);
670-
SDValue EVL = N->getOperand(2);
671-
// Zero extend to the promoted type and do the count there.
672-
SDValue Op = VPZExtPromotedInteger(N->getOperand(0), Mask, EVL);
673-
return DAG.getNode(ISD::VP_SUB, dl, NVT,
674-
DAG.getNode(N->getOpcode(), dl, NVT, Op, Mask, EVL),
675-
ExtractLeadingBits, Mask, EVL);
692+
SDValue Mask = N->getOperand(1);
693+
SDValue EVL = N->getOperand(2);
694+
Op = DAG.getNode(ISD::VP_SHL, dl, NVT, Op, ShiftConst, Mask, EVL);
695+
return DAG.getNode(CtlzOpcode, dl, NVT, Op, Mask, EVL);
696+
}
697+
llvm_unreachable("Invalid CTLZ Opcode");
676698
}
677699

678700
SDValue DAGTypeLegalizer::PromoteIntRes_CTPOP_PARITY(SDNode *N) {
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2+
; RUN: llc < %s --mtriple=aarch64 | FileCheck %s
3+
4+
declare i8 @llvm.ctlz.i8(i8, i1 immarg)
5+
declare <8 x i8> @llvm.ctlz.v8i8(<8 x i8>, i1 immarg)
6+
declare i11 @llvm.ctlz.i11(i11, i1 immarg)
7+
8+
define i32 @clz_nzu8(i8 %self) {
9+
; CHECK-LABEL: clz_nzu8:
10+
; CHECK: // %bb.0: // %start
11+
; CHECK-NEXT: lsl w8, w0, #24
12+
; CHECK-NEXT: clz w0, w8
13+
; CHECK-NEXT: ret
14+
start:
15+
%ctlz_res = call i8 @llvm.ctlz.i8(i8 %self, i1 true)
16+
%ret = zext i8 %ctlz_res to i32
17+
ret i32 %ret
18+
}
19+
20+
; non standard bit size argument to ctlz
21+
define i32 @clz_nzu11(i11 %self) {
22+
; CHECK-LABEL: clz_nzu11:
23+
; CHECK: // %bb.0:
24+
; CHECK-NEXT: lsl w8, w0, #21
25+
; CHECK-NEXT: clz w0, w8
26+
; CHECK-NEXT: ret
27+
%ctlz_res = call i11 @llvm.ctlz.i11(i11 %self, i1 true)
28+
%ret = zext i11 %ctlz_res to i32
29+
ret i32 %ret
30+
}
31+
32+
; vector type argument to ctlz intrinsic
33+
define <8 x i32> @clz_vec_nzu8(<8 x i8> %self) {
34+
; CHECK-LABEL: clz_vec_nzu8:
35+
; CHECK: // %bb.0:
36+
; CHECK-NEXT: clz v0.8b, v0.8b
37+
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
38+
; CHECK-NEXT: ushll2 v1.4s, v0.8h, #0
39+
; CHECK-NEXT: ushll v0.4s, v0.4h, #0
40+
; CHECK-NEXT: ret
41+
%ctlz_res = call <8 x i8> @llvm.ctlz.v8i8(<8 x i8> %self, i1 true)
42+
%ret = zext <8 x i8> %ctlz_res to <8 x i32>
43+
ret <8 x i32> %ret
44+
}

llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-ctlz-zero-undef.mir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,10 @@ body: |
200200
; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s64) = COPY $vgpr0_vgpr1
201201
; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 8589934591
202202
; CHECK-NEXT: [[AND:%[0-9]+]]:_(s64) = G_AND [[COPY]], [[C]]
203-
; CHECK-NEXT: [[CTLZ_ZERO_UNDEF:%[0-9]+]]:_(s32) = G_CTLZ_ZERO_UNDEF [[AND]](s64)
204-
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s64) = G_CONSTANT i64 31
205-
; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[C1]](s64)
206-
; CHECK-NEXT: [[USUBO:%[0-9]+]]:_(s32), [[USUBO1:%[0-9]+]]:_(s1) = G_USUBO [[CTLZ_ZERO_UNDEF]], [[UV]]
207-
; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[USUBO]](s32)
203+
; CHECK-NEXT: [[C1:%[0-9]+]]:_(s32) = G_CONSTANT i32 31
204+
; CHECK-NEXT: [[SHL:%[0-9]+]]:_(s64) = G_SHL [[AND]], [[C1]](s32)
205+
; CHECK-NEXT: [[CTLZ_ZERO_UNDEF:%[0-9]+]]:_(s32) = G_CTLZ_ZERO_UNDEF [[SHL]](s64)
206+
; CHECK-NEXT: [[ZEXT:%[0-9]+]]:_(s64) = G_ZEXT [[CTLZ_ZERO_UNDEF]](s32)
208207
; CHECK-NEXT: $vgpr0_vgpr1 = COPY [[ZEXT]](s64)
209208
%0:_(s64) = COPY $vgpr0_vgpr1
210209
%1:_(s33) = G_TRUNC %0

0 commit comments

Comments
 (0)