Skip to content

Commit f35e8b4

Browse files
committed
[GlobalISel] Handle div-by-pow2
This patch adds similar handling of div-by-pow2 as in `SelectionDAG`.
1 parent 962f3e4 commit f35e8b4

File tree

4 files changed

+2444
-14
lines changed

4 files changed

+2444
-14
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,16 @@ class CombinerHelper {
673673
bool matchSDivByConst(MachineInstr &MI);
674674
void applySDivByConst(MachineInstr &MI);
675675

676+
/// Given an G_SDIV \p MI expressing a signed divided by a pow2 constant,
677+
/// return expressions that implements it by shifting.
678+
bool matchSDivByPow2(MachineInstr &MI);
679+
void applySDivByPow2(MachineInstr &MI);
680+
681+
/// Given an G_UDIV \p MI expressing an unsigned divided by a pow2 constant,
682+
/// return expressions that implements it by shifting.
683+
bool matchUDivByPow2(MachineInstr &MI);
684+
void applyUDivByPow2(MachineInstr &MI);
685+
676686
// G_UMULH x, (1 << c)) -> x >> (bitwidth - c)
677687
bool matchUMulHToLShr(MachineInstr &MI);
678688
void applyUMulHToLShr(MachineInstr &MI);

llvm/include/llvm/Target/GlobalISel/Combine.td

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def combine_extracted_vector_load : GICombineRule<
264264
(match (wip_match_opcode G_EXTRACT_VECTOR_ELT):$root,
265265
[{ return Helper.matchCombineExtractedVectorLoad(*${root}, ${matchinfo}); }]),
266266
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
267-
267+
268268
def combine_indexed_load_store : GICombineRule<
269269
(defs root:$root, indexed_load_store_matchdata:$matchinfo),
270270
(match (wip_match_opcode G_LOAD, G_SEXTLOAD, G_ZEXTLOAD, G_STORE):$root,
@@ -1036,7 +1036,20 @@ def sdiv_by_const : GICombineRule<
10361036
[{ return Helper.matchSDivByConst(*${root}); }]),
10371037
(apply [{ Helper.applySDivByConst(*${root}); }])>;
10381038

1039-
def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const]>;
1039+
def sdiv_by_pow2 : GICombineRule<
1040+
(defs root:$root),
1041+
(match (wip_match_opcode G_SDIV):$root,
1042+
[{ return Helper.matchSDivByPow2(*${root}); }]),
1043+
(apply [{ Helper.applySDivByPow2(*${root}); }])>;
1044+
1045+
def udiv_by_pow2 : GICombineRule<
1046+
(defs root:$root),
1047+
(match (wip_match_opcode G_UDIV):$root,
1048+
[{ return Helper.matchUDivByPow2(*${root}); }]),
1049+
(apply [{ Helper.applyUDivByPow2(*${root}); }])>;
1050+
1051+
def intdiv_combines : GICombineGroup<[udiv_by_const, sdiv_by_const,
1052+
sdiv_by_pow2, udiv_by_pow2]>;
10401053

10411054
def reassoc_ptradd : GICombineRule<
10421055
(defs root:$root, build_fn_matchinfo:$matchinfo),
@@ -1356,7 +1369,7 @@ def constant_fold_binops : GICombineGroup<[constant_fold_binop,
13561369

13571370
def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
13581371
extract_vec_elt_combines, combines_for_extload, combine_extracted_vector_load,
1359-
undef_combines, identity_combines, phi_combines,
1372+
undef_combines, identity_combines, phi_combines,
13601373
simplify_add_to_sub, hoist_logic_op_with_same_opcode_hands, shifts_too_big,
13611374
reassocs, ptr_add_immed_chain,
13621375
shl_ashr_to_sext_inreg, sext_inreg_of_load,
@@ -1373,7 +1386,7 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
13731386
intdiv_combines, mulh_combines, redundant_neg_operands,
13741387
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
13751388
sub_add_reg, select_to_minmax, redundant_binop_in_equality,
1376-
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
1389+
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
13771390
combine_concat_vector, double_icmp_zero_and_or_combine]>;
13781391

