Skip to content

Commit b24436a

Browse files
committed
GlobalISel: Lower funnel shifts
1 parent 5949bd9 commit b24436a

File tree

10 files changed

+17927
-75
lines changed

10 files changed

+17927
-75
lines changed

llvm/include/llvm/CodeGen/GlobalISel/LegalizerHelper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ class LegalizerHelper {
345345
LegalizeResult lowerLoad(MachineInstr &MI);
346346
LegalizeResult lowerStore(MachineInstr &MI);
347347
LegalizeResult lowerBitCount(MachineInstr &MI);
348+
LegalizeResult lowerFunnelShiftWithInverse(MachineInstr &MI);
349+
LegalizeResult lowerFunnelShiftAsShifts(MachineInstr &MI);
350+
LegalizeResult lowerFunnelShift(MachineInstr &MI);
348351

349352
LegalizeResult lowerU64ToF32BitOps(MachineInstr &MI);
350353
LegalizeResult lowerUITOFP(MachineInstr &MI);

llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,13 @@ class MachineIRBuilder {
14441444
return buildInstr(TargetOpcode::G_SMULH, {Dst}, {Src0, Src1}, Flags);
14451445
}
14461446

1447+
/// Build and insert \p Res = G_UREM \p Op0, \p Op1
1448+
MachineInstrBuilder buildURem(const DstOp &Dst, const SrcOp &Src0,
1449+
const SrcOp &Src1,
1450+
Optional<unsigned> Flags = None) {
1451+
return buildInstr(TargetOpcode::G_UREM, {Dst}, {Src0, Src1}, Flags);
1452+
}
1453+
14471454
MachineInstrBuilder buildFMul(const DstOp &Dst, const SrcOp &Src0,
14481455
const SrcOp &Src1,
14491456
Optional<unsigned> Flags = None) {

llvm/include/llvm/CodeGen/GlobalISel/Utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,13 @@ bool isBuildVectorAllOnes(const MachineInstr &MI,
328328
Optional<RegOrConstant> getVectorSplat(const MachineInstr &MI,
329329
const MachineRegisterInfo &MRI);
330330

331+
/// Attempt to match a unary predicate against a scalar/splat constant or every
332+
/// element of a constant G_BUILD_VECTOR. If \p ConstVal is null, the source
333+
/// value was undef.
334+
bool matchUnaryPredicate(const MachineRegisterInfo &MRI, Register Reg,
335+
std::function<bool(const Constant *ConstVal)> Match,
336+
bool AllowUndefs = false);
337+
331338
/// Returns true if given the TargetLowering's boolean contents information,
332339
/// the value \p Val contains a true value.
333340
bool isConstTrueVal(const TargetLowering &TLI, int64_t Val, bool IsVector,

llvm/lib/CodeGen/GlobalISel/LegalizerHelper.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3210,6 +3210,9 @@ LegalizerHelper::lower(MachineInstr &MI, unsigned TypeIdx, LLT LowerHintTy) {
32103210
case G_SDIVREM:
32113211
case G_UDIVREM:
32123212
return lowerDIVREM(MI);
3213+
case G_FSHL:
3214+
case G_FSHR:
3215+
return lowerFunnelShift(MI);
32133216
}
32143217
}
32153218

@@ -5207,6 +5210,132 @@ LegalizerHelper::lowerBitCount(MachineInstr &MI) {
52075210
}
52085211
}
52095212

5213+
// Check that (every element of) Reg is undef or not an exact multiple of BW.
5214+
static bool isNonZeroModBitWidthOrUndef(const MachineRegisterInfo &MRI,
5215+
Register Reg, unsigned BW) {
5216+
return matchUnaryPredicate(
5217+
MRI, Reg,
5218+
[=](const Constant *C) {
5219+
// Null constant here means an undef.
5220+
const ConstantInt *CI = dyn_cast_or_null<ConstantInt>(C);
5221+
return !CI || CI->getValue().urem(BW) != 0;
5222+
},
5223+
/*AllowUndefs*/ true);
5224+
}
5225+
5226+
LegalizerHelper::LegalizeResult
5227+
LegalizerHelper::lowerFunnelShiftWithInverse(MachineInstr &MI) {
5228+
Register Dst = MI.getOperand(0).getReg();
5229+
Register X = MI.getOperand(1).getReg();
5230+
Register Y = MI.getOperand(2).getReg();
5231+
Register Z = MI.getOperand(3).getReg();
5232+
LLT Ty = MRI.getType(Dst);
5233+
LLT ShTy = MRI.getType(Z);
5234+
5235+
unsigned BW = Ty.getScalarSizeInBits();
5236+
const bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
5237+
unsigned RevOpcode = IsFSHL ? TargetOpcode::G_FSHR : TargetOpcode::G_FSHL;
5238+
5239+
if (isNonZeroModBitWidthOrUndef(MRI, Z, BW)) {
5240+
// fshl X, Y, Z -> fshr X, Y, -Z
5241+
// fshr X, Y, Z -> fshl X, Y, -Z
5242+
auto Zero = MIRBuilder.buildConstant(ShTy, 0);
5243+
Z = MIRBuilder.buildSub(Ty, Zero, Z).getReg(0);
5244+
} else {
5245+
// fshl X, Y, Z -> fshr (srl X, 1), (fshr X, Y, 1), ~Z
5246+
// fshr X, Y, Z -> fshl (fshl X, Y, 1), (shl Y, 1), ~Z
5247+
auto One = MIRBuilder.buildConstant(ShTy, 1);
5248+
if (IsFSHL) {
5249+
Y = MIRBuilder.buildInstr(RevOpcode, {Ty}, {X, Y, One}).getReg(0);
5250+
X = MIRBuilder.buildLShr(Ty, X, One).getReg(0);
5251+
} else {
5252+
X = MIRBuilder.buildInstr(RevOpcode, {Ty}, {X, Y, One}).getReg(0);
5253+
Y = MIRBuilder.buildShl(Ty, Y, One).getReg(0);
5254+
}
5255+
5256+
Z = MIRBuilder.buildNot(ShTy, Z).getReg(0);
5257+
}
5258+
5259+
MIRBuilder.buildInstr(RevOpcode, {Dst}, {X, Y, Z});
5260+
MI.eraseFromParent();
5261+
return Legalized;
5262+
}
5263+
5264+
LegalizerHelper::LegalizeResult
5265+
LegalizerHelper::lowerFunnelShiftAsShifts(MachineInstr &MI) {
5266+
Register Dst = MI.getOperand(0).getReg();
5267+
Register X = MI.getOperand(1).getReg();
5268+
Register Y = MI.getOperand(2).getReg();
5269+
Register Z = MI.getOperand(3).getReg();
5270+
LLT Ty = MRI.getType(Dst);
5271+
LLT ShTy = MRI.getType(Z);
5272+
5273+
const unsigned BW = Ty.getScalarSizeInBits();
5274+
const bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
5275+
5276+
Register ShX, ShY;
5277+
Register ShAmt, InvShAmt;
5278+
5279+
// FIXME: Emit optimized urem by constant instead of letting it expand later.
5280+
if (isNonZeroModBitWidthOrUndef(MRI, Z, BW)) {
5281+
// fshl: X << C | Y >> (BW - C)
5282+
// fshr: X << (BW - C) | Y >> C
5283+
// where C = Z % BW is not zero
5284+
auto BitWidthC = MIRBuilder.buildConstant(ShTy, BW);
5285+
ShAmt = MIRBuilder.buildURem(ShTy, Z, BitWidthC).getReg(0);
5286+
InvShAmt = MIRBuilder.buildSub(ShTy, BitWidthC, ShAmt).getReg(0);
5287+
ShX = MIRBuilder.buildShl(Ty, X, IsFSHL ? ShAmt : InvShAmt).getReg(0);
5288+
ShY = MIRBuilder.buildLShr(Ty, Y, IsFSHL ? InvShAmt : ShAmt).getReg(0);
5289+
} else {
5290+
// fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW))
5291+
// fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW)
5292+
auto Mask = MIRBuilder.buildConstant(ShTy, BW - 1);
5293+
if (isPowerOf2_32(BW)) {
5294+
// Z % BW -> Z & (BW - 1)
5295+
ShAmt = MIRBuilder.buildAnd(ShTy, Z, Mask).getReg(0);
5296+
// (BW - 1) - (Z % BW) -> ~Z & (BW - 1)
5297+
auto NotZ = MIRBuilder.buildNot(ShTy, Z);
5298+
InvShAmt = MIRBuilder.buildAnd(ShTy, NotZ, Mask).getReg(0);
5299+
} else {
5300+
auto BitWidthC = MIRBuilder.buildConstant(ShTy, BW);
5301+
ShAmt = MIRBuilder.buildURem(ShTy, Z, BitWidthC).getReg(0);
5302+
InvShAmt = MIRBuilder.buildSub(ShTy, Mask, ShAmt).getReg(0);
5303+
}
5304+
5305+
auto One = MIRBuilder.buildConstant(ShTy, 1);
5306+
if (IsFSHL) {
5307+
ShX = MIRBuilder.buildShl(Ty, X, ShAmt).getReg(0);
5308+
auto ShY1 = MIRBuilder.buildLShr(Ty, Y, One);
5309+
ShY = MIRBuilder.buildLShr(Ty, ShY1, InvShAmt).getReg(0);
5310+
} else {
5311+
auto ShX1 = MIRBuilder.buildShl(Ty, X, One);
5312+
ShX = MIRBuilder.buildShl(Ty, ShX1, InvShAmt).getReg(0);
5313+
ShY = MIRBuilder.buildLShr(Ty, Y, ShAmt).getReg(0);
5314+
}
5315+
}
5316+
5317+
MIRBuilder.buildOr(Dst, ShX, ShY);
5318+
MI.eraseFromParent();
5319+
return Legalized;
5320+
}
5321+
5322+
LegalizerHelper::LegalizeResult
5323+
LegalizerHelper::lowerFunnelShift(MachineInstr &MI) {
5324+
// These operations approximately do the following (while avoiding undefined
5325+
// shifts by BW):
5326+
// G_FSHL: (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
5327+
// G_FSHR: (X << (BW - (Z % BW))) | (Y >> (Z % BW))
5328+
Register Dst = MI.getOperand(0).getReg();
5329+
LLT Ty = MRI.getType(Dst);
5330+
LLT ShTy = MRI.getType(MI.getOperand(3).getReg());
5331+
5332+
bool IsFSHL = MI.getOpcode() == TargetOpcode::G_FSHL;
5333+
unsigned RevOpcode = IsFSHL ? TargetOpcode::G_FSHR : TargetOpcode::G_FSHL;
5334+
if (LI.getAction({RevOpcode, {Ty, ShTy}}).Action == Lower)
5335+
return lowerFunnelShiftAsShifts(MI);
5336+
return lowerFunnelShiftWithInverse(MI);
5337+
}
5338+
52105339
// Expand s32 = G_UITOFP s64 using bit operations to an IEEE float
52115340
// representation.
52125341
LegalizerHelper::LegalizeResult

llvm/lib/CodeGen/GlobalISel/Utils.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,38 @@ Optional<RegOrConstant> llvm::getVectorSplat(const MachineInstr &MI,
926926
return RegOrConstant(Reg);
927927
}
928928

929+
bool llvm::matchUnaryPredicate(
930+
const MachineRegisterInfo &MRI, Register Reg,
931+
std::function<bool(const Constant *ConstVal)> Match, bool AllowUndefs) {
932+
933+
const MachineInstr *Def = getDefIgnoringCopies(Reg, MRI);
934+
if (AllowUndefs && Def->getOpcode() == TargetOpcode::G_IMPLICIT_DEF)
935+
return Match(nullptr);
936+
937+
// TODO: Also handle fconstant
938+
if (Def->getOpcode() == TargetOpcode::G_CONSTANT)
939+
return Match(Def->getOperand(1).getCImm());
940+
941+
if (Def->getOpcode() != TargetOpcode::G_BUILD_VECTOR)
942+
return false;
943+
944+
for (unsigned I = 1, E = Def->getNumOperands(); I != E; ++I) {
945+
Register SrcElt = Def->getOperand(I).getReg();
946+
const MachineInstr *SrcDef = getDefIgnoringCopies(SrcElt, MRI);
947+
if (AllowUndefs && SrcDef->getOpcode() == TargetOpcode::G_IMPLICIT_DEF) {
948+
if (!Match(nullptr))
949+
return false;
950+
continue;
951+
}
952+
953+
if (SrcDef->getOpcode() != TargetOpcode::G_CONSTANT ||
954+
!Match(SrcDef->getOperand(1).getCImm()))
955+
return false;
956+
}
957+
958+
return true;
959+
}
960+
929961
bool llvm::isConstTrueVal(const TargetLowering &TLI, int64_t Val, bool IsVector,
930962
bool IsFP) {
931963
switch (TLI.getBooleanContents(IsVector, IsFP)) {

llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,11 +1595,26 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
15951595
.clampScalar(0, S32, S64)
15961596
.lower();
15971597

1598+
// TODO: Only Try to form v2s16 with legal packed instructions.
15981599
getActionDefinitionsBuilder(G_FSHR)
15991600
.legalFor({{S32, S32}})
1601+
.lowerFor({{V2S16, V2S16}})
1602+
.fewerElementsIf(elementTypeIs(0, S16), changeTo(0, V2S16))
16001603
.scalarize(0)
16011604
.lower();
16021605

1606+
if (ST.hasVOP3PInsts()) {
1607+
getActionDefinitionsBuilder(G_FSHL)
1608+
.lowerFor({{V2S16, V2S16}})
1609+
.fewerElementsIf(elementTypeIs(0, S16), changeTo(0, V2S16))
1610+
.scalarize(0)
1611+
.lower();
1612+
} else {
1613+
getActionDefinitionsBuilder(G_FSHL)
1614+
.scalarize(0)
1615+
.lower();
1616+
}
1617+
16031618
getActionDefinitionsBuilder(G_READCYCLECOUNTER)
16041619
.legalFor({S64});
16051620

@@ -1624,9 +1639,7 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
16241639
G_SADDO, G_SSUBO,
16251640

16261641
// TODO: Implement
1627-
G_FMINIMUM, G_FMAXIMUM,
1628-
G_FSHL
1629-
}).lower();
1642+
G_FMINIMUM, G_FMAXIMUM}).lower();
16301643

16311644
getActionDefinitionsBuilder({G_VASTART, G_VAARG, G_BRJT, G_JUMP_TABLE,
16321645
G_INDEXED_LOAD, G_INDEXED_SEXTLOAD,

0 commit comments

Comments
 (0)