Skip to content

Commit dfe9416

Browse files
committed
[GlobalISel] Handle div-by-pow2
This patch adds similar handling of div-by-pow2 as in `SelectionDAG`.
1 parent bbcfe6f commit dfe9416

File tree

10 files changed

+2708
-743
lines changed

10 files changed

+2708
-743
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,14 @@ class CombinerHelper {
673673
bool matchSDivByConst(MachineInstr &MI);
674674
void applySDivByConst(MachineInstr &MI);
675675

676+
/// Given an G_SDIV \p MI expressing a signed divided by a pow2 constant,
677+
/// return expressions that implements it by shifting.
678+
bool matchDivByPow2(MachineInstr &MI, bool IsSigned);
679+
void applySDivByPow2(MachineInstr &MI);
680+
/// Given an G_UDIV \p MI expressing an unsigned divided by a pow2 constant,
681+
/// return expressions that implements it by shifting.
682+
void applyUDivByPow2(MachineInstr &MI);
683+
676684
// G_UMULH x, (1 << c)) -> x >> (bitwidth - c)
677685
bool matchUMulHToLShr(MachineInstr &MI);
678686
void applyUMulHToLShr(MachineInstr &MI);

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,11 +308,20 @@ std::optional<APFloat> ConstantFoldIntToFloat(unsigned Opcode, LLT DstTy,
308308
Register Src,
309309
const MachineRegisterInfo &MRI);
310310

311+
std::optional<SmallVector<APInt>>
312+
ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
313+
const MachineRegisterInfo &MRI);
314+
311315
/// Tries to constant fold a G_CTLZ operation on \p Src. If \p Src is a vector
312316
/// then it tries to do an element-wise constant fold.
313317
std::optional<SmallVector<unsigned>>
314318
ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI);
315319

320+
/// Tries to constant fold a G_CTTZ operation on \p Src. If \p Src is a vector
321+
/// then it tries to do an element-wise constant fold.
322+
std::optional<SmallVector<unsigned>>
323+
ConstantFoldCTTZ(Register Src, const MachineRegisterInfo &MRI);
324+
316325
/// Test if the given value is known to have exactly one bit set. This differs
317326
/// from computeKnownBits in that it doesn't necessarily determine which bit is
318327
/// set.

llvm/include/llvm/Target/GlobalISel/Combine.td

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def FmArcp : MIFlagEnum<"FmArcp">;
179179
def FmContract : MIFlagEnum<"FmContract">;
180180
def FmAfn : MIFlagEnum<"FmAfn">;
181181
def FmReassoc : MIFlagEnum<"FmReassoc">;
182+
def IsExact : MIFlagEnum<"IsExact">;
182183

183184
def MIFlags;
184185
// def not; -> Already defined as a SDNode
@@ -1036,7 +1037,20 @@ def sdiv_by_const : GICombineRule<
10361037
[{ return Helper.matchSDivByConst(*${root}); }]),
10371038
(apply [{ Helper.applySDivByConst(*${root}); }])>;
10381039

1039-
def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const]>;
1040+
def sdiv_by_pow2 : GICombineRule<
1041+
(defs root:$root),
1042+
(match (G_SDIV $dst, $x, $y, (MIFlags (not IsExact))):$root,
1043+
[{ return Helper.matchDivByPow2(*${root}, /*IsSigned=*/true); }]),
1044+
(apply [{ Helper.applySDivByPow2(*${root}); }])>;
1045+
1046+
def udiv_by_pow2 : GICombineRule<
1047+
(defs root:$root),
1048+
(match (G_UDIV $dst, $x, $y, (MIFlags (not IsExact))):$root,
1049+
[{ return Helper.matchDivByPow2(*${root}, /*IsSigned=*/false); }]),
1050+
(apply [{ Helper.applyUDivByPow2(*${root}); }])>;
1051+
1052+
def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const,
1053+
sdiv_by_pow2, udiv_by_pow2]>;
10401054

10411055
def reassoc_ptradd : GICombineRule<
10421056
(defs root:$root, build_fn_matchinfo:$matchinfo),