13791392
// A combine group used to for prelegalizer combiners at -O0. The combines in

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 136 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1490,7 +1490,7 @@ void CombinerHelper::applyOptBrCondByInvertingCond(MachineInstr &MI,
14901490
Observer.changedInstr(*BrCond);
14911491
}
14921492

1493-
1493+
14941494
bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) {
14951495
MachineIRBuilder HelperBuilder(MI);
14961496
GISelObserverWrapper DummyObserver;
@@ -5286,6 +5286,141 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) {
52865286
return MIB.buildMul(Ty, Res, Factor);
52875287
}
52885288

5289+
bool CombinerHelper::matchSDivByPow2(MachineInstr &MI) {
5290+
assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5291+
if (MI.getFlag(MachineInstr::MIFlag::IsExact))
5292+
return false;
5293+
auto &SDiv = cast<GenericMachineInstr>(MI);
5294+
Register RHS = SDiv.getReg(2);
5295+
auto MatchPow2 = [&](const Constant *C) {
5296+
if (auto *CI = dyn_cast<ConstantInt>(C))
5297+
return CI->getValue().isPowerOf2() || CI->getValue().isNegatedPowerOf2();
5298+
return false;
5299+
};
5300+
return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs= */ false);
5301+
}
5302+
5303+
void CombinerHelper::applySDivByPow2(MachineInstr &MI) {
5304+
assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV");
5305+
auto &SDiv = cast<GenericMachineInstr>(MI);
5306+
Register Dst = SDiv.getReg(0);
5307+
Register LHS = SDiv.getReg(1);
5308+
Register RHS = SDiv.getReg(2);
5309+
LLT Ty = MRI.getType(Dst);
5310+
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5311+
5312+
Builder.setInstrAndDebugLoc(MI);
5313+
5314+
unsigned Bitwidth = Ty.getScalarSizeInBits();
5315+
auto Zero = Builder.buildConstant(Ty, 0);
5316+
5317+
auto RHSC = getConstantOrConstantSplatVector(RHS);
5318+
if (RHSC.has_value()) {
5319+
auto RHSCV = *RHSC;
5320+
5321+
// Special case: (sdiv X, 1) -> X
5322+
if (RHSCV.isOne()) {
5323+
replaceSingleDefInstWithReg(MI, LHS);
5324+
return;
5325+
}
5326+
// Special Case: (sdiv X, -1) -> 0-X
5327+
if (RHSCV.isAllOnes()) {
5328+
auto Sub = Builder.buildSub(Ty, Zero, LHS);
5329+
replaceSingleDefInstWithReg(MI, Sub->getOperand(0).getReg());
5330+
return;
5331+
}
5332+
5333+
5334+
unsigned TrailingZeros = RHSCV.countTrailingZeros();
5335+
auto C1 = Builder.buildConstant(ShiftAmtTy, TrailingZeros);
5336+
auto Inexact = Builder.buildConstant(ShiftAmtTy, Bitwidth - TrailingZeros);
5337+
auto Sign = Builder.buildAShr(
5338+
Ty, LHS, Builder.buildConstant(ShiftAmtTy, Bitwidth - 1));
5339+
// Add (LHS < 0) ? abs2 - 1 : 0;
5340+
auto Srl = Builder.buildShl(Ty, Sign, Inexact);
5341+
auto Add = Builder.buildAdd(Ty, LHS, Srl);
5342+
auto Sra = Builder.buildAShr(Ty, Add, C1);
5343+
5344+
// If dividing by a positive value, we're done. Otherwise, the result must
5345+
// be negated.
5346+
auto Res = RHSCV.isNegative() ? Builder.buildSub(Ty, Zero, Sra) : Sra;
5347+
replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
5348+
return;
5349+
}
5350+
5351+
// RHS is not a splat vector. Build the above version with instructions.
5352+
auto Bits = Builder.buildConstant(ShiftAmtTy, Bitwidth);
5353+
auto C1 = Builder.buildCTTZ(Ty, RHS);
5354+
C1 = Builder.buildZExtOrTrunc(ShiftAmtTy, C1);
5355+
auto Inexact = Builder.buildSub(ShiftAmtTy, Bits, C1);
5356+
auto Sign = Builder.buildAShr(
5357+
Ty, LHS, Builder.buildConstant(ShiftAmtTy, Bitwidth - 1));
5358+
5359+
// Add (LHS < 0) ? abs2 - 1 : 0;
5360+
auto Srl = Builder.buildShl(Ty, Sign, Inexact);
5361+
auto Add = Builder.buildAdd(Ty, LHS, Srl);
5362+
auto Sra = Builder.buildAShr(Ty, Add, C1);
5363+
5364+
LLT CCVT = LLT::vector(Ty.getElementCount(), 1);
5365+
5366+
auto One = Builder.buildConstant(Ty, 1);
5367+
auto AllOnes =
5368+
Builder.buildConstant(Ty, APInt::getAllOnes(Ty.getScalarSizeInBits()));
5369+
auto IsOne = Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, One);
5370+
auto IsAllOnes =
5371+
Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, AllOnes);
5372+
auto IsOneOrAllOnes = Builder.buildOr(CCVT, IsOne, IsAllOnes);
5373+
Sra = Builder.buildSelect(Ty, IsOneOrAllOnes, LHS, Sra);
5374+
5375+
// If dividing by a positive value, we're done. Otherwise, the result must
5376+
// be negated.
5377+
auto Sub = Builder.buildSub(Ty, Zero, Sra);
5378+
auto IsNeg = Builder.buildICmp(CmpInst::Predicate::ICMP_SLT, CCVT, LHS, Zero);
5379+
auto Res = Builder.buildSelect(Ty, IsNeg, Sub, Sra);
5380+
replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
5381+
}
5382+
5383+
bool CombinerHelper::matchUDivByPow2(MachineInstr &MI) {
5384+
assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV");
5385+
if (MI.getFlag(MachineInstr::MIFlag::IsExact))
5386+
return false;
5387+
auto &UDiv = cast<GenericMachineInstr>(MI);
5388+
Register RHS = UDiv.getReg(2);
5389+
auto MatchPow2 = [&](const Constant *C) {
5390+
if (auto *CI = dyn_cast<ConstantInt>(C))
5391+
return CI->getValue().isPowerOf2();
5392+
return false;
5393+
};
5394+
return matchUnaryPredicate(MRI, RHS, MatchPow2, /* AllowUndefs= */ false);
5395+
}
5396+
5397+
void CombinerHelper::applyUDivByPow2(MachineInstr &MI) {
5398+
assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected SDIV");
5399+
auto &UDiv = cast<GenericMachineInstr>(MI);
5400+
Register Dst = UDiv.getReg(0);
5401+
Register LHS = UDiv.getReg(1);
5402+
Register RHS = UDiv.getReg(2);
5403+
LLT Ty = MRI.getType(Dst);
5404+
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
5405+
5406+
Builder.setInstrAndDebugLoc(MI);
5407+
5408+
auto RHSC = getIConstantVRegValWithLookThrough(RHS, MRI);
5409+
assert(RHSC.has_value() && "RHS must be a constant");
5410+
auto RHSCV = RHSC->Value;
5411+
5412+
// Special case: (udiv X, 1) -> X
5413+
if (RHSCV.isOne()) {
5414+
replaceSingleDefInstWithReg(MI, LHS);
5415+
return;
5416+
}
5417+
5418+
unsigned TrailingZeros = RHSCV.countTrailingZeros();
5419+
auto C1 = Builder.buildConstant(ShiftAmtTy, TrailingZeros);
5420+
auto Res = Builder.buildLShr(Ty, LHS, C1);
5421+
replaceSingleDefInstWithReg(MI, Res->getOperand(0).getReg());
5422+
}
5423+
52895424
bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) {
52905425
assert(MI.getOpcode() == TargetOpcode::G_UMULH);
52915426
Register RHS = MI.getOperand(2).getReg();

0 commit comments

Comments
 (0)