Skip to content

Commit 425e1f1

Browse files
mshockwavenpanchen
andcommitted
[IA][RISCV] Support VP intrinsics in InterleavedAccessPass
Teach InterleavedAccessPass to recognize the following patterns: - vp.store an interleaved scalable vector - Deinterleaving a scalable vector loaded from vp.load - Deinterleaving a scalable vector loaded from a vp.strided.load Upon recognizing these patterns, IA will collect the interleaved / deinterleaved operands and delegate them over to their respective newly-added TLI hooks. For RISC-V, these patterns are lowered into segmented loads/stores (except when we're interleaving constant splats, in which case a unit-strde store will be generated) Right now we only recognized power-of-two (de)interleave cases, in which (de)interleave4/8 are synthesized from a tree of (de)interleave2. Co-authored-by: Nikolay Panchenko <[email protected]>
1 parent 4ad7a62 commit 425e1f1

File tree

7 files changed

+1591
-0
lines changed

7 files changed

+1591
-0
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class TargetRegisterClass;
9494
class TargetRegisterInfo;
9595
class TargetTransformInfo;
9696
class Value;
97+
class VPIntrinsic;
9798

9899
namespace Sched {
99100

@@ -3152,6 +3153,47 @@ class TargetLoweringBase {
31523153
return false;
31533154
}
31543155

3156+
/// Lower an interleaved load to target specific intrinsics. Return
3157+
/// true on success.
3158+
///
3159+
/// \p Load is a vp.load instruction.
3160+
/// \p Mask is a mask value
3161+
/// \p DeinterleaveIntrin is vector.deinterleave intrinsic
3162+
/// \p Factor is the interleave factor.
3163+
/// \p DeinterleaveRes is a list of deinterleaved results.
3164+
virtual bool lowerInterleavedScalableLoad(
3165+
VPIntrinsic *Load, Value *Mask, IntrinsicInst *DeinterleaveIntrin,
3166+
unsigned Factor, ArrayRef<Value *> DeinterleaveRes) const {
3167+
return false;
3168+
}
3169+
3170+
/// Lower an interleaved store to target specific intrinsics. Return
3171+
/// true on success.
3172+
///
3173+
/// \p Store is the vp.store instruction.
3174+
/// \p Mask is a mask value
3175+
/// \p InterleaveIntrin is vector.interleave intrinsic
3176+
/// \p Factor is the interleave factor.
3177+
/// \p InterleaveOps is a list of values being interleaved.
3178+
virtual bool lowerInterleavedScalableStore(
3179+
VPIntrinsic *Store, Value *Mask, IntrinsicInst *InterleaveIntrin,
3180+
unsigned Factor, ArrayRef<Value *> InterleaveOps) const {
3181+
return false;
3182+
}
3183+
3184+
/// Lower a deinterleave intrinsic to a target specific strided load
3185+
/// intrinsic. Return true on success.
3186+
///
3187+
/// \p StridedLoad is the vp.strided.load instruction.
3188+
/// \p DI is the deinterleave intrinsic.
3189+
/// \p Factor is the interleave factor.
3190+
/// \p DeinterleaveRes is a list of deinterleaved results.
3191+
virtual bool lowerDeinterleaveIntrinsicToStridedLoad(
3192+
VPIntrinsic *StridedLoad, IntrinsicInst *DI, unsigned Factor,
3193+
ArrayRef<Value *> DeinterleaveRes) const {
3194+
return false;
3195+
}
3196+
31553197
/// Lower a deinterleave intrinsic to a target specific load intrinsic.
31563198
/// Return true on success. Currently only supports
31573199
/// llvm.vector.deinterleave2

llvm/lib/CodeGen/InterleavedAccessPass.cpp

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "llvm/IR/Instruction.h"
6161
#include "llvm/IR/Instructions.h"
6262
#include "llvm/IR/IntrinsicInst.h"
63+
#include "llvm/IR/PatternMatch.h"
6364
#include "llvm/InitializePasses.h"
6465
#include "llvm/Pass.h"
6566
#include "llvm/Support/Casting.h"
@@ -248,6 +249,186 @@ static bool isReInterleaveMask(ShuffleVectorInst *SVI, unsigned &Factor,
248249
return false;
249250
}
250251

252+
// For an (de)interleave tree like this:
253+
//
254+
// A C B D
255+
// |___| |___|
256+
// |_____|
257+
// |
258+
// A B C D
259+
//
260+
// We will get ABCD at the end while the leave operands/results
261+
// are ACBD, which are also what we initially collected in
262+
// getVectorInterleaveFactor / getVectorDeinterleaveFactor. But TLI
263+
// hooks (e.g. lowerInterleavedScalableLoad) expect ABCD, so we need
264+
// to reorder them by interleaving these values.
265+
static void interleaveLeafValues(SmallVectorImpl<Value *> &Leaves) {
266+
unsigned Factor = Leaves.size();
267+
assert(isPowerOf2_32(Factor) && Factor <= 8 && Factor > 1);
268+
269+
if (Factor == 2)
270+
return;
271+
272+
SmallVector<Value *, 8> Buffer;
273+
if (Factor == 4) {
274+
for (unsigned SrcIdx : {0, 2, 1, 3})
275+
Buffer.push_back(Leaves[SrcIdx]);
276+
} else {
277+
// Factor of 8.
278+
//
279+
// A E C G B F D H
280+
// |_| |_| |_| |_|
281+
// |___| |___|
282+
// |_______|
283+
// |
284+
// A B C D E F G H
285+
for (unsigned SrcIdx : {0, 4, 2, 6, 1, 5, 3, 7})
286+
Buffer.push_back(Leaves[SrcIdx]);
287+
}
288+
289+
llvm::copy(Buffer, Leaves.begin());
290+
}
291+
292+
static unsigned getVectorInterleaveFactor(IntrinsicInst *II,
293+
SmallVectorImpl<Value *> &Operands) {
294+
if (II->getIntrinsicID() != Intrinsic::vector_interleave2)
295+
return 0;
296+
297+
unsigned Factor = 0;
298+
299+
// Visit with BFS
300+
SmallVector<IntrinsicInst *, 8> Queue;
301+
Queue.push_back(II);
302+
while (!Queue.empty()) {
303+
IntrinsicInst *Current = Queue.front();
304+
Queue.erase(Queue.begin());
305+
306+
for (unsigned I = 0; I < 2; ++I) {
307+
Value *Op = Current->getOperand(I);
308+
if (auto *OpII = dyn_cast<IntrinsicInst>(Op))
309+
if (OpII->getIntrinsicID() == Intrinsic::vector_interleave2) {
310+
Queue.push_back(OpII);
311+
continue;
312+
}
313+
314+
++Factor;
315+
Operands.push_back(Op);
316+
}
317+
}
318+
319+
// Currently we only recognize power-of-two factors.
320+
// FIXME: should we assert here instead?
321+
if (Factor > 1 && isPowerOf2_32(Factor)) {
322+
interleaveLeafValues(Operands);
323+
return Factor;
324+
}
325+
return 0;
326+
}
327+
328+
/// Check the interleaved mask
329+
///
330+
/// - if a value within the optional is non-nullptr, the value corresponds to
331+
/// deinterleaved mask
332+
/// - if a value within the option is nullptr, the value corresponds to all-true
333+
/// mask
334+
/// - return nullopt if mask cannot be deinterleaved
335+
static std::optional<Value *> getMask(Value *WideMask, unsigned Factor) {
336+
using namespace llvm::PatternMatch;
337+
if (auto *IMI = dyn_cast<IntrinsicInst>(WideMask)) {
338+
SmallVector<Value *, 8> Operands;
339+
if (unsigned MaskFactor = getVectorInterleaveFactor(IMI, Operands)) {
340+
assert(!Operands.empty());
341+
if (MaskFactor == Factor &&
342+
std::equal(Operands.begin(), Operands.end(), Operands.begin()))
343+
return Operands.front();
344+
}
345+
}
346+
if (match(WideMask, m_AllOnes()))
347+
return nullptr;
348+
return std::nullopt;
349+
}
350+
351+
static unsigned getVectorDeInterleaveFactor(IntrinsicInst *II,
352+
SmallVectorImpl<Value *> &Results) {
353+
using namespace PatternMatch;
354+
if (II->getIntrinsicID() != Intrinsic::vector_deinterleave2 ||
355+
!II->hasNUses(2))
356+
return 0;
357+
358+
unsigned Factor = 0;
359+
360+
// Visit with BFS
361+
SmallVector<IntrinsicInst *, 8> Queue;
362+
Queue.push_back(II);
363+
while (!Queue.empty()) {
364+
IntrinsicInst *Current = Queue.front();
365+
Queue.erase(Queue.begin());
366+
assert(Current->hasNUses(2));
367+
368+
unsigned VisitedIdx = 0;
369+
for (User *Usr : Current->users()) {
370+
// We're playing safe here and matches only the expression
371+
// consisting of a perfectly balanced binary tree in which all
372+
// intermediate values are only used once.
373+
if (!Usr->hasOneUse() || !isa<ExtractValueInst>(Usr))
374+
return 0;
375+
376+
auto *EV = cast<ExtractValueInst>(Usr);
377+
ArrayRef<unsigned> Indices = EV->getIndices();
378+
if (Indices.size() != 1 || Indices[0] >= 2)
379+
return 0;
380+
381+
// The idea is that we don't want to have two extractvalue
382+
// on the same index. So we XOR (index + 1) onto VisitedIdx
383+
// such that if there is any duplication, VisitedIdx will be
384+
// zero.
385+
VisitedIdx ^= Indices[0] + 1;
386+
if (!VisitedIdx)
387+
return 0;
388+
// We have a legal index. At this point we're either going
389+
// to continue the traversal or push the leaf values into Results.
390+
// But in either cases we need to follow the order imposed by
391+
// ExtractValue's indices and swap with the last element pushed
392+
// into Queue/Results if necessary (This is also one of the main
393+
// reasons using BFS instead of DFS here, btw).
394+
395+
// When VisitedIdx equals to 0b11, we're the last visted ExtractValue.
396+
// So if the current index is 0, we need to swap. Conversely, when
397+
// we're either the first visited ExtractValue or the last operand
398+
// in Queue/Results is of index 0, there is no need to swap.
399+
bool SwapWithLast = VisitedIdx == 0b11 && Indices[0] == 0;
400+
401+
// Continue the traversal.
402+
if (match(EV->user_back(),
403+
m_Intrinsic<Intrinsic::vector_deinterleave2>()) &&
404+
EV->user_back()->hasNUses(2)) {
405+
auto *EVUsr = cast<IntrinsicInst>(EV->user_back());
406+
if (SwapWithLast)
407+
Queue.insert(Queue.end() - 1, EVUsr);
408+
else
409+
Queue.push_back(EVUsr);
410+
continue;
411+
}
412+
413+
// Save the leaf value.
414+
if (SwapWithLast)
415+
Results.insert(Results.end() - 1, EV);
416+
else
417+
Results.push_back(EV);
418+
419+
++Factor;
420+
}
421+
}
422+
423+
// Currently we only recognize power-of-two factors.
424+
// FIXME: should we assert here instead?
425+
if (Factor > 1 && isPowerOf2_32(Factor)) {
426+
interleaveLeafValues(Results);
427+
return Factor;
428+
}
429+
return 0;
430+
}
431+
251432
bool InterleavedAccessImpl::lowerInterleavedLoad(
252433
LoadInst *LI, SmallSetVector<Instruction *, 32> &DeadInsts) {
253434
if (!LI->isSimple() || isa<ScalableVectorType>(LI->getType()))
@@ -480,6 +661,81 @@ bool InterleavedAccessImpl::lowerInterleavedStore(
480661

481662
bool InterleavedAccessImpl::lowerDeinterleaveIntrinsic(
482663
IntrinsicInst *DI, SmallSetVector<Instruction *, 32> &DeadInsts) {
664+
using namespace PatternMatch;
665+
SmallVector<Value *, 8> DeInterleaveResults;
666+
unsigned Factor = getVectorDeInterleaveFactor(DI, DeInterleaveResults);
667+
668+
if (auto *VPLoad = dyn_cast<VPIntrinsic>(DI->getOperand(0));
669+
Factor && VPLoad) {
670+
if (!match(VPLoad, m_OneUse(m_Intrinsic<Intrinsic::vp_load>())))
671+
return false;
672+
673+
// Check mask operand. Handle both all-true and interleaved mask.
674+
Value *WideMask = VPLoad->getOperand(1);
675+
std::optional<Value *> Mask = getMask(WideMask, Factor);
676+
if (!Mask)
677+
return false;
678+
679+
LLVM_DEBUG(dbgs() << "IA: Found a deinterleave intrinsic: " << *DI << "\n");
680+
681+
// Since lowerInterleaveLoad expects Shuffles and LoadInst, use special
682+
// TLI function to emit target-specific interleaved instruction.
683+
if (!TLI->lowerInterleavedScalableLoad(VPLoad, *Mask, DI, Factor,
684+
DeInterleaveResults))
685+
return false;
686+
687+
DeadInsts.insert(DI);
688+
DeadInsts.insert(VPLoad);
689+
return true;
690+
}
691+
692+
// Match
693+
// %x = vp.strided.load ;; VPStridedLoad
694+
// %y = bitcast %x ;; BitCast
695+
// %y' = inttoptr %y
696+
// %z = deinterleave %y ;; DI
697+
if (Factor && isa<BitCastInst, IntToPtrInst>(DI->getOperand(0))) {
698+
auto *BitCast = cast<Instruction>(DI->getOperand(0));
699+
if (!BitCast->hasOneUse())
700+
return false;
701+
702+
Instruction *IntToPtrCast = nullptr;
703+
if (auto *BC = dyn_cast<BitCastInst>(BitCast->getOperand(0))) {
704+
IntToPtrCast = BitCast;
705+
BitCast = BC;
706+
}
707+
708+
// Match the type is
709+
// <VF x (factor * elementTy)> bitcast to <(VF * factor) x elementTy>
710+
Value *BitCastSrc = BitCast->getOperand(0);
711+
auto *BitCastSrcTy = dyn_cast<VectorType>(BitCastSrc->getType());
712+
auto *BitCastDstTy = cast<VectorType>(BitCast->getType());
713+
if (!BitCastSrcTy || (BitCastSrcTy->getElementCount() * Factor !=
714+
BitCastDstTy->getElementCount()))
715+
return false;
716+
717+
if (auto *VPStridedLoad = dyn_cast<VPIntrinsic>(BitCast->getOperand(0))) {
718+
if (VPStridedLoad->getIntrinsicID() !=
719+
Intrinsic::experimental_vp_strided_load ||
720+
!VPStridedLoad->hasOneUse())
721+
return false;
722+
723+
LLVM_DEBUG(dbgs() << "IA: Found a deinterleave intrinsic: " << *DI
724+
<< "\n");
725+
726+
if (!TLI->lowerDeinterleaveIntrinsicToStridedLoad(
727+
VPStridedLoad, DI, Factor, DeInterleaveResults))
728+
return false;
729+
730+
DeadInsts.push_back(DI);
731+
if (IntToPtrCast)
732+
DeadInsts.push_back(IntToPtrCast);
733+
DeadInsts.push_back(BitCast);
734+
DeadInsts.push_back(VPStridedLoad);
735+
return true;
736+
}
737+
}
738+
483739
LoadInst *LI = dyn_cast<LoadInst>(DI->getOperand(0));
484740

485741
if (!LI || !LI->hasOneUse() || !LI->isSimple())
@@ -504,6 +760,33 @@ bool InterleavedAccessImpl::lowerInterleaveIntrinsic(
504760
if (!II->hasOneUse())
505761
return false;
506762

763+
if (auto *VPStore = dyn_cast<VPIntrinsic>(*(II->users().begin()))) {
764+
if (VPStore->getIntrinsicID() != Intrinsic::vp_store)
765+
return false;
766+
767+
SmallVector<Value *, 8> InterleaveOperands;
768+
unsigned Factor = getVectorInterleaveFactor(II, InterleaveOperands);
769+
if (!Factor)
770+
return false;
771+
772+
Value *WideMask = VPStore->getOperand(2);
773+
std::optional<Value *> Mask = getMask(WideMask, Factor);
774+
if (!Mask)
775+
return false;
776+
777+
LLVM_DEBUG(dbgs() << "IA: Found an interleave intrinsic: " << *II << "\n");
778+
779+
// Since lowerInterleavedStore expects Shuffle and StoreInst, use special
780+
// TLI function to emit target-specific interleaved instruction.
781+
if (!TLI->lowerInterleavedScalableStore(VPStore, *Mask, II, Factor,
782+
InterleaveOperands))
783+
return false;
784+
785+
DeadInsts.insert(VPStore);
786+
DeadInsts.insert(II);
787+
return true;
788+
}
789+
507790
StoreInst *SI = dyn_cast<StoreInst>(*(II->users().begin()));
508791

509792
if (!SI || !SI->isSimple())

0 commit comments

Comments
 (0)