@@ -547,6 +547,39 @@ struct BinaryOpc_match {
547
547
}
548
548
};
549
549
550
+ // / Matching while capturing mask
551
+ template <typename T0, typename T1, typename T2> struct SDShuffle_match {
552
+ T0 Op1;
553
+ T1 Op2;
554
+ T2 Mask;
555
+
556
+ SDShuffle_match (const T0 &Op1, const T1 &Op2, const T2 &Mask)
557
+ : Op1(Op1), Op2(Op2), Mask(Mask) {}
558
+
559
+ template <typename MatchContext>
560
+ bool match (const MatchContext &Ctx, SDValue N) {
561
+ if (auto *I = dyn_cast<ShuffleVectorSDNode>(N)) {
562
+ return Op1.match (Ctx, I->getOperand (0 )) &&
563
+ Op2.match (Ctx, I->getOperand (1 )) && Mask.match (I->getMask ());
564
+ }
565
+ return false ;
566
+ }
567
+ };
568
+ struct m_Mask {
569
+ ArrayRef<int > &MaskRef;
570
+ m_Mask (ArrayRef<int > &MaskRef) : MaskRef(MaskRef) {}
571
+ bool match (ArrayRef<int > Mask) {
572
+ MaskRef = Mask;
573
+ return true ;
574
+ }
575
+ };
576
+
577
+ struct m_SpecificMask {
578
+ ArrayRef<int > MaskRef;
579
+ m_SpecificMask (ArrayRef<int > MaskRef) : MaskRef(MaskRef) {}
580
+ bool match (ArrayRef<int > Mask) { return MaskRef == Mask; }
581
+ };
582
+
550
583
template <typename LHS_P, typename RHS_P, typename Pred_t,
551
584
bool Commutable = false , bool ExcludeChain = false >
552
585
struct MaxMin_match {
@@ -797,6 +830,17 @@ inline BinaryOpc_match<LHS, RHS> m_FRem(const LHS &L, const RHS &R) {
797
830
return BinaryOpc_match<LHS, RHS>(ISD::FREM, L, R);
798
831
}
799
832
833
+ template <typename V1_t, typename V2_t>
834
+ inline BinaryOpc_match<V1_t, V2_t> m_Shuffle (const V1_t &v1, const V2_t &v2) {
835
+ return BinaryOpc_match<V1_t, V2_t>(ISD::VECTOR_SHUFFLE, v1, v2);
836
+ }
837
+
838
+ template <typename V1_t, typename V2_t, typename Mask_t>
839
+ inline SDShuffle_match<V1_t, V2_t, Mask_t>
840
+ m_Shuffle (const V1_t &v1, const V2_t &v2, const Mask_t &mask) {
841
+ return SDShuffle_match<V1_t, V2_t, Mask_t>(v1, v2, mask);
842
+ }
843
+
800
844
template <typename LHS, typename RHS>
801
845
inline BinaryOpc_match<LHS, RHS> m_ExtractElt (const LHS &Vec, const RHS &Idx) {
802
846
return BinaryOpc_match<LHS, RHS>(ISD::EXTRACT_VECTOR_ELT, Vec, Idx);
0 commit comments