@@ -419,6 +419,29 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
419
419
.clampScalar (0 , sXLen , sXLen )
420
420
.customFor ({sXLen });
421
421
422
+ auto &SplatActions =
423
+ getActionDefinitionsBuilder (G_SPLAT_VECTOR)
424
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
425
+ typeIs (1 , sXLen )))
426
+ .customIf (all (typeIsLegalBoolVec (0 , BoolVecTys, ST), typeIs (1 , s1)));
427
+ // Handle case of s64 element vectors on RV32. If the subtarget does not have
428
+ // f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
429
+ // does have f64, then we don't know whether the type is an f64 or an i64,
430
+ // so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
431
+ // depending on how the instructions it consumes are legalized. They are not
432
+ // legalized yet since legalization is in reverse postorder, so we cannot
433
+ // make the decision at this moment.
434
+ if (XLen == 32 ) {
435
+ if (ST.hasVInstructionsF64 () && ST.hasStdExtD ())
436
+ SplatActions.legalIf (all (
437
+ typeInSet (0 , {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs (1 , s64)));
438
+ else if (ST.hasVInstructionsI64 ())
439
+ SplatActions.customIf (all (
440
+ typeInSet (0 , {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs (1 , s64)));
441
+ }
442
+
443
+ SplatActions.clampScalar (1 , sXLen , sXLen );
444
+
422
445
getLegacyLegalizerInfo ().computeTables ();
423
446
}
424
447
@@ -609,6 +632,118 @@ bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
609
632
return true ;
610
633
}
611
634
635
+ // / Return the type of the mask type suitable for masking the provided
636
+ // / vector type. This is simply an i1 element type vector of the same
637
+ // / (possibly scalable) length.
638
+ static LLT getMaskTypeFor (LLT VecTy) {
639
+ assert (VecTy.isVector ());
640
+ ElementCount EC = VecTy.getElementCount ();
641
+ return LLT::vector (EC, LLT::scalar (1 ));
642
+ }
643
+
644
+ // / Creates an all ones mask suitable for masking a vector of type VecTy with
645
+ // / vector length VL.
646
+ static MachineInstrBuilder buildAllOnesMask (LLT VecTy, const SrcOp &VL,
647
+ MachineIRBuilder &MIB,
648
+ MachineRegisterInfo &MRI) {
649
+ LLT MaskTy = getMaskTypeFor (VecTy);
650
+ return MIB.buildInstr (RISCV::G_VMSET_VL, {MaskTy}, {VL});
651
+ }
652
+
653
+ // / Gets the two common "VL" operands: an all-ones mask and the vector length.
654
+ // / VecTy is a scalable vector type.
655
+ static std::pair<MachineInstrBuilder, Register>
656
+ buildDefaultVLOps (const DstOp &Dst, MachineIRBuilder &MIB,
657
+ MachineRegisterInfo &MRI) {
658
+ LLT VecTy = Dst.getLLTTy (MRI);
659
+ assert (VecTy.isScalableVector () && " Expecting scalable container type" );
660
+ Register VL (RISCV::X0);
661
+ MachineInstrBuilder Mask = buildAllOnesMask (VecTy, VL, MIB, MRI);
662
+ return {Mask, VL};
663
+ }
664
+
665
+ static MachineInstrBuilder
666
+ buildSplatPartsS64WithVL (const DstOp &Dst, const SrcOp &Passthru, Register Lo,
667
+ Register Hi, Register VL, MachineIRBuilder &MIB,
668
+ MachineRegisterInfo &MRI) {
669
+ // TODO: If the Hi bits of the splat are undefined, then it's fine to just
670
+ // splat Lo even if it might be sign extended. I don't think we have
671
+ // introduced a case where we're build a s64 where the upper bits are undef
672
+ // yet.
673
+
674
+ // Fall back to a stack store and stride x0 vector load.
675
+ // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
676
+ // preprocessDAG in SDAG.
677
+ return MIB.buildInstr (RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
678
+ {Passthru, Lo, Hi, VL});
679
+ }
680
+
681
+ static MachineInstrBuilder
682
+ buildSplatSplitS64WithVL (const DstOp &Dst, const SrcOp &Passthru,
683
+ const SrcOp &Scalar, Register VL,
684
+ MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
685
+ assert (Scalar.getLLTTy (MRI) == LLT::scalar (64 ) && " Unexpected VecTy!" );
686
+ auto Unmerge = MIB.buildUnmerge (LLT::scalar (32 ), Scalar);
687
+ return buildSplatPartsS64WithVL (Dst, Passthru, Unmerge.getReg (0 ),
688
+ Unmerge.getReg (1 ), VL, MIB, MRI);
689
+ }
690
+
691
+ // Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
692
+ // legal equivalently-sized i8 type, so we can use that as a go-between.
693
+ // Splats of s1 types that have constant value can be legalized as VMSET_VL or
694
+ // VMCLR_VL.
695
+ bool RISCVLegalizerInfo::legalizeSplatVector (MachineInstr &MI,
696
+ MachineIRBuilder &MIB) const {
697
+ assert (MI.getOpcode () == TargetOpcode::G_SPLAT_VECTOR);
698
+
699
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
700
+
701
+ Register Dst = MI.getOperand (0 ).getReg ();
702
+ Register SplatVal = MI.getOperand (1 ).getReg ();
703
+
704
+ LLT VecTy = MRI.getType (Dst);
705
+ LLT XLenTy (STI.getXLenVT ());
706
+
707
+ // Handle case of s64 element vectors on rv32
708
+ if (XLenTy.getSizeInBits () == 32 &&
709
+ VecTy.getElementType ().getSizeInBits () == 64 ) {
710
+ auto [_, VL] = buildDefaultVLOps (Dst, MIB, MRI);
711
+ buildSplatSplitS64WithVL (Dst, MIB.buildUndef (VecTy), SplatVal, VL, MIB,
712
+ MRI);
713
+ MI.eraseFromParent ();
714
+ return true ;
715
+ }
716
+
717
+ // All-zeros or all-ones splats are handled specially.
718
+ MachineInstr &SplatValMI = *MRI.getVRegDef (SplatVal);
719
+ if (isAllOnesOrAllOnesSplat (SplatValMI, MRI)) {
720
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
721
+ MIB.buildInstr (RISCV::G_VMSET_VL, {Dst}, {VL});
722
+ MI.eraseFromParent ();
723
+ return true ;
724
+ }
725
+ if (isNullOrNullSplat (SplatValMI, MRI)) {
726
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
727
+ MIB.buildInstr (RISCV::G_VMCLR_VL, {Dst}, {VL});
728
+ MI.eraseFromParent ();
729
+ return true ;
730
+ }
731
+
732
+ // Handle non-constant mask splat (i.e. not sure if it's all zeros or all
733
+ // ones) by promoting it to an s8 splat.
734
+ LLT InterEltTy = LLT::scalar (8 );
735
+ LLT InterTy = VecTy.changeElementType (InterEltTy);
736
+ auto ZExtSplatVal = MIB.buildZExt (InterEltTy, SplatVal);
737
+ auto And =
738
+ MIB.buildAnd (InterEltTy, ZExtSplatVal, MIB.buildConstant (InterEltTy, 1 ));
739
+ auto LHS = MIB.buildSplatVector (InterTy, And);
740
+ auto ZeroSplat =
741
+ MIB.buildSplatVector (InterTy, MIB.buildConstant (InterEltTy, 0 ));
742
+ MIB.buildICmp (CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
743
+ MI.eraseFromParent ();
744
+ return true ;
745
+ }
746
+
612
747
bool RISCVLegalizerInfo::legalizeCustom (
613
748
LegalizerHelper &Helper, MachineInstr &MI,
614
749
LostDebugLocObserver &LocObserver) const {
@@ -672,6 +807,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
672
807
case TargetOpcode::G_SEXT:
673
808
case TargetOpcode::G_ANYEXT:
674
809
return legalizeExt (MI, MIRBuilder);
810
+ case TargetOpcode::G_SPLAT_VECTOR:
811
+ return legalizeSplatVector (MI, MIRBuilder);
675
812
}
676
813
677
814
llvm_unreachable (" expected switch to return" );
0 commit comments