@@ -419,6 +419,29 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
419
419
.clampScalar (0 , sXLen , sXLen )
420
420
.customFor ({sXLen });
421
421
422
+ auto &SplatActions =
423
+ getActionDefinitionsBuilder (G_SPLAT_VECTOR)
424
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
425
+ typeIs (1 , sXLen )))
426
+ .customIf (all (typeIsLegalBoolVec (0 , BoolVecTys, ST), typeIs (1 , s1)));
427
+ // Handle case of s64 element vectors on RV32. If the subtarget does not have
428
+ // f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
429
+ // does have f64, then we don't know whether the type is an f64 or an i64,
430
+ // so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
431
+ // depending on how the instructions it consumes are legalized. They are not
432
+ // legalized yet since legalization is in reverse postorder, so we cannot
433
+ // make the decision at this moment.
434
+ if (XLen == 32 ) {
435
+ if (ST.hasVInstructionsF64 () && ST.hasStdExtD ())
436
+ SplatActions.legalIf (all (
437
+ typeInSet (0 , {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs (1 , s64)));
438
+ else if (ST.hasVInstructionsI64 ())
439
+ SplatActions.customIf (all (
440
+ typeInSet (0 , {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs (1 , s64)));
441
+ }
442
+
443
+ SplatActions.clampScalar (1 , sXLen , sXLen );
444
+
422
445
getLegacyLegalizerInfo ().computeTables ();
423
446
}
424
447
@@ -608,6 +631,118 @@ bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
608
631
return true ;
609
632
}
610
633
634
+ // / Return the type of the mask type suitable for masking the provided
635
+ // / vector type. This is simply an i1 element type vector of the same
636
+ // / (possibly scalable) length.
637
+ static LLT getMaskTypeFor (LLT VecTy) {
638
+ assert (VecTy.isVector ());
639
+ ElementCount EC = VecTy.getElementCount ();
640
+ return LLT::vector (EC, LLT::scalar (1 ));
641
+ }
642
+
643
+ // / Creates an all ones mask suitable for masking a vector of type VecTy with
644
+ // / vector length VL.
645
+ static MachineInstrBuilder buildAllOnesMask (LLT VecTy, const SrcOp &VL,
646
+ MachineIRBuilder &MIB,
647
+ MachineRegisterInfo &MRI) {
648
+ LLT MaskTy = getMaskTypeFor (VecTy);
649
+ return MIB.buildInstr (RISCV::G_VMSET_VL, {MaskTy}, {VL});
650
+ }
651
+
652
+ // / Gets the two common "VL" operands: an all-ones mask and the vector length.
653
+ // / VecTy is a scalable vector type.
654
+ static std::pair<MachineInstrBuilder, Register>
655
+ buildDefaultVLOps (const DstOp &Dst, MachineIRBuilder &MIB,
656
+ MachineRegisterInfo &MRI) {
657
+ LLT VecTy = Dst.getLLTTy (MRI);
658
+ assert (VecTy.isScalableVector () && " Expecting scalable container type" );
659
+ Register VL (RISCV::X0);
660
+ MachineInstrBuilder Mask = buildAllOnesMask (VecTy, VL, MIB, MRI);
661
+ return {Mask, VL};
662
+ }
663
+
664
+ static MachineInstrBuilder
665
+ buildSplatPartsS64WithVL (const DstOp &Dst, const SrcOp &Passthru, Register Lo,
666
+ Register Hi, Register VL, MachineIRBuilder &MIB,
667
+ MachineRegisterInfo &MRI) {
668
+ // TODO: If the Hi bits of the splat are undefined, then it's fine to just
669
+ // splat Lo even if it might be sign extended. I don't think we have
670
+ // introduced a case where we're build a s64 where the upper bits are undef
671
+ // yet.
672
+
673
+ // Fall back to a stack store and stride x0 vector load.
674
+ // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
675
+ // preprocessDAG in SDAG.
676
+ return MIB.buildInstr (RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
677
+ {Passthru, Lo, Hi, VL});
678
+ }
679
+
680
+ static MachineInstrBuilder
681
+ buildSplatSplitS64WithVL (const DstOp &Dst, const SrcOp &Passthru,
682
+ const SrcOp &Scalar, Register VL,
683
+ MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
684
+ assert (Scalar.getLLTTy (MRI) == LLT::scalar (64 ) && " Unexpected VecTy!" );
685
+ auto Unmerge = MIB.buildUnmerge (LLT::scalar (32 ), Scalar);
686
+ return buildSplatPartsS64WithVL (Dst, Passthru, Unmerge.getReg (0 ),
687
+ Unmerge.getReg (1 ), VL, MIB, MRI);
688
+ }
689
+
690
+ // Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
691
+ // legal equivalently-sized i8 type, so we can use that as a go-between.
692
+ // Splats of s1 types that have constant value can be legalized as VMSET_VL or
693
+ // VMCLR_VL.
694
+ bool RISCVLegalizerInfo::legalizeSplatVector (MachineInstr &MI,
695
+ MachineIRBuilder &MIB) const {
696
+ assert (MI.getOpcode () == TargetOpcode::G_SPLAT_VECTOR);
697
+
698
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
699
+
700
+ Register Dst = MI.getOperand (0 ).getReg ();
701
+ Register SplatVal = MI.getOperand (1 ).getReg ();
702
+
703
+ LLT VecTy = MRI.getType (Dst);
704
+ LLT XLenTy (STI.getXLenVT ());
705
+
706
+ // Handle case of s64 element vectors on rv32
707
+ if (XLenTy.getSizeInBits () == 32 &&
708
+ VecTy.getElementType ().getSizeInBits () == 64 ) {
709
+ auto [_, VL] = buildDefaultVLOps (Dst, MIB, MRI);
710
+ buildSplatSplitS64WithVL (Dst, MIB.buildUndef (VecTy), SplatVal, VL, MIB,
711
+ MRI);
712
+ MI.eraseFromParent ();
713
+ return true ;
714
+ }
715
+
716
+ // All-zeros or all-ones splats are handled specially.
717
+ MachineInstr &SplatValMI = *MRI.getVRegDef (SplatVal);
718
+ if (isAllOnesOrAllOnesSplat (SplatValMI, MRI)) {
719
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
720
+ MIB.buildInstr (RISCV::G_VMSET_VL, {Dst}, {VL});
721
+ MI.eraseFromParent ();
722
+ return true ;
723
+ }
724
+ if (isNullOrNullSplat (SplatValMI, MRI)) {
725
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
726
+ MIB.buildInstr (RISCV::G_VMCLR_VL, {Dst}, {VL});
727
+ MI.eraseFromParent ();
728
+ return true ;
729
+ }
730
+
731
+ // Handle non-constant mask splat (i.e. not sure if it's all zeros or all
732
+ // ones) by promoting it to an s8 splat.
733
+ LLT InterEltTy = LLT::scalar (8 );
734
+ LLT InterTy = VecTy.changeElementType (InterEltTy);
735
+ auto ZExtSplatVal = MIB.buildZExt (InterEltTy, SplatVal);
736
+ auto And =
737
+ MIB.buildAnd (InterEltTy, ZExtSplatVal, MIB.buildConstant (InterEltTy, 1 ));
738
+ auto LHS = MIB.buildSplatVector (InterTy, And);
739
+ auto ZeroSplat =
740
+ MIB.buildSplatVector (InterTy, MIB.buildConstant (InterEltTy, 0 ));
741
+ MIB.buildICmp (CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
742
+ MI.eraseFromParent ();
743
+ return true ;
744
+ }
745
+
611
746
bool RISCVLegalizerInfo::legalizeCustom (
612
747
LegalizerHelper &Helper, MachineInstr &MI,
613
748
LostDebugLocObserver &LocObserver) const {
@@ -671,6 +806,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
671
806
case TargetOpcode::G_SEXT:
672
807
case TargetOpcode::G_ANYEXT:
673
808
return legalizeExt (MI, MIRBuilder);
809
+ case TargetOpcode::G_SPLAT_VECTOR:
810
+ return legalizeSplatVector (MI, MIRBuilder);
674
811
}
675
812
676
813
llvm_unreachable (" expected switch to return" );
0 commit comments