@@ -377,6 +377,7 @@ namespace {
377
377
unsigned PosOpcode, unsigned NegOpcode,
378
378
const SDLoc &DL);
379
379
SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL);
380
+ SDValue MatchLoadCombine(SDNode *N);
380
381
SDValue ReduceLoadWidth(SDNode *N);
381
382
SDValue ReduceLoadOpStoreWidth(SDNode *N);
382
383
SDValue splitMergedValStore(StoreSDNode *ST);
@@ -3985,6 +3986,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
3985
3986
if (SDNode *Rot = MatchRotate(N0, N1, SDLoc(N)))
3986
3987
return SDValue(Rot, 0);
3987
3988
3989
+ if (SDValue Load = MatchLoadCombine(N))
3990
+ return Load;
3991
+
3988
3992
// Simplify the operands using demanded-bits information.
3989
3993
if (!VT.isVector() &&
3990
3994
SimplifyDemandedBits(SDValue(N, 0)))
@@ -4356,6 +4360,270 @@ struct BaseIndexOffset {
4356
4360
};
4357
4361
} // namespace
4358
4362
4363
+ namespace {
4364
+ /// Represents known origin of an individual byte in load combine pattern. The
4365
+ /// value of the byte is either constant zero or comes from memory.
4366
+ struct ByteProvider {
4367
+ // For constant zero providers Load is set to nullptr. For memory providers
4368
+ // Load represents the node which loads the byte from memory.
4369
+ // ByteOffset is the offset of the byte in the value produced by the load.
4370
+ LoadSDNode *Load;
4371
+ unsigned ByteOffset;
4372
+
4373
+ ByteProvider() : Load(nullptr), ByteOffset(0) {}
4374
+
4375
+ static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) {
4376
+ return ByteProvider(Load, ByteOffset);
4377
+ }
4378
+ static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); }
4379
+
4380
+ bool isConstantZero() { return !Load; }
4381
+ bool isMemory() { return Load; }
4382
+
4383
+ bool operator==(const ByteProvider &Other) const {
4384
+ return Other.Load == Load && Other.ByteOffset == ByteOffset;
4385
+ }
4386
+
4387
+ private:
4388
+ ByteProvider(LoadSDNode *Load, unsigned ByteOffset)
4389
+ : Load(Load), ByteOffset(ByteOffset) {}
4390
+ };
4391
+
4392
+ /// Recursively traverses the expression calculating the origin of the requested
4393
+ /// byte of the given value. Returns None if the provider can't be calculated.
4394
+ ///
4395
+ /// For all the values except the root of the expression verifies that the value
4396
+ /// has exactly one use and if it's not true return None. This way if the origin
4397
+ /// of the byte is returned it's guaranteed that the values which contribute to
4398
+ /// the byte are not used outside of this expression.
4399
+ ///
4400
+ /// Because the parts of the expression are not allowed to have more than one
4401
+ /// use this function iterates over trees, not DAGs. So it never visits the same
4402
+ /// node more than once.
4403
+ const Optional<ByteProvider> calculateByteProvider(SDValue Op, unsigned Index,
4404
+ unsigned Depth,
4405
+ bool Root = false) {
4406
+ // Typical i64 by i8 pattern requires recursion up to 8 calls depth
4407
+ if (Depth == 10)
4408
+ return None;
4409
+
4410
+ if (!Root && !Op.hasOneUse())
4411
+ return None;
4412
+
4413
+ assert(Op.getValueType().isScalarInteger() && "can't handle other types");
4414
+ unsigned BitWidth = Op.getValueSizeInBits();
4415
+ if (BitWidth % 8 != 0)
4416
+ return None;
4417
+ unsigned ByteWidth = BitWidth / 8;
4418
+ assert(Index < ByteWidth && "invalid index requested");
4419
+
4420
+ switch (Op.getOpcode()) {
4421
+ case ISD::OR: {
4422
+ auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1);
4423
+ if (!LHS)
4424
+ return None;
4425
+ auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1);
4426
+ if (!RHS)
4427
+ return None;
4428
+
4429
+ if (LHS->isConstantZero())
4430
+ return RHS;
4431
+ else if (RHS->isConstantZero())
4432
+ return LHS;
4433
+ else
4434
+ return None;
4435
+ }
4436
+ case ISD::SHL: {
4437
+ auto ShiftOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
4438
+ if (!ShiftOp)
4439
+ return None;
4440
+
4441
+ uint64_t BitShift = ShiftOp->getZExtValue();
4442
+ if (BitShift % 8 != 0)
4443
+ return None;
4444
+ uint64_t ByteShift = BitShift / 8;
4445
+
4446
+ return Index < ByteShift
4447
+ ? ByteProvider::getConstantZero()
4448
+ : calculateByteProvider(Op->getOperand(0), Index - ByteShift,
4449
+ Depth + 1);
4450
+ }
4451
+ case ISD::ZERO_EXTEND: {
4452
+ SDValue NarrowOp = Op->getOperand(0);
4453
+ unsigned NarrowBitWidth = NarrowOp.getScalarValueSizeInBits();
4454
+ if (NarrowBitWidth % 8 != 0)
4455
+ return None;
4456
+ uint64_t NarrowByteWidth = NarrowBitWidth / 8;
4457
+
4458
+ return Index >= NarrowByteWidth
4459
+ ? ByteProvider::getConstantZero()
4460
+ : calculateByteProvider(NarrowOp, Index, Depth + 1);
4461
+ }
4462
+ case ISD::LOAD: {
4463
+ auto L = cast<LoadSDNode>(Op.getNode());
4464
+
4465
+ // TODO: support ext loads
4466
+ if (L->isVolatile() || L->isIndexed() ||
4467
+ L->getExtensionType() != ISD::NON_EXTLOAD)
4468
+ return None;
4469
+
4470
+ return ByteProvider::getMemory(L, Index);
4471
+ }
4472
+ }
4473
+
4474
+ return None;
4475
+ }
4476
+ } // namespace
4477
+
4478
+ /// Match a pattern where a wide type scalar value is loaded by several narrow
4479
+ /// loads and combined by shifts and ors. Fold it into a single load or a load
4480
+ /// and a BSWAP if the targets supports it.
4481
+ ///
4482
+ /// Assuming little endian target:
4483
+ /// i8 *a = ...
4484
+ /// i32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24)
4485
+ /// =>
4486
+ /// i32 val = *((i32)a)
4487
+ ///
4488
+ /// i8 *a = ...
4489
+ /// i32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]
4490
+ /// =>
4491
+ /// i32 val = BSWAP(*((i32)a))
4492
+ ///
4493
+ /// TODO: This rule matches complex patterns with OR node roots and doesn't
4494
+ /// interact well with the worklist mechanism. When a part of the pattern is
4495
+ /// updated (e.g. one of the loads) its direct users are put into the worklist,
4496
+ /// but the root node of the pattern which triggers the load combine is not
4497
+ /// necessarily a direct user of the changed node. For example, once the address
4498
+ /// of t28 load is reassociated load combine won't be triggered:
4499
+ /// t25: i32 = add t4, Constant:i32<2>
4500
+ /// t26: i64 = sign_extend t25
4501
+ /// t27: i64 = add t2, t26
4502
+ /// t28: i8,ch = load<LD1[%tmp9]> t0, t27, undef:i64
4503
+ /// t29: i32 = zero_extend t28
4504
+ /// t32: i32 = shl t29, Constant:i8<8>
4505
+ /// t33: i32 = or t23, t32
4506
+ /// As a possible fix visitLoad can check if the load can be a part of a load
4507
+ /// combine pattern and add corresponding OR roots to the worklist.
4508
+ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
4509
+ assert(N->getOpcode() == ISD::OR &&
4510
+ "Can only match load combining against OR nodes");
4511
+
4512
+ // Handles simple types only
4513
+ EVT VT = N->getValueType(0);
4514
+ if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
4515
+ return SDValue();
4516
+ unsigned ByteWidth = VT.getSizeInBits() / 8;
4517
+
4518
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
4519
+ // Before legalize we can introduce too wide illegal loads which will be later
4520
+ // split into legal sized loads. This enables us to combine i64 load by i8
4521
+ // patterns to a couple of i32 loads on 32 bit targets.
4522
+ if (LegalOperations && !TLI.isOperationLegal(ISD::LOAD, VT))
4523
+ return SDValue();
4524
+
4525
+ auto LittleEndianByteAt = [](unsigned BW, unsigned i) { return i; };
4526
+ auto BigEndianByteAt = [](unsigned BW, unsigned i) { return BW - i - 1; };
4527
+
4528
+ Optional<BaseIndexOffset> Base;
4529
+ SDValue Chain;
4530
+
4531
+ SmallSet<LoadSDNode *, 8> Loads;
4532
+ LoadSDNode *FirstLoad = nullptr;
4533
+
4534
+ bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian();
4535
+ auto ByteAt = IsBigEndianTarget ? BigEndianByteAt : LittleEndianByteAt;
4536
+
4537
+ // Check if all the bytes of the OR we are looking at are loaded from the same
4538
+ // base address. Collect bytes offsets from Base address in ByteOffsets.
4539
+ SmallVector<int64_t, 4> ByteOffsets(ByteWidth);
4540
+ for (unsigned i = 0; i < ByteWidth; i++) {
4541
+ auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true);
4542
+ if (!P || !P->isMemory()) // All the bytes must be loaded from memory
4543
+ return SDValue();
4544
+
4545
+ LoadSDNode *L = P->Load;
4546
+ assert(L->hasNUsesOfValue(1, 0) && !L->isVolatile() && !L->isIndexed() &&
4547
+ (L->getExtensionType() == ISD::NON_EXTLOAD) &&
4548
+ "Must be enforced by calculateByteProvider");
4549
+ assert(L->getOffset().isUndef() && "Unindexed load must have undef offset");
4550
+
4551
+ // All loads must share the same chain
4552
+ SDValue LChain = L->getChain();
4553
+ if (!Chain)
4554
+ Chain = LChain;
4555
+ else if (Chain != LChain)
4556
+ return SDValue();
4557
+
4558
+ // Loads must share the same base address
4559
+ BaseIndexOffset Ptr = BaseIndexOffset::match(L->getBasePtr(), DAG);
4560
+ if (!Base)
4561
+ Base = Ptr;
4562
+ else if (!Base->equalBaseIndex(Ptr))
4563
+ return SDValue();
4564
+
4565
+ // Calculate the offset of the current byte from the base address
4566
+ unsigned LoadBitWidth = L->getMemoryVT().getSizeInBits();
4567
+ assert(LoadBitWidth % 8 == 0 &&
4568
+ "can only analyze providers for individual bytes not bit");
4569
+ unsigned LoadByteWidth = LoadBitWidth / 8;
4570
+ int64_t MemoryByteOffset = ByteAt(LoadByteWidth, P->ByteOffset);
4571
+ int64_t ByteOffsetFromBase = Ptr.Offset + MemoryByteOffset;
4572
+ ByteOffsets[i] = ByteOffsetFromBase;
4573
+
4574
+ // Remember the first byte load
4575
+ if (ByteOffsetFromBase == 0)
4576
+ FirstLoad = L;
4577
+
4578
+ Loads.insert(L);
4579
+ }
4580
+ assert(Loads.size() > 0 && "All the bytes of the value must be loaded from "
4581
+ "memory, so there must be at least one load which produces the value");
4582
+ assert(Base && "Base address of the accessed memory location must be set");
4583
+
4584
+ // Check if the bytes of the OR we are looking at match with either big or
4585
+ // little endian value load
4586
+ bool BigEndian = true, LittleEndian = true;
4587
+ for (unsigned i = 0; i < ByteWidth; i++) {
4588
+ LittleEndian &= ByteOffsets[i] == LittleEndianByteAt(ByteWidth, i);
4589
+ BigEndian &= ByteOffsets[i] == BigEndianByteAt(ByteWidth, i);
4590
+ if (!BigEndian && !LittleEndian)
4591
+ return SDValue();
4592
+ }
4593
+ assert((BigEndian != LittleEndian) && "should be either or");
4594
+ assert(FirstLoad && "must be set");
4595
+
4596
+ // The node we are looking at matches with the pattern, check if we can
4597
+ // replace it with a single load and bswap if needed.
4598
+
4599
+ // If the load needs byte swap check if the target supports it
4600
+ bool NeedsBswap = IsBigEndianTarget != BigEndian;
4601
+
4602
+ // Before legalize we can introduce illegal bswaps which will be later
4603
+ // converted to an explicit bswap sequence. This way we end up with a single
4604
+ // load and byte shuffling instead of several loads and byte shuffling.
4605
+ if (NeedsBswap && LegalOperations && !TLI.isOperationLegal(ISD::BSWAP, VT))
4606
+ return SDValue();
4607
+
4608
+ // Check that a load of the wide type is both allowed and fast on the target
4609
+ bool Fast = false;
4610
+ bool Allowed = TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(),
4611
+ VT, FirstLoad->getAddressSpace(),
4612
+ FirstLoad->getAlignment(), &Fast);
4613
+ if (!Allowed || !Fast)
4614
+ return SDValue();
4615
+
4616
+ SDValue NewLoad =
4617
+ DAG.getLoad(VT, SDLoc(N), Chain, FirstLoad->getBasePtr(),
4618
+ FirstLoad->getPointerInfo(), FirstLoad->getAlignment());
4619
+
4620
+ // Transfer chain users from old loads to the new load.
4621
+ for (LoadSDNode *L : Loads)
4622
+ DAG.ReplaceAllUsesOfValueWith(SDValue(L, 1), SDValue(NewLoad.getNode(), 1));
4623
+
4624
+ return NeedsBswap ? DAG.getNode(ISD::BSWAP, SDLoc(N), VT, NewLoad) : NewLoad;
4625
+ }
4626
+
4359
4627
SDValue DAGCombiner::visitXOR(SDNode *N) {
4360
4628
SDValue N0 = N->getOperand(0);
4361
4629
SDValue N1 = N->getOperand(1);
0 commit comments