Skip to content
This repository was archived by the owner on Mar 28, 2020. It is now read-only.

Commit b52af07

Browse files
committed
[DAGCombiner] Match load by bytes idiom and fold it into a single load. Attempt #2.
The previous patch (https://reviews.llvm.org/rL289538) got reverted because of a bug. Chandler also requested some changes to the algorithm. http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20161212/413479.html This is an updated patch. The key difference is that collectBitProviders (renamed to calculateByteProvider) now collects the origin of one byte, not the whole value. It simplifies the implementation and allows to stop the traversal earlier if we know that the result won't be used. From the original commit: Match a pattern where a wide type scalar value is loaded by several narrow loads and combined by shifts and ors. Fold it into a single load or a load and a bswap if the targets supports it. Assuming little endian target: i8 *a = ... i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) => i32 val = *((i32)a) i8 *a = ... i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] => i32 val = BSWAP(*((i32)a)) This optimization was discussed on llvm-dev some time ago in "Load combine pass" thread. We came to the conclusion that we want to do this transformation late in the pipeline because in presence of atomic loads load widening is irreversible transformation and it might hinder other optimizations. Eventually we'd like to support folding patterns like this where the offset has a variable and a constant part: i32 val = a[i] | (a[i + 1] << 8) | (a[i + 2] << 16) | (a[i + 3] << 24) Matching the pattern above is easier at SelectionDAG level since address reassociation has already happened and the fact that the loads are adjacent is clear. Understanding that these loads are adjacent at IR level would have involved looking through geps/zexts/adds while looking at the addresses. The general scheme is to match OR expressions by recursively calculating the origin of individual bytes which constitute the resulting OR value. If all the OR bytes come from memory verify that they are adjacent and match with little or big endian encoding of a wider value. If so and the load of the wider type (and bswap if needed) is allowed by the target generate a load and a bswap if needed. Reviewed By: RKSimon, filcab, chandlerc Differential Revision: https://reviews.llvm.org/D27861 git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@293036 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent 6bed410 commit b52af07

File tree

6 files changed

+1850
-0
lines changed

6 files changed

+1850
-0
lines changed

lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ namespace {
377377
unsigned PosOpcode, unsigned NegOpcode,
378378
const SDLoc &DL);
379379
SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
380+
SDValue MatchLoadCombine(SDNode *N);
380381
SDValue ReduceLoadWidth(SDNode *N);
381382
SDValue ReduceLoadOpStoreWidth(SDNode *N);
382383
SDValue splitMergedValStore(StoreSDNode *ST);
@@ -3985,6 +3986,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
39853986
if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N)))
39863987
return SDValue(Rot, 0);
39873988

3989+
if (SDValue Load = MatchLoadCombine(N))
3990+
return Load;
3991+
39883992
// Simplify the operands using demanded-bits information.
39893993
if (!VT.isVector() &&
39903994
SimplifyDemandedBits(SDValue(N, 0)))
@@ -4356,6 +4360,270 @@ struct BaseIndexOffset {
43564360
};
43574361
} // namespace
43584362

