@@ -12357,6 +12357,193 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
12357
12357
return Accum;
12358
12358
}
12359
12359
12360
+ // Collect the ultimate src of each of the mul24 node's operands, and confirm
12361
+ // each operand is 8 bytes.
12362
+ static std::optional<ByteProvider<SDValue>>
12363
+ handleMulOperand(const SDValue &MulOperand) {
12364
+ auto Byte0 = calculateByteProvider(MulOperand, 0, 0);
12365
+ if (!Byte0 || Byte0->isConstantZero()) {
12366
+ return std::nullopt;
12367
+ }
12368
+ auto Byte1 = calculateByteProvider(MulOperand, 1, 0);
12369
+ if (Byte1 && !Byte1->isConstantZero()) {
12370
+ return std::nullopt;
12371
+ }
12372
+ return Byte0;
12373
+ }
12374
+
12375
+ static unsigned addPermMasks(unsigned First, unsigned Second) {
12376
+ unsigned FirstCs = First & 0x0c0c0c0c;
12377
+ unsigned SecondCs = Second & 0x0c0c0c0c;
12378
+ unsigned FirstNoCs = First & ~0x0c0c0c0c;
12379
+ unsigned SecondNoCs = Second & ~0x0c0c0c0c;
12380
+
12381
+ assert(FirstCs & 0xFF | SecondCs & 0xFF);
12382
+ assert(FirstCs & 0xFF00 | SecondCs & 0xFF00);
12383
+ assert(FirstCs & 0xFF0000 | SecondCs & 0xFF0000);
12384
+ assert(FirstCs & 0xFF000000 | SecondCs & 0xFF000000);
12385
+
12386
+ return (FirstNoCs | SecondNoCs) | (FirstCs & SecondCs);
12387
+ }
12388
+
12389
+ static void placeSources(ByteProvider<SDValue> &Src0,
12390
+ ByteProvider<SDValue> &Src1,
12391
+ SmallVectorImpl<std::pair<SDValue, unsigned>> &Src0s,
12392
+ SmallVectorImpl<std::pair<SDValue, unsigned>> &Src1s,
12393
+ int Step) {
12394
+
12395
+ assert(Src0.Src.has_value() && Src1.Src.has_value());
12396
+ // Src0s and Src1s are empty, just place arbitrarily
12397
+ if (Step == 0) {
12398
+ Src0s.push_back({*Src0.Src, (Src0.SrcOffset << 24) + 0x0c0c0c});
12399
+ Src1s.push_back({*Src1.Src, (Src1.SrcOffset << 24) + 0x0c0c0c});
12400
+ return;
12401
+ }
12402
+
12403
+ for (int BPI = 0; BPI < 2; BPI++) {
12404
+ std::pair<ByteProvider<SDValue>, ByteProvider<SDValue>> BPP = {Src0, Src1};
12405
+ if (BPI == 1) {
12406
+ BPP = {Src1, Src0};
12407
+ }
12408
+ unsigned ZeroMask = 0x0c0c0c0c;
12409
+ unsigned FMask = 0xFF << (8 * (3 - Step));
12410
+
12411
+ unsigned FirstMask =
12412
+ BPP.first.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12413
+ unsigned SecondMask =
12414
+ BPP.second.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask);
12415
+ // Attempt to find Src vector which contains our SDValue, if so, add our
12416
+ // perm mask to the existing one. If we are unable to find a match for the
12417
+ // first SDValue, attempt to find match for the second.
12418
+ int FirstGroup = -1;
12419
+ for (int I = 0; I < 2; I++) {
12420
+ SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12421
+ I == 0 ? Src0s : Src1s;
12422
+ auto MatchesFirst = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12423
+ return IterElt.first == *BPP.first.Src;
12424
+ };
12425
+
12426
+ auto Match = std::find_if(Srcs.begin(), Srcs.end(), MatchesFirst);
12427
+ if (Match != Srcs.end()) {
12428
+ Match->second = addPermMasks(FirstMask, Match->second);
12429
+ FirstGroup = I;
12430
+ break;
12431
+ }
12432
+ }
12433
+ if (FirstGroup != -1) {
12434
+ SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs =
12435
+ FirstGroup == 1 ? Src0s : Src1s;
12436
+ auto MatchesSecond = [&BPP](std::pair<SDValue, unsigned> IterElt) {
12437
+ return IterElt.first == *BPP.second.Src;
12438
+ };
12439
+ auto Match = std::find_if(Srcs.begin(), Srcs.end(), MatchesSecond);
12440
+ if (Match != Srcs.end()) {
12441
+ Match->second = addPermMasks(SecondMask, Match->second);
12442
+ } else
12443
+ Srcs.push_back({*BPP.second.Src, SecondMask});
12444
+ return;
12445
+ }
12446
+ }
12447
+
12448
+ // If we have made it here, then we could not find a match in Src0s or Src1s
12449
+ // for either Src0 or Src1, so just place them arbitrarily.
12450
+
12451
+ unsigned ZeroMask = 0x0c0c0c0c;
12452
+ unsigned FMask = 0xFF << (8 * (3 - Step));
12453
+
12454
+ Src0s.push_back(
12455
+ {*Src0.Src, (Src0.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
12456
+ Src1s.push_back(
12457
+ {*Src1.Src, (Src1.SrcOffset << (8 * (3 - Step)) | (ZeroMask & ~FMask))});
12458
+
12459
+ return;
12460
+ }
12461
+
12462
+ static SDValue
12463
+ resolveSources(SelectionDAG &DAG, SDLoc SL,
12464
+ SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12465
+ bool IsSigned, bool IsAny) {
12466
+
12467
+ // If we just have one source, just permute it accordingly.
12468
+ if (Srcs.size() == 1) {
12469
+ auto Elt = Srcs.begin();
12470
+ auto EltVal = DAG.getBitcastedAnyExtOrTrunc(Elt->first, SL, MVT::i32);
12471
+
12472
+ // v_perm will produce the original value
12473
+ if (Elt->second == 0x3020100)
12474
+ return EltVal;
12475
+
12476
+ return DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12477
+ DAG.getConstant(Elt->second, SL, MVT::i32));
12478
+ }
12479
+
12480
+ auto FirstElt = Srcs.begin();
12481
+ auto SecondElt = std::next(FirstElt);
12482
+
12483
+ SmallVector<SDValue, 2> Perms;
12484
+
12485
+ // If we have multiple sources in the chain, combine them via perms (using
12486
+ // calculated perm mask) and Ors.
12487
+ while (true) {
12488
+ auto FirstMask = FirstElt->second;
12489
+ auto SecondMask = SecondElt->second;
12490
+
12491
+ unsigned FirstCs = FirstMask & 0x0c0c0c0c;
12492
+ unsigned FirstPlusFour = FirstMask | 0x04040404;
12493
+ // 0x0c + 0x04 = 0x10, so anding with 0x0F will produced 0x00 for any
12494
+ // original 0x0C
12495
+ FirstMask = (FirstPlusFour & 0x0F0F0F0F) | FirstCs;
12496
+
12497
+ auto PermMask = addPermMasks(FirstMask, SecondMask);
12498
+ auto FirstVal =
12499
+ DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
12500
+ auto SecondVal =
12501
+ DAG.getBitcastedAnyExtOrTrunc(SecondElt->first, SL, MVT::i32);
12502
+
12503
+ Perms.push_back(DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, FirstVal,
12504
+ SecondVal,
12505
+ DAG.getConstant(PermMask, SL, MVT::i32)));
12506
+
12507
+ FirstElt = std::next(SecondElt);
12508
+ if (FirstElt == Srcs.end())
12509
+ break;
12510
+
12511
+ SecondElt = std::next(FirstElt);
12512
+ // If we only have a FirstElt, then just combine that into the cumulative
12513
+ // source node
12514
+ if (SecondElt == Srcs.end()) {
12515
+ auto EltVal =
12516
+ DAG.getBitcastedAnyExtOrTrunc(FirstElt->first, SL, MVT::i32);
12517
+
12518
+ Perms.push_back(
12519
+ DAG.getNode(AMDGPUISD::PERM, SL, MVT::i32, EltVal, EltVal,
12520
+ DAG.getConstant(FirstElt->second, SL, MVT::i32)));
12521
+ break;
12522
+ }
12523
+ }
12524
+
12525
+ assert(Perms.size() == 1 || Perms.size() == 2);
12526
+ return Perms.size() == 2
12527
+ ? DAG.getNode(ISD::OR, SL, MVT::i32, Perms[0], Perms[1])
12528
+ : Perms[0];
12529
+ }
12530
+
12531
+ static void fixMasks(SmallVectorImpl<std::pair<SDValue, unsigned>> &Srcs,
12532
+ unsigned ChainLength) {
12533
+ for (auto &[EntryVal, EntryMask] : Srcs) {
12534
+ EntryMask = EntryMask >> ((4 - ChainLength) * 8);
12535
+ auto ZeroMask = ChainLength == 2 ? 0x0c0c0000 : 0x0c000000;
12536
+ EntryMask += ZeroMask;
12537
+ }
12538
+ }
12539
+
12540
+ static bool isMul(const SDValue Op) {
12541
+ auto Opcode = Op.getOpcode();
12542
+
12543
+ return (Opcode == ISD::MUL || Opcode == AMDGPUISD::MUL_U24 ||
12544
+ Opcode == AMDGPUISD::MUL_I24);
12545
+ }
12546
+
12360
12547
SDValue SITargetLowering::performAddCombine(SDNode *N,
12361
12548
DAGCombinerInfo &DCI) const {
12362
12549
SelectionDAG &DAG = DCI.DAG;
@@ -12370,14 +12557,140 @@ SDValue SITargetLowering::performAddCombine(SDNode *N,
12370
12557
if (SDValue Folded = tryFoldToMad64_32(N, DCI))
12371
12558
return Folded;
12372
12559
}
12373
-
12374
- return SDValue();
12375
12560
}
12376
12561
12377
12562
if (SDValue V = reassociateScalarOps(N, DAG)) {
12378
12563
return V;
12379
12564
}
12380
12565
12566
+ if ((isMul(LHS) || isMul(RHS)) && Subtarget->hasDot7Insts() &&
12567
+ (Subtarget->hasDot1Insts() || Subtarget->hasDot8Insts())) {
12568
+ SDValue TempNode(N, 0);
12569
+ auto MulIdx = isMul(LHS) ? 0 : 1;
12570
+
12571
+ auto MulOpcode = TempNode.getOperand(MulIdx).getOpcode();
12572
+ bool IsSigned =
12573
+ MulOpcode == AMDGPUISD::MUL_I24 ||
12574
+ (MulOpcode == ISD::MUL &&
12575
+ TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
12576
+ !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
12577
+ SmallVector<std::pair<SDValue, unsigned>, 4> Src0s;
12578
+ SmallVector<std::pair<SDValue, unsigned>, 4> Src1s;
12579
+ SmallVector<SDValue, 4> Src2s;
12580
+
12581
+ // Match the v_dot4 tree, while collecting src nodes.
12582
+ int ChainLength = 0;
12583
+ for (int I = 0; I < 4; I++) {
12584
+ auto MulIdx = isMul(LHS) ? 0 : isMul(RHS) ? 1 : -1;
12585
+ if (MulIdx == -1)
12586
+ break;
12587
+ auto IterIsSigned =
12588
+ MulOpcode == AMDGPUISD::MUL_I24 ||
12589
+ (MulOpcode == ISD::MUL &&
12590
+ TempNode->getOperand(MulIdx)->getFlags().hasNoSignedWrap() &&
12591
+ !TempNode->getOperand(MulIdx)->getFlags().hasNoUnsignedWrap());
12592
+ if (IterIsSigned != IsSigned) {
12593
+ break;
12594
+ }
12595
+ auto Src0 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(0));
12596
+ if (!Src0)
12597
+ break;
12598
+ auto Src1 = handleMulOperand(TempNode->getOperand(MulIdx)->getOperand(1));
12599
+ if (!Src1)
12600
+ break;
12601
+ placeSources(*Src0, *Src1, Src0s, Src1s, I);
12602
+ auto AddIdx = 1 - MulIdx;
12603
+ // Allow the special case where add (add (mul24, 0), mul24) became ->
12604
+ // add (mul24, mul24)
12605
+ if (I == 2 && isMul(TempNode->getOperand(AddIdx))) {
12606
+ Src2s.push_back(TempNode->getOperand(AddIdx));
12607
+ auto Src0 =
12608
+ handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(0));
12609
+ if (!Src0)
12610
+ break;
12611
+ auto Src1 =
12612
+ handleMulOperand(TempNode->getOperand(AddIdx)->getOperand(1));
12613
+ if (!Src1)
12614
+ break;
12615
+ placeSources(*Src0, *Src1, Src0s, Src1s, I + 1);
12616
+ Src2s.push_back(DAG.getConstant(0, SL, MVT::i32));
12617
+ ChainLength = I + 2;
12618
+ break;
12619
+ }
12620
+
12621
+ TempNode = TempNode->getOperand(AddIdx);
12622
+ Src2s.push_back(TempNode);
12623
+ ChainLength = I + 1;
12624
+ if (TempNode->getNumOperands() < 2)
12625
+ break;
12626
+ LHS = TempNode->getOperand(0);
12627
+ RHS = TempNode->getOperand(1);
12628
+ }
12629
+
12630
+ if (ChainLength < 2)
12631
+ return SDValue();
12632
+
12633
+ // Masks were constructed with assumption that we would find a chain of
12634
+ // length 4. If not, then we need to 0 out the MSB bits (via perm mask of
12635
+ // 0x0c) so they do not affect dot calculation.
12636
+ if (ChainLength < 4) {
12637
+ fixMasks(Src0s, ChainLength);
12638
+ fixMasks(Src1s, ChainLength);
12639
+ }
12640
+
12641
+ SDValue Src0, Src1;
12642
+
12643
+ // If we are just using a single source for both, and have permuted the
12644
+ // bytes consistently, we can just use the sources without permuting
12645
+ // (commutation)
12646
+ bool UseOriginalSrc = false;
12647
+ if (ChainLength == 4 && Src0s.size() == 1 && Src1s.size() == 1 &&
12648
+ Src0s.begin()->second == Src1s.begin()->second &&
12649
+ Src0s.begin()->first.getValueSizeInBits() == 32 &&
12650
+ Src1s.begin()->first.getValueSizeInBits() == 32) {
12651
+ SmallVector<unsigned, 4> SrcBytes;
12652
+ auto Src0Mask = Src0s.begin()->second;
12653
+ SrcBytes.push_back(Src0Mask & 0xFF000000);
12654
+ bool UniqueEntries = true;
12655
+ for (auto I = 1; I < 4; I++) {
12656
+ auto NextByte = Src0Mask & (0xFF << ((3 - I) * 8));
12657
+
12658
+ if (is_contained(SrcBytes, NextByte)) {
12659
+ UniqueEntries = false;
12660
+ break;
12661
+ }
12662
+ SrcBytes.push_back(NextByte);
12663
+ }
12664
+
12665
+ if (UniqueEntries) {
12666
+ UseOriginalSrc = true;
12667
+ // Must be 32 bits to enter above conditional
12668
+ assert(Src0s.begin()->first.getValueSizeInBits() == 32);
12669
+ assert(Src1s.begin()->first.getValueSizeInBits() == 32);
12670
+ Src0 = DAG.getBitcast(MVT::getIntegerVT(32), Src0s.begin()->first);
12671
+ Src1 = DAG.getBitcast(MVT::getIntegerVT(32), Src1s.begin()->first);
12672
+ }
12673
+ }
12674
+
12675
+ if (!UseOriginalSrc) {
12676
+ Src0 = resolveSources(DAG, SL, Src0s, false, true);
12677
+ Src1 = resolveSources(DAG, SL, Src1s, false, true);
12678
+ }
12679
+
12680
+ SDValue Src2 =
12681
+ DAG.getExtOrTrunc(IsSigned, Src2s[ChainLength - 1], SL, MVT::i32);
12682
+
12683
+ SDValue IID = DAG.getTargetConstant(IsSigned ? Intrinsic::amdgcn_sdot4
12684
+ : Intrinsic::amdgcn_udot4,
12685
+ SL, MVT::i64);
12686
+
12687
+ assert(!VT.isVector());
12688
+ auto Dot = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SL, MVT::i32, IID, Src0,
12689
+ Src1, Src2, DAG.getTargetConstant(0, SL, MVT::i1));
12690
+
12691
+ return DAG.getExtOrTrunc(IsSigned, Dot, SL, VT);
12692
+ }
12693
+
12381
12694
if (VT != MVT::i32 || !DCI.isAfterLegalizeDAG())
12382
12695
return SDValue();
12383
12696
0 commit comments