@@ -419,6 +419,17 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
419
419
.clampScalar (0 , sXLen , sXLen )
420
420
.customFor ({sXLen });
421
421
422
+ auto &SplatActions =
423
+ getActionDefinitionsBuilder (G_SPLAT_VECTOR)
424
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
425
+ typeIs (1 , sXLen )))
426
+ .customIf (all (typeIsLegalBoolVec (0 , BoolVecTys, ST), typeIs (1 , s1)));
427
+ // s64 splat on RV32 should be lowered to RISCV::G_SPLAT_VECTOR_PARTS_I64
428
+ if (XLen == 32 )
429
+ SplatActions.customIf (
430
+ all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST), typeIs (1 , s64)));
431
+ SplatActions.clampScalar (1 , sXLen , sXLen );
432
+
422
433
getLegacyLegalizerInfo ().computeTables ();
423
434
}
424
435
@@ -609,6 +620,118 @@ bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
609
620
return true ;
610
621
}
611
622
623
+ // / Return the type of the mask type suitable for masking the provided
624
+ // / vector type. This is simply an i1 element type vector of the same
625
+ // / (possibly scalable) length.
626
+ static LLT getMaskTypeFor (LLT VecTy) {
627
+ assert (VecTy.isVector ());
628
+ ElementCount EC = VecTy.getElementCount ();
629
+ return LLT::vector (EC, LLT::scalar (1 ));
630
+ }
631
+
632
+ // / Creates an all ones mask suitable for masking a vector of type VecTy with
633
+ // / vector length VL.
634
+ static MachineInstrBuilder buildAllOnesMask (LLT VecTy, const SrcOp &VL,
635
+ MachineIRBuilder &MIB,
636
+ MachineRegisterInfo &MRI) {
637
+ LLT MaskTy = getMaskTypeFor (VecTy);
638
+ return MIB.buildInstr (RISCV::G_VMSET_VL, {MaskTy}, {VL});
639
+ }
640
+
641
+ // / Gets the two common "VL" operands: an all-ones mask and the vector length.
642
+ // / VecTy is a scalable vector type.
643
+ static std::pair<MachineInstrBuilder, Register>
644
+ buildDefaultVLOps (const DstOp &Dst, MachineIRBuilder &MIB,
645
+ MachineRegisterInfo &MRI) {
646
+ LLT VecTy = Dst.getLLTTy (MRI);
647
+ assert (VecTy.isScalableVector () && " Expecting scalable container type" );
648
+ Register VL (RISCV::X0);
649
+ MachineInstrBuilder Mask = buildAllOnesMask (VecTy, VL, MIB, MRI);
650
+ return {Mask, VL};
651
+ }
652
+
653
+ static MachineInstrBuilder
654
+ buildSplatPartsS64WithVL (const DstOp &Dst, const SrcOp &Passthru, Register Lo,
655
+ Register Hi, Register VL, MachineIRBuilder &MIB,
656
+ MachineRegisterInfo &MRI) {
657
+ // TODO: If the Hi bits of the splat are undefined, then it's fine to just
658
+ // splat Lo even if it might be sign extended. I don't think we have
659
+ // introduced a case where we're build a s64 where the upper bits are undef
660
+ // yet.
661
+
662
+ // Fall back to a stack store and stride x0 vector load.
663
+ // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
664
+ // preprocessDAG in SDAG.
665
+ return MIB.buildInstr (RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
666
+ {Passthru, Lo, Hi, VL});
667
+ }
668
+
669
+ static MachineInstrBuilder
670
+ buildSplatSplitS64WithVL (const DstOp &Dst, const SrcOp &Passthru,
671
+ const SrcOp &Scalar, Register VL,
672
+ MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
673
+ assert (Scalar.getLLTTy (MRI) == LLT::scalar (64 ) && " Unexpected VecTy!" );
674
+ auto Unmerge = MIB.buildUnmerge (LLT::scalar (32 ), Scalar);
675
+ return buildSplatPartsS64WithVL (Dst, Passthru, Unmerge.getReg (0 ),
676
+ Unmerge.getReg (1 ), VL, MIB, MRI);
677
+ }
678
+
679
+ // Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
680
+ // legal equivalently-sized i8 type, so we can use that as a go-between.
681
+ // Splats of s1 types that have constant value can be legalized as VMSET_VL or
682
+ // VMCLR_VL.
683
+ bool RISCVLegalizerInfo::legalizeSplatVector (MachineInstr &MI,
684
+ MachineIRBuilder &MIB) const {
685
+ assert (MI.getOpcode () == TargetOpcode::G_SPLAT_VECTOR);
686
+
687
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
688
+
689
+ Register Dst = MI.getOperand (0 ).getReg ();
690
+ Register SplatVal = MI.getOperand (1 ).getReg ();
691
+
692
+ LLT VecTy = MRI.getType (Dst);
693
+ LLT XLenTy (STI.getXLenVT ());
694
+
695
+ // Handle case of s64 element vectors on rv32
696
+ if (XLenTy.getSizeInBits () == 32 &&
697
+ VecTy.getElementType ().getSizeInBits () == 64 ) {
698
+ auto [_, VL] = buildDefaultVLOps (Dst, MIB, MRI);
699
+ buildSplatSplitS64WithVL (Dst, MIB.buildUndef (VecTy), SplatVal, VL, MIB,
700
+ MRI);
701
+ MI.eraseFromParent ();
702
+ return true ;
703
+ }
704
+
705
+ // All-zeros or all-ones splats are handled specially.
706
+ MachineInstr &SplatValMI = *MRI.getVRegDef (SplatVal);
707
+ if (isAllOnesOrAllOnesSplat (SplatValMI, MRI)) {
708
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
709
+ MIB.buildInstr (RISCV::G_VMSET_VL, {Dst}, {VL});
710
+ MI.eraseFromParent ();
711
+ return true ;
712
+ }
713
+ if (isNullOrNullSplat (SplatValMI, MRI)) {
714
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
715
+ MIB.buildInstr (RISCV::G_VMCLR_VL, {Dst}, {VL});
716
+ MI.eraseFromParent ();
717
+ return true ;
718
+ }
719
+
720
+ // Handle non-constant mask splat (i.e. not sure if it's all zeros or all
721
+ // ones) by promoting it to an s8 splat.
722
+ LLT InterEltTy = LLT::scalar (8 );
723
+ LLT InterTy = VecTy.changeElementType (InterEltTy);
724
+ auto ZExtSplatVal = MIB.buildZExt (InterEltTy, SplatVal);
725
+ auto And =
726
+ MIB.buildAnd (InterEltTy, ZExtSplatVal, MIB.buildConstant (InterEltTy, 1 ));
727
+ auto LHS = MIB.buildSplatVector (InterTy, And);
728
+ auto ZeroSplat =
729
+ MIB.buildSplatVector (InterTy, MIB.buildConstant (InterEltTy, 0 ));
730
+ MIB.buildICmp (CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
731
+ MI.eraseFromParent ();
732
+ return true ;
733
+ }
734
+
612
735
bool RISCVLegalizerInfo::legalizeCustom (
613
736
LegalizerHelper &Helper, MachineInstr &MI,
614
737
LostDebugLocObserver &LocObserver) const {
@@ -672,6 +795,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
672
795
case TargetOpcode::G_SEXT:
673
796
case TargetOpcode::G_ANYEXT:
674
797
return legalizeExt (MI, MIRBuilder);
798
+ case TargetOpcode::G_SPLAT_VECTOR:
799
+ return legalizeSplatVector (MI, MIRBuilder);
675
800
}
676
801
677
802
llvm_unreachable (" expected switch to return" );
0 commit comments