@@ -66,6 +66,18 @@ typeIsLegalBoolVec(unsigned TypeIdx, std::initializer_list<LLT> BoolVecTys,
66
66
return all (typeInSet (TypeIdx, BoolVecTys), P);
67
67
}
68
68
69
+ static LegalityPredicate hasStdExtD (const RISCVSubtarget &ST) {
70
+ return [=, &ST](const LegalityQuery &Query) { return ST.hasStdExtD (); };
71
+ }
72
+ static LegalityPredicate hasVInstructionsI64 (const RISCVSubtarget &ST) {
73
+ return
74
+ [=, &ST](const LegalityQuery &Query) { return ST.hasVInstructionsI64 (); };
75
+ }
76
+ static LegalityPredicate hasVInstructionsF64 (const RISCVSubtarget &ST) {
77
+ return
78
+ [=, &ST](const LegalityQuery &Query) { return ST.hasVInstructionsI64 (); };
79
+ }
80
+
69
81
RISCVLegalizerInfo::RISCVLegalizerInfo (const RISCVSubtarget &ST)
70
82
: STI(ST), XLen(STI.getXLen()), sXLen(LLT::scalar(XLen)) {
71
83
const LLT sDoubleXLen = LLT::scalar (2 * XLen);
@@ -413,6 +425,27 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
413
425
.clampScalar (0 , sXLen , sXLen )
414
426
.customFor ({sXLen });
415
427
428
+ auto &SplatActions =
429
+ getActionDefinitionsBuilder (G_SPLAT_VECTOR)
430
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
431
+ typeIs (1 , sXLen )))
432
+ .customIf (all (typeIsLegalBoolVec (0 , BoolVecTys, ST), typeIs (1 , s1)));
433
+ // Handle case of s64 element vectors on RV32. If the subtarget does not have
434
+ // f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
435
+ // does have f64, then we don't know whether the type is an f64 or an i64,
436
+ // so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
437
+ // depending on how the instructions it consumes are legalized. They are not
438
+ // legalized yet since legalization is in reverse postorder, so we cannot
439
+ // make the decision at this moment.
440
+ if (XLen == 32 )
441
+ SplatActions
442
+ .legalIf (all (typeInSet (0 , {nxv1s64, nxv2s64, nxv4s64, nxv8s64}),
443
+ typeIs (1 , s64), hasVInstructionsF64 (ST), hasStdExtD (ST)))
444
+ .customIf (all (typeInSet (0 , {nxv1s64, nxv2s64, nxv4s64, nxv8s64}),
445
+ hasVInstructionsI64 (ST), typeIs (1 , s64)));
446
+
447
+ SplatActions.clampScalar (1 , sXLen , sXLen );
448
+
416
449
getLegacyLegalizerInfo ().computeTables ();
417
450
}
418
451
@@ -603,6 +636,118 @@ bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
603
636
return true ;
604
637
}
605
638
639
+ // / Return the type of the mask type suitable for masking the provided
640
+ // / vector type. This is simply an i1 element type vector of the same
641
+ // / (possibly scalable) length.
642
+ static LLT getMaskTypeFor (LLT VecTy) {
643
+ assert (VecTy.isVector ());
644
+ ElementCount EC = VecTy.getElementCount ();
645
+ return LLT::vector (EC, LLT::scalar (1 ));
646
+ }
647
+
648
+ // / Creates an all ones mask suitable for masking a vector of type VecTy with
649
+ // / vector length VL.
650
+ static MachineInstrBuilder buildAllOnesMask (LLT VecTy, const SrcOp &VL,
651
+ MachineIRBuilder &MIB,
652
+ MachineRegisterInfo &MRI) {
653
+ LLT MaskTy = getMaskTypeFor (VecTy);
654
+ return MIB.buildInstr (RISCV::G_VMSET_VL, {MaskTy}, {VL});
655
+ }
656
+
657
+ // / Gets the two common "VL" operands: an all-ones mask and the vector length.
658
+ // / VecTy is a scalable vector type.
659
+ static std::pair<MachineInstrBuilder, Register>
660
+ buildDefaultVLOps (const DstOp &Dst, MachineIRBuilder &MIB,
661
+ MachineRegisterInfo &MRI) {
662
+ LLT VecTy = Dst.getLLTTy (MRI);
663
+ assert (VecTy.isScalableVector () && " Expecting scalable container type" );
664
+ Register VL (RISCV::X0);
665
+ MachineInstrBuilder Mask = buildAllOnesMask (VecTy, VL, MIB, MRI);
666
+ return {Mask, VL};
667
+ }
668
+
669
+ static MachineInstrBuilder
670
+ buildSplatPartsS64WithVL (const DstOp &Dst, const SrcOp &Passthru, Register Lo,
671
+ Register Hi, Register VL, MachineIRBuilder &MIB,
672
+ MachineRegisterInfo &MRI) {
673
+ // TODO: If the Hi bits of the splat are undefined, then it's fine to just
674
+ // splat Lo even if it might be sign extended. I don't think we have
675
+ // introduced a case where we're build a s64 where the upper bits are undef
676
+ // yet.
677
+
678
+ // Fall back to a stack store and stride x0 vector load.
679
+ // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
680
+ // preprocessDAG in SDAG.
681
+ return MIB.buildInstr (RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
682
+ {Passthru, Lo, Hi, VL});
683
+ }
684
+
685
+ static MachineInstrBuilder
686
+ buildSplatSplitS64WithVL (const DstOp &Dst, const SrcOp &Passthru,
687
+ const SrcOp &Scalar, Register VL,
688
+ MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
689
+ assert (Scalar.getLLTTy (MRI) == LLT::scalar (64 ) && " Unexpected VecTy!" );
690
+ auto Unmerge = MIB.buildUnmerge (LLT::scalar (32 ), Scalar);
691
+ return buildSplatPartsS64WithVL (Dst, Passthru, Unmerge.getReg (0 ),
692
+ Unmerge.getReg (1 ), VL, MIB, MRI);
693
+ }
694
+
695
+ // Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
696
+ // legal equivalently-sized i8 type, so we can use that as a go-between.
697
+ // Splats of s1 types that have constant value can be legalized as VMSET_VL or
698
+ // VMCLR_VL.
699
+ bool RISCVLegalizerInfo::legalizeSplatVector (MachineInstr &MI,
700
+ MachineIRBuilder &MIB) const {
701
+ assert (MI.getOpcode () == TargetOpcode::G_SPLAT_VECTOR);
702
+
703
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
704
+
705
+ Register Dst = MI.getOperand (0 ).getReg ();
706
+ Register SplatVal = MI.getOperand (1 ).getReg ();
707
+
708
+ LLT VecTy = MRI.getType (Dst);
709
+ LLT XLenTy (STI.getXLenVT ());
710
+
711
+ // Handle case of s64 element vectors on rv32
712
+ if (XLenTy.getSizeInBits () == 32 &&
713
+ VecTy.getElementType ().getSizeInBits () == 64 ) {
714
+ auto [_, VL] = buildDefaultVLOps (Dst, MIB, MRI);
715
+ buildSplatSplitS64WithVL (Dst, MIB.buildUndef (VecTy), SplatVal, VL, MIB,
716
+ MRI);
717
+ MI.eraseFromParent ();
718
+ return true ;
719
+ }
720
+
721
+ // All-zeros or all-ones splats are handled specially.
722
+ MachineInstr &SplatValMI = *MRI.getVRegDef (SplatVal);
723
+ if (isAllOnesOrAllOnesSplat (SplatValMI, MRI)) {
724
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
725
+ MIB.buildInstr (RISCV::G_VMSET_VL, {Dst}, {VL});
726
+ MI.eraseFromParent ();
727
+ return true ;
728
+ }
729
+ if (isNullOrNullSplat (SplatValMI, MRI)) {
730
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
731
+ MIB.buildInstr (RISCV::G_VMCLR_VL, {Dst}, {VL});
732
+ MI.eraseFromParent ();
733
+ return true ;
734
+ }
735
+
736
+ // Handle non-constant mask splat (i.e. not sure if it's all zeros or all
737
+ // ones) by promoting it to an s8 splat.
738
+ LLT InterEltTy = LLT::scalar (8 );
739
+ LLT InterTy = VecTy.changeElementType (InterEltTy);
740
+ auto ZExtSplatVal = MIB.buildZExt (InterEltTy, SplatVal);
741
+ auto And =
742
+ MIB.buildAnd (InterEltTy, ZExtSplatVal, MIB.buildConstant (InterEltTy, 1 ));
743
+ auto LHS = MIB.buildSplatVector (InterTy, And);
744
+ auto ZeroSplat =
745
+ MIB.buildSplatVector (InterTy, MIB.buildConstant (InterEltTy, 0 ));
746
+ MIB.buildICmp (CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
747
+ MI.eraseFromParent ();
748
+ return true ;
749
+ }
750
+
606
751
bool RISCVLegalizerInfo::legalizeCustom (
607
752
LegalizerHelper &Helper, MachineInstr &MI,
608
753
LostDebugLocObserver &LocObserver) const {
@@ -666,6 +811,8 @@ bool RISCVLegalizerInfo::legalizeCustom(
666
811
case TargetOpcode::G_SEXT:
667
812
case TargetOpcode::G_ANYEXT:
668
813
return legalizeExt (MI, MIRBuilder);
814
+ case TargetOpcode::G_SPLAT_VECTOR:
815
+ return legalizeSplatVector (MI, MIRBuilder);
669
816
}
670
817
671
818
llvm_unreachable (" expected switch to return" );
0 commit comments