llvm/lib/CodeGen/GlobalISel/CSEMIRBuilder.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,20 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
174174
switch (Opc) {
175175
default:
176176
break;
177+
case TargetOpcode::G_ICMP: {
178+
assert(SrcOps.size() == 3 && "Invalid sources");
179+
assert(DstOps.size() == 1 && "Invalid dsts");
180+
LLT SrcTy = SrcOps[1].getLLTTy(*getMRI());
181+
182+
if (std::optional<SmallVector<APInt>> Cst =
183+
ConstantFoldICmp(SrcOps[0].getPredicate(), SrcOps[1].getReg(),
184+
SrcOps[2].getReg(), *getMRI())) {
185+
if (SrcTy.isVector())
186+
return buildBuildVectorConstant(DstOps[0], *Cst);
187+
return buildConstant(DstOps[0], Cst->front());
188+
}
189+
break;
190+
}
177191
case TargetOpcode::G_ADD:
178192
case TargetOpcode::G_PTR_ADD:
179193
case TargetOpcode::G_AND:
@@ -272,6 +286,22 @@ MachineInstrBuilder CSEMIRBuilder::buildInstr(unsigned Opc,
272286
buildConstant(VecTy.getScalarType(), Cst).getReg(0));
273287
return buildBuildVector(DstOps[0], ConstantRegs);
274288
}
289+
case TargetOpcode::G_CTTZ: {
290+
assert(SrcOps.size() == 1 && "Expected one source");
291+
assert(DstOps.size() == 1 && "Expected one dest");
292+
auto MaybeCsts = ConstantFoldCTTZ(SrcOps[0].getReg(), *getMRI());
293+
if (!MaybeCsts)
294+
break;
295+
if (MaybeCsts->size() == 1)
296+
return buildConstant(DstOps[0], (*MaybeCsts)[0]);
297+
// This was a vector constant. Build a G_BUILD_VECTOR for them.
298+
SmallVector<Register> ConstantRegs;
299+
LLT VecTy = DstOps[0].getLLTTy(*getMRI());
300+
for (unsigned Cst : *MaybeCsts)
301+
ConstantRegs.emplace_back(
302+
buildConstant(VecTy.getScalarType(), Cst).getReg(0));
303+
return buildBuildVector(DstOps[0], ConstantRegs);
304+
}
275305
}
276306
bool CanCopy = checkCopyToDefsPossible(DstOps);
277307
if (!canPerformCSEForOpc(Opc))

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5270,6 +5270,96 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
52705270
return MIB.buildMul(Ty, Res, Factor);
52715271
}
52725272

