@@ -418,6 +418,19 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
418
418
.clampScalar (0 , sXLen , sXLen )
419
419
.customFor ({sXLen });
420
420
421
+ auto &SplatActions =
422
+ getActionDefinitionsBuilder (G_SPLAT_VECTOR)
423
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
424
+ typeIs (1 , sXLen )))
425
+ .customIf (all (typeIsLegalBoolVec (0 , BoolVecTys, ST), typeIs (1 , s1)));
426
+ // Handle case of s64 element vectors on RV32. We don't know whether the type
427
+ // is an f64 or an i64. As a result mark it as legal here and lower to
428
+ // G_SPLAT_VECTOR_SPLIT_64_VL or G_VFMV_VL later.
429
+ if (XLen == 32 )
430
+ SplatActions.legalIf (
431
+ all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST), typeIs (1 , s64)));
432
+ SplatActions.clampScalar (1 , sXLen , sXLen );
433
+
421
434
getLegacyLegalizerInfo ().computeTables ();
422
435
}
423
436
@@ -608,6 +621,82 @@ bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
608
621
return true ;
609
622
}
610
623
624
+ // / Return the type of the mask type suitable for masking the provided
625
+ // / vector type. This is simply an i1 element type vector of the same
626
+ // / (possibly scalable) length.
627
+ static LLT getMaskTypeFor (LLT VecTy) {
628
+ assert (VecTy.isVector ());
629
+ ElementCount EC = VecTy.getElementCount ();
630
+ return LLT::vector (EC, LLT::scalar (1 ));
631
+ }
632
+
633
+ // / Creates an all ones mask suitable for masking a vector of type VecTy with
634
+ // / vector length VL.
635
+ static MachineInstrBuilder buildAllOnesMask (LLT VecTy, const SrcOp &VL,
636
+ MachineIRBuilder &MIB,
637
+ MachineRegisterInfo &MRI) {
638
+ LLT MaskTy = getMaskTypeFor (VecTy);
639
+ return MIB.buildInstr (RISCV::G_VMSET_VL, {MaskTy}, {VL});
640
+ }
641
+
642
+ // / Gets the two common "VL" operands: an all-ones mask and the vector length.
643
+ // / VecTy is a scalable vector type.
644
+ static std::pair<MachineInstrBuilder, Register>
645
+ buildDefaultVLOps (const DstOp &Dst, MachineIRBuilder &MIB,
646
+ MachineRegisterInfo &MRI) {
647
+ LLT VecTy = Dst.getLLTTy (MRI);
648
+ assert (VecTy.isScalableVector () && " Expecting scalable container type" );
649
+ Register VL (RISCV::X0);
650
+ MachineInstrBuilder Mask = buildAllOnesMask (VecTy, VL, MIB, MRI);
651
+ return {Mask, VL};
652
+ }
653
+
654
+ // Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
655
+ // legal equivalently-sized i8 type, so we can use that as a go-between.
656
+ // Splats of s1 types that have constant value can be legalized as VMSET_VL or
657
+ // VMCLR_VL.
658
+ bool RISCVLegalizerInfo::legalizeSplatVector (MachineInstr &MI,
659
+ MachineIRBuilder &MIB) const {
660
+ assert (MI.getOpcode () == TargetOpcode::G_SPLAT_VECTOR);
661
+
662
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
663
+
664
+ Register Dst = MI.getOperand (0 ).getReg ();
665
+ Register SplatVal = MI.getOperand (1 ).getReg ();
666
+
667
+ LLT VecTy = MRI.getType (Dst);
668
+ LLT XLenTy (STI.getXLenVT ());
669
+
670
+ // All-zeros or all-ones splats are handled specially.
671
+ MachineInstr &SplatValMI = *MRI.getVRegDef (SplatVal);
672
+ if (isAllOnesOrAllOnesSplat (SplatValMI, MRI)) {
673
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
674
+ MIB.buildInstr (RISCV::G_VMSET_VL, {Dst}, {VL});
675
+ MI.eraseFromParent ();
676
+ return true ;
677
+ }
678
+ if (isNullOrNullSplat (SplatValMI, MRI)) {
679
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
680
+ MIB.buildInstr (RISCV::G_VMCLR_VL, {Dst}, {VL});
681
+ MI.eraseFromParent ();
682
+ return true ;
683
+ }
684
+
685
+ // Handle non-constant mask splat (i.e. not sure if it's all zeros or all
686
+ // ones) by promoting it to an s8 splat.
687
+ LLT InterEltTy = LLT::scalar (8 );
688
+ LLT InterTy = VecTy.changeElementType (InterEltTy);
689
+ auto ZExtSplatVal = MIB.buildZExt (InterEltTy, SplatVal);
690
+ auto And =
691
+ MIB.buildAnd (InterEltTy, ZExtSplatVal, MIB.buildConstant (InterEltTy, 1 ));
692
+ auto LHS = MIB.buildSplatVector (InterTy, And);
693
+ auto ZeroSplat =
694
+ MIB.buildSplatVector (InterTy, MIB.buildConstant (InterEltTy, 0 ));
695
+ MIB.buildICmp (CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
696
+ MI.eraseFromParent ();
697
+ return true ;
698
+ }
699
+
611
700
bool RISCVLegalizerInfo::legalizeCustom (
612
701
LegalizerHelper &Helper, MachineInstr &MI,
613
702
LostDebugLocObserver &LocObserver) const {
@@ -671,6 +760,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
671
760
case TargetOpcode::G_SEXT:
672
761
case TargetOpcode::G_ANYEXT:
673
762
return legalizeExt (MI, MIRBuilder);
763
+ case TargetOpcode::G_SPLAT_VECTOR:
764
+ return legalizeSplatVector (MI, MIRBuilder);
674
765
}
675
766
676
767
llvm_unreachable (" expected switch to return" );
0 commit comments