4363+
namespace {
4364+
/// Represents known origin of an individual byte in load combine pattern. The
4365+
/// value of the byte is either constant zero or comes from memory.
4366+
struct ByteProvider {
4367+
// For constant zero providers Load is set to nullptr. For memory providers
4368+
// Load represents the node which loads the byte from memory.
4369+
// ByteOffset is the offset of the byte in the value produced by the load.
4370+
LoadSDNode *Load;
4371+
unsigned ByteOffset;
4372+
4373+
ByteProvider() : Load(nullptr), ByteOffset(0) {}
4374+
4375+
static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
4376+
return ByteProvider(Load, ByteOffset);
4377+
}
4378+
static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
4379+
4380+
bool isConstantZero() { return !Load; }
4381+
bool isMemory() { return Load; }
4382+
4383+
bool operator==(const ByteProvider &Other) const {
4384+
return Other.Load == Load && Other.ByteOffset == ByteOffset;
4385+
}
4386+
4387+
private:
4388+
ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
4389+
: Load(Load), ByteOffset(ByteOffset) {}
4390+
};
4391+
4392+
/// Recursively traverses the expression calculating the origin of the requested
4393+
/// byte of the given value. Returns None if the provider can't be calculated.
4394+
///
4395+
/// For all the values except the root of the expression verifies that the value
4396+
/// has exactly one use and if it's not true return None. This way if the origin
4397+
/// of the byte is returned it's guaranteed that the values which contribute to
4398+
/// the byte are not used outside of this expression.
4399+
///
4400+
/// Because the parts of the expression are not allowed to have more than one
4401+
/// use this function iterates over trees, not DAGs. So it never visits the same
4402+
/// node more than once.
4403+
const Optional<ByteProvider> calculateByteProvider(SDValue Op, unsigned Index,
4404+
unsigned Depth,
4405+
bool Root = false) {
4406+
// Typical i64 by i8 pattern requires recursion up to 8 calls depth
4407+
if (Depth == 10)
4408+
return None;
4409+
4410+
if (!Root && !Op.hasOneUse())
4411+
return None;
4412+
4413+
assert(Op.getValueType().isScalarInteger() && "can't handle other types");
4414+
unsigned BitWidth = Op.getValueSizeInBits();
4415+
if (BitWidth % 8 != 0)
4416+
return None;
4417+
unsigned ByteWidth = BitWidth / 8;
4418+
assert(Index < ByteWidth && "invalid index requested");
4419+
4420+
switch (Op.getOpcode()) {
4421+
case ISD::OR: {
4422+
auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
4423+
if (!LHS)
4424+
return None;
4425+
auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
4426+
if (!RHS)
4427+
return None;
4428+
4429+
if (LHS->isConstantZero())
4430+
return RHS;
4431+
else if (RHS->isConstantZero())
4432+
return LHS;
4433+
else
4434+
return None;
4435+
}
4436+
case ISD::SHL: {
4437+
auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
4438+
if (!ShiftOp)
4439+
return None;
4440+
4441+
uint64_t BitShift = ShiftOp->getZExtValue();
4442+
if (BitShift % 8 != 0)
4443+
return None;
4444+
uint64_t ByteShift = BitShift / 8;
4445+
4446+
return Index < ByteShift
4447+
? ByteProvider::getConstantZero()
4448+
: calculateByteProvider(Op->getOperand(0), Index - ByteShift,
4449+
Depth + 1);
4450+
}
4451+
case ISD::ZERO_EXTEND: {
4452+
SDValue NarrowOp = Op->getOperand(0);
4453+
unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
4454+
if (NarrowBitWidth % 8 != 0)
4455+
return None;
4456+
uint64_t NarrowByteWidth = NarrowBitWidth / 8;
4457+
4458+
return Index >= NarrowByteWidth
4459+
? ByteProvider::getConstantZero()
4460+
: calculateByteProvider(NarrowOp, Index, Depth + 1);
4461+
}
4462+
case ISD::LOAD: {
4463+
auto L = cast<LoadSDNode>(Op.getNode());
4464+
4465+
// TODO: support ext loads
4466+
if (L->isVolatile() || L->isIndexed() ||
4467+
L->getExtensionType() != ISD::NON_EXTLOAD)
4468+
return None;
4469+
4470+
return ByteProvider::getMemory(L, Index);
4471+
}
4472+
}
4473+
4474+
return None;
4475+
}
4476+
} // namespace
4477+
4478+
/// Match a pattern where a wide type scalar value is loaded by several narrow
4479+
/// loads and combined by shifts and ors. Fold it into a single load or a load
4480+
/// and a BSWAP if the targets supports it.
4481+
///
4482+
/// Assuming little endian target:
4483+
/// i8 *a = ...
4484+
/// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
4485+
/// =>
4486+
/// i32 val = *((i32)a)
4487+
///
4488+
/// i8 *a = ...
4489+
/// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
4490+
/// =>
4491+
/// i32 val = BSWAP(*((i32)a))
4492+
///
4493+
/// TODO: This rule matches complex patterns with OR node roots and doesn't
4494+
/// interact well with the worklist mechanism. When a part of the pattern is
4495+
/// updated (e.g. one of the loads) its direct users are put into the worklist,
4496+
/// but the root node of the pattern which triggers the load combine is not
4497+
/// necessarily a direct user of the changed node. For example, once the address
4498+
/// of t28 load is reassociated load combine won't be triggered:
4499+
/// t25: i32 = add t4, Constant:i32<2>
4500+
/// t26: i64 = sign_extend t25
4501+
/// t27: i64 = add t2, t26
4502+
/// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
4503+
/// t29: i32 = zero_extend t28
4504+
/// t32: i32 = shl t29, Constant:i8<8>
4505+
/// t33: i32 = or t23, t32
4506+
/// As a possible fix visitLoad can check if the load can be a part of a load
4507+
/// combine pattern and add corresponding OR roots to the worklist.
4508+
SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
4509+
assert(N->getOpcode() == ISD::OR &&
4510+
"Can only match load combining against OR nodes");
4511+
4512+
// Handles simple types only
4513+
EVT VT = N->getValueType(0);
4514+
if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
4515+
return SDValue();
4516+
unsigned ByteWidth = VT.getSizeInBits() / 8;
4517+
4518+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
4519+
// Before legalize we can introduce too wide illegal loads which will be later
4520+
// split into legal sized loads. This enables us to combine i64 load by i8
4521+
// patterns to a couple of i32 loads on 32 bit targets.
4522+
if (LegalOperations && !TLI.isOperationLegal(ISD::LOAD, VT))
4523+
return SDValue();
4524+
4525+
auto LittleEndianByteAt = [](unsigned BW, unsigned i) { return i; };
4526+
auto BigEndianByteAt = [](unsigned BW, unsigned i) { return BW - i - 1; };
4527+
4528+
Optional<BaseIndexOffset> Base;
4529+
SDValue Chain;
4530+
4531+
SmallSet<LoadSDNode *, 8> Loads;
4532+
LoadSDNode *FirstLoad = nullptr;
4533+
4534+
bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
4535+
auto ByteAt = IsBigEndianTarget ? BigEndianByteAt : LittleEndianByteAt;
4536+
4537+
// Check if all the bytes of the OR we are looking at are loaded from the same
4538+
// base address. Collect bytes offsets from Base address in ByteOffsets.
4539+
SmallVector<int64_t, 4> ByteOffsets(ByteWidth);
4540+
for (unsigned i = 0; i < ByteWidth; i++) {
4541+
auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
4542+
if (!P || !P->isMemory()) // All the bytes must be loaded from memory
4543+
return SDValue();
4544+
4545+
LoadSDNode *L = P->Load;
4546+
assert(L->hasNUsesOfValue(1, 0) && !L->isVolatile() && !L->isIndexed() &&
4547+
(L->getExtensionType() == ISD::NON_EXTLOAD) &&
4548+
"Must be enforced by calculateByteProvider");
4549+
assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
4550+
4551+
// All loads must share the same chain
4552+
SDValue LChain = L->getChain();
4553+
if (!Chain)
4554+
Chain = LChain;
4555+
else if (Chain != LChain)
4556+
return SDValue();
4557+
4558+
// Loads must share the same base address
4559+
BaseIndexOffset Ptr = BaseIndexOffset::match(L->getBasePtr(), DAG);
4560+
if (!Base)
4561+
Base = Ptr;
4562+
else if (!Base->equalBaseIndex(Ptr))
4563+
return SDValue();
4564+
4565+
// Calculate the offset of the current byte from the base address
4566+
unsigned LoadBitWidth = L->getMemoryVT().getSizeInBits();
4567+
assert(LoadBitWidth % 8 == 0 &&
4568+
"can only analyze providers for individual bytes not bit");
4569+
unsigned LoadByteWidth = LoadBitWidth / 8;
4570+
int64_t MemoryByteOffset = ByteAt(LoadByteWidth, P->ByteOffset);
4571+
int64_t ByteOffsetFromBase = Ptr.Offset + MemoryByteOffset;
4572+
ByteOffsets[i] = ByteOffsetFromBase;
4573+
4574+
// Remember the first byte load
4575+
if (ByteOffsetFromBase == 0)
4576+
FirstLoad = L;
4577+
4578+
Loads.insert(L);
4579+
}
4580+
assert(Loads.size() > 0 && "All the bytes of the value must be loaded from "
4581+
"memory, so there must be at least one load which produces the value");
4582+
assert(Base && "Base address of the accessed memory location must be set");
4583+
4584+
// Check if the bytes of the OR we are looking at match with either big or
4585+
// little endian value load
4586+
bool BigEndian = true, LittleEndian = true;
4587+
for (unsigned i = 0; i < ByteWidth; i++) {
4588+
LittleEndian &= ByteOffsets[i] == LittleEndianByteAt(ByteWidth, i);
4589+
BigEndian &= ByteOffsets[i] == BigEndianByteAt(ByteWidth, i);
4590+
if (!BigEndian && !LittleEndian)
4591+
return SDValue();
4592+
}
4593+
assert((BigEndian != LittleEndian) && "should be either or");
4594+
assert(FirstLoad && "must be set");
4595+
4596+
// The node we are looking at matches with the pattern, check if we can
4597+
// replace it with a single load and bswap if needed.
4598+
4599+
// If the load needs byte swap check if the target supports it
4600+
bool NeedsBswap = IsBigEndianTarget != BigEndian;
4601+
4602+
// Before legalize we can introduce illegal bswaps which will be later
4603+
// converted to an explicit bswap sequence. This way we end up with a single
4604+
// load and byte shuffling instead of several loads and byte shuffling.
4605+
if (NeedsBswap && LegalOperations && !TLI.isOperationLegal(ISD::BSWAP, VT))
4606+
return SDValue();
4607+
4608+
// Check that a load of the wide type is both allowed and fast on the target
4609+
bool Fast = false;
4610+
bool Allowed = TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
4611+
VT, FirstLoad->getAddressSpace(),
4612+
FirstLoad->getAlignment(), &Fast);
4613+
if (!Allowed || !Fast)
4614+
return SDValue();
4615+
4616+
SDValue NewLoad =
4617+
DAG.getLoad(VT, SDLoc(N), Chain, FirstLoad->getBasePtr(),
4618+
FirstLoad->getPointerInfo(), FirstLoad->getAlignment());
4619+
4620+
// Transfer chain users from old loads to the new load.
4621+
for (LoadSDNode *L : Loads)
4622+
DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
4623+
4624+
return NeedsBswap ? DAG.getNode(ISD::BSWAP, SDLoc(N), VT, NewLoad) : NewLoad;
4625+
}
4626+
43594627
SDValue DAGCombiner::visitXOR(SDNode *N) {
43604628
SDValue N0 = N->getOperand(0);
43614629
SDValue N1 = N->getOperand(1);

0 commit comments

Comments
 (0)