5273+
bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) {
5274+
assert((MI.getOpcode() == TargetOpcode::G_SDIV ||
5275+
MI.getOpcode() == TargetOpcode::G_UDIV) &&
5276+
"Expected SDIV or UDIV");
5277+
auto &Div = cast<GenericMachineInstr>(MI);
5278+
Register RHS = Div.getReg(2);
5279+
auto MatchPow2 = [&](const Constant *C) {
5280+
auto *CI = dyn_cast<ConstantInt>(C);
5281+
return CI && (CI->getValue().isPowerOf2() ||
5282+
(IsSigned && CI->getValue().isNegatedPowerOf2()));
5283+
};
5284+
return matchUnaryPredicate(MRI, RHS, MatchPow2, /*AllowUndefs=*/false);
5285+
}
5286+
5287+
void CombinerHelper::applySDivByPow2(MachineInstr &MI) {
5288+
assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5289+
auto &SDiv = cast<GenericMachineInstr>(MI);
5290+
Register Dst = SDiv.getReg(0);
5291+
Register LHS = SDiv.getReg(1);
5292+
Register RHS = SDiv.getReg(2);
5293+
LLT Ty = MRI.getType(Dst);
5294+
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5295+
5296+
Builder.setInstrAndDebugLoc(MI);
5297+
5298+
// Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2,
5299+
// to the following version:
5300+
//
5301+
// %c1 = G_CTTZ %rhs
5302+
// %inexact = G_SUB $bitwidth, %c1
5303+
// %sign = %G_ASHR %lhs, $(bitwidth - 1)
5304+
// %srl = G_SHR %sign, %inexact
5305+
// %add = G_ADD %lhs, %srl
5306+
// %sra = G_ASHR %add, %c1
5307+
// %sra = G_SELECT, %isoneorallones, %lhs, %sra
5308+
// %zero = G_CONSTANT $0
5309+
// %neg = G_NEG %sra
5310+
// %isneg = G_ICMP SLT %lhs, %zero
5311+
// %res = G_SELECT %isneg, %neg, %sra
5312+
5313+
unsigned Bitwidth = Ty.getScalarSizeInBits();
5314+
auto Zero = Builder.buildConstant(Ty, 0);
5315+
5316+
auto Bits = Builder.buildConstant(ShiftAmtTy, Bitwidth);
5317+
auto C1 = Builder.buildCTTZ(ShiftAmtTy, RHS);
5318+
auto Inexact = Builder.buildSub(ShiftAmtTy, Bits, C1);
5319+
auto Sign = Builder.buildAShr(
5320+
Ty, LHS, Builder.buildConstant(ShiftAmtTy, Bitwidth - 1));
5321+
5322+
// Add (LHS < 0) ? abs2 - 1 : 0;
5323+
auto Lshr = Builder.buildLShr(Ty, Sign, Inexact);
5324+
auto Add = Builder.buildAdd(Ty, LHS, Lshr);
5325+
auto Shr = Builder.buildAShr(Ty, Add, C1);
5326+
5327+
LLT CCVT =
5328+
Ty.isVector() ? LLT::vector(Ty.getElementCount(), 1) : LLT::scalar(1);
5329+
5330+
auto One = Builder.buildConstant(Ty, 1);
5331+
auto AllOnes =
5332+
Builder.buildConstant(Ty, APInt::getAllOnes(Ty.getScalarSizeInBits()));
5333+
auto IsOne = Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, One);
5334+
auto IsAllOnes =
5335+
Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, AllOnes);
5336+
auto IsOneOrAllOnes = Builder.buildOr(CCVT, IsOne, IsAllOnes);
5337+
Shr = Builder.buildSelect(Ty, IsOneOrAllOnes, LHS, Shr);
5338+
5339+
// If dividing by a positive value, we're done. Otherwise, the result must
5340+
// be negated.
5341+
auto Neg = Builder.buildNeg(Ty, Shr);
5342+
auto IsNeg = Builder.buildICmp(CmpInst::Predicate::ICMP_SLT, CCVT, LHS, Zero);
5343+
Builder.buildSelect(MI.getOperand(0).getReg(), IsNeg, Neg, Shr);
5344+
MI.eraseFromParent();
5345+
}
5346+
5347+
void CombinerHelper::applyUDivByPow2(MachineInstr &MI) {
5348+
assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
5349+
auto &UDiv = cast<GenericMachineInstr>(MI);
5350+
Register Dst = UDiv.getReg(0);
5351+
Register LHS = UDiv.getReg(1);
5352+
Register RHS = UDiv.getReg(2);
5353+
LLT Ty = MRI.getType(Dst);
5354+
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5355+
5356+
Builder.setInstrAndDebugLoc(MI);
5357+
5358+
auto C1 = Builder.buildCTTZ(ShiftAmtTy, RHS);
5359+
Builder.buildLShr(MI.getOperand(0).getReg(), LHS, C1);
5360+
MI.eraseFromParent();
5361+
}
5362+
52735363
bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) {
52745364
assert(MI.getOpcode() == TargetOpcode::G_UMULH);
52755365
Register RHS = MI.getOperand(2).getReg();

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,6 +996,105 @@ llvm::ConstantFoldCTLZ(Register Src, const MachineRegisterInfo &MRI) {
996996
return std::nullopt;
997997
}
998998

999+
std::optional<SmallVector<unsigned>>
1000+
llvm::ConstantFoldCTTZ(Register Src, const MachineRegisterInfo &MRI) {
1001+
LLT Ty = MRI.getType(Src);
1002+
SmallVector<unsigned> FoldedCTTZs;
1003+
auto tryFoldScalar = [&](Register R) -> std::optional<unsigned> {
1004+
auto MaybeCst = getIConstantVRegVal(R, MRI);
1005+
if (!MaybeCst)
1006+
return std::nullopt;
1007+
return MaybeCst->countTrailingZeros();
1008+
};
1009+
if (Ty.isVector()) {
1010+
// Try to constant fold each element.
1011+
auto *BV = getOpcodeDef<GBuildVector>(Src, MRI);
1012+
if (!BV)
1013+
return std::nullopt;
1014+
for (unsigned SrcIdx = 0; SrcIdx < BV->getNumSources(); ++SrcIdx) {
1015+
if (auto MaybeFold = tryFoldScalar(BV->getSourceReg(SrcIdx))) {
1016+
FoldedCTTZs.emplace_back(*MaybeFold);
1017+
continue;
1018+
}
1019+
return std::nullopt;
1020+
}
1021+
return FoldedCTTZs;
1022+
}
1023+
if (auto MaybeCst = tryFoldScalar(Src)) {
1024+
FoldedCTTZs.emplace_back(*MaybeCst);
1025+
return FoldedCTTZs;
1026+
}
1027+
return std::nullopt;
1028+
}
1029+
1030+
std::optional<SmallVector<APInt>>
1031+
llvm::ConstantFoldICmp(unsigned Pred, const Register Op1, const Register Op2,
1032+
const MachineRegisterInfo &MRI) {
1033+
LLT Ty = MRI.getType(Op1);
1034+
if (Ty != MRI.getType(Op2))
1035+
return std::nullopt;
1036+
1037+
auto TryFoldScalar = [&MRI, Pred](Register LHS,
1038+
Register RHS) -> std::optional<APInt> {
1039+
auto LHSCst = getIConstantVRegVal(LHS, MRI);
1040+
auto RHSCst = getIConstantVRegVal(RHS, MRI);
1041+
if (!LHSCst || !RHSCst)
1042+
return std::nullopt;
1043+
1044+
switch (Pred) {
1045+
case CmpInst::Predicate::ICMP_EQ:
1046+
return APInt(1, LHSCst->eq(*RHSCst));
1047+
case CmpInst::Predicate::ICMP_NE:
1048+
return APInt(1, LHSCst->ne(*RHSCst));
1049+
case CmpInst::Predicate::ICMP_UGT:
1050+
return APInt(1, LHSCst->ugt(*RHSCst));
1051+
case CmpInst::Predicate::ICMP_UGE:
1052+
return APInt(1, LHSCst->uge(*RHSCst));
1053+
case CmpInst::Predicate::ICMP_ULT:
1054+
return APInt(1, LHSCst->ult(*RHSCst));
1055+
case CmpInst::Predicate::ICMP_ULE:
1056+
return APInt(1, LHSCst->ult(*RHSCst));
1057+
case CmpInst::Predicate::ICMP_SGT:
1058+
return APInt(1, LHSCst->sgt(*RHSCst));
1059+
case CmpInst::Predicate::ICMP_SGE:
1060+
return APInt(1, LHSCst->sge(*RHSCst));
1061+
case CmpInst::Predicate::ICMP_SLT:
1062+
return APInt(1, LHSCst->slt(*RHSCst));
1063+
case CmpInst::Predicate::ICMP_SLE:
1064+
return APInt(1, LHSCst->sle(*RHSCst));
1065+
default:
1066+
return std::nullopt;
1067+
}
1068+
};
1069+
1070+
SmallVector<APInt> FoldedICmps;
1071+
1072+
if (Ty.isVector()) {
1073+
// Try to constant fold each element.
1074+
auto *BV1 = getOpcodeDef<GBuildVector>(Op1, MRI);
1075+
auto *BV2 = getOpcodeDef<GBuildVector>(Op2, MRI);
1076+
if (!BV1 || !BV2)
1077+
return std::nullopt;
1078+
assert(BV1->getNumSources() == BV2->getNumSources() && "Invalid vectors");
1079+
for (unsigned I = 0; I < BV1->getNumSources(); ++I) {
1080+
if (auto MaybeFold =
1081+
TryFoldScalar(BV1->getSourceReg(I), BV2->getSourceReg(I))) {
1082+
FoldedICmps.emplace_back(*MaybeFold);
1083+
continue;
1084+
}
1085+
return std::nullopt;
1086+
}
1087+
return FoldedICmps;
1088+
}
1089+
1090+
if (auto MaybeCst = TryFoldScalar(Op1, Op2)) {
1091+
FoldedICmps.emplace_back(*MaybeCst);
1092+
return FoldedICmps;
1093+
}
1094+
1095+
return std::nullopt;
1096+
}
1097+
9991098
bool llvm::isKnownToBeAPowerOfTwo(Register Reg, const MachineRegisterInfo &MRI,
10001099
GISelKnownBits *KB) {
10011100
std::optional<DefinitionAndSourceRegister> DefSrcReg =

llvm/test/CodeGen/AMDGPU/GlobalISel/llvm.amdgcn.sbfe.ll

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -670,36 +670,19 @@ define amdgpu_kernel void @bfe_sext_in_reg_i24(ptr addrspace(1) %out, ptr addrsp
670670
define amdgpu_kernel void @simplify_demanded_bfe_sdiv(ptr addrspace(1) %out, ptr addrspace(1) %in) #0 {
671671
; GFX6-LABEL: simplify_demanded_bfe_sdiv:
672672
; GFX6: ; %bb.0:
673-
; GFX6-NEXT: v_rcp_iflag_f32_e32 v0, 2.0
674-
; GFX6-NEXT: s_load_dwordx4 s[4:7], s[0:1], 0x0
675-
; GFX6-NEXT: v_mul_f32_e32 v0, 0x4f7ffffe, v0
676-
; GFX6-NEXT: v_cvt_u32_f32_e32 v0, v0
673+
; GFX6-NEXT: s_load_dwordx4 s[0:3], s[0:1], 0x0
677674
; GFX6-NEXT: s_waitcnt lgkmcnt(0)
678-
; GFX6-NEXT: s_load_dword s0, s[6:7], 0x0
679-
; GFX6-NEXT: s_mov_b32 s6, -1
680-
; GFX6-NEXT: s_mov_b32 s7, 0xf000
681-
; GFX6-NEXT: v_mul_lo_u32 v1, v0, -2
682-
; GFX6-NEXT: s_waitcnt lgkmcnt(0)
683-
; GFX6-NEXT: s_bfe_i32 s0, s0, 0x100001
684-
; GFX6-NEXT: s_ashr_i32 s2, s0, 31
685-
; GFX6-NEXT: v_mul_hi_u32 v1, v0, v1
686-
; GFX6-NEXT: s_add_i32 s0, s0, s2
687-
; GFX6-NEXT: s_xor_b32 s0, s0, s2
688-
; GFX6-NEXT: v_add_i32_e32 v0, vcc, v0, v1
689-
; GFX6-NEXT: v_mul_hi_u32 v0, s0, v0
690-
; GFX6-NEXT: v_lshlrev_b32_e32 v1, 1, v0
691-
; GFX6-NEXT: v_add_i32_e32 v2, vcc, 1, v0
692-
; GFX6-NEXT: v_sub_i32_e32 v1, vcc, s0, v1
693-
; GFX6-NEXT: v_cmp_le_u32_e32 vcc, 2, v1
694-
; GFX6-NEXT: v_cndmask_b32_e32 v0, v0, v2, vcc
695-
; GFX6-NEXT: v_subrev_i32_e64 v2, s[0:1], 2, v1
696-
; GFX6-NEXT: v_cndmask_b32_e32 v1, v1, v2, vcc
697-
; GFX6-NEXT: v_add_i32_e32 v2, vcc, 1, v0
698-
; GFX6-NEXT: v_cmp_le_u32_e32 vcc, 2, v1
699-
; GFX6-NEXT: v_cndmask_b32_e32 v0, v0, v2, vcc
700-
; GFX6-NEXT: v_xor_b32_e32 v0, s2, v0
701-
; GFX6-NEXT: v_subrev_i32_e32 v0, vcc, s2, v0
702-
; GFX6-NEXT: buffer_store_dword v0, off, s[4:7], 0
675+
; GFX6-NEXT: s_load_dword s3, s[2:3], 0x0
676+
; GFX6-NEXT: s_mov_b32 s2, -1
677+
; GFX6-NEXT: s_waitcnt lgkmcnt(0)
678+
; GFX6-NEXT: s_bfe_i32 s3, s3, 0x100001
679+
; GFX6-NEXT: s_ashr_i32 s4, s3, 31
680+
; GFX6-NEXT: s_lshr_b32 s4, s4, 31
681+
; GFX6-NEXT: s_add_i32 s3, s3, s4
682+
; GFX6-NEXT: s_ashr_i32 s3, s3, 1
683+
; GFX6-NEXT: v_mov_b32_e32 v0, s3
684+
; GFX6-NEXT: s_mov_b32 s3, 0xf000
685+
; GFX6-NEXT: buffer_store_dword v0, off, s[0:3], 0
703686
; GFX6-NEXT: s_endpgm
704687
%src = load i32, ptr addrspace(1) %in, align 4
705688
%bfe = call i32 @llvm.amdgcn.sbfe.i32(i32 %src, i32 1, i32 16)

0 commit comments

Comments
 (0)