@@ -139,20 +139,21 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
139
139
.clampScalar (0 , s32, sXLen )
140
140
.minScalarSameAs (1 , 0 );
141
141
142
+ auto &ExtActions =
143
+ getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
144
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
145
+ typeIsLegalIntOrFPVec (1 , IntOrFPVecTys, ST)));
142
146
if (ST.is64Bit ()) {
143
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
144
- .legalFor ({{sXLen , s32}})
145
- .maxScalar (0 , sXLen );
146
-
147
+ ExtActions.legalFor ({{sXLen , s32}});
147
148
getActionDefinitionsBuilder (G_SEXT_INREG)
148
149
.customFor ({sXLen })
149
150
.maxScalar (0 , sXLen )
150
151
.lower ();
151
152
} else {
152
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar (0 , sXLen );
153
-
154
153
getActionDefinitionsBuilder (G_SEXT_INREG).maxScalar (0 , sXLen ).lower ();
155
154
}
155
+ ExtActions.customIf (typeIsLegalBoolVec (1 , BoolVecTys, ST))
156
+ .maxScalar (0 , sXLen );
156
157
157
158
// Merge/Unmerge
158
159
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
@@ -235,7 +236,9 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
235
236
236
237
getActionDefinitionsBuilder (G_ICMP)
237
238
.legalFor ({{sXLen , sXLen }, {sXLen , p0}})
238
- .widenScalarToNextPow2 (1 )
239
+ .legalIf (all (typeIsLegalBoolVec (0 , BoolVecTys, ST),
240
+ typeIsLegalIntOrFPVec (1 , IntOrFPVecTys, ST)))
241
+ .widenScalarOrEltToNextPow2OrMinSize (1 , 8 )
239
242
.clampScalar (1 , sXLen , sXLen )
240
243
.clampScalar (0 , sXLen , sXLen );
241
244
@@ -418,6 +421,29 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
418
421
.clampScalar (0 , sXLen , sXLen )
419
422
.customFor ({sXLen });
420
423
424
+ auto &SplatActions =
425
+ getActionDefinitionsBuilder (G_SPLAT_VECTOR)
426
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
427
+ typeIs (1 , sXLen )))
428
+ .customIf (all (typeIsLegalBoolVec (0 , BoolVecTys, ST), typeIs (1 , s1)));
429
+ // Handle case of s64 element vectors on RV32. If the subtarget does not have
430
+ // f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
431
+ // does have f64, then we don't know whether the type is an f64 or an i64,
432
+ // so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
433
+ // depending on how the instructions it consumes are legalized. They are not
434
+ // legalized yet since legalization is in reverse postorder, so we cannot
435
+ // make the decision at this moment.
436
+ if (XLen == 32 ) {
437
+ if (ST.hasVInstructionsF64 () && ST.hasStdExtD ())
438
+ SplatActions.legalIf (all (
439
+ typeInSet (0 , {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs (1 , s64)));
440
+ else if (ST.hasVInstructionsI64 ())
441
+ SplatActions.customIf (all (
442
+ typeInSet (0 , {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs (1 , s64)));
443
+ }
444
+
445
+ SplatActions.clampScalar (1 , sXLen , sXLen );
446
+
421
447
getLegacyLegalizerInfo ().computeTables ();
422
448
}
423
449
@@ -576,7 +602,145 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
576
602
auto VScale = MIB.buildLShr (XLenTy, VLENB, MIB.buildConstant (XLenTy, 3 ));
577
603
MIB.buildMul (Dst, VScale, MIB.buildConstant (XLenTy, Val));
578
604
}
605
+ MI.eraseFromParent ();
606
+ return true ;
607
+ }
608
+
609
+ // Custom-lower extensions from mask vectors by using a vselect either with 1
610
+ // for zero/any-extension or -1 for sign-extension:
611
+ // (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
612
+ // Note that any-extension is lowered identically to zero-extension.
613
+ bool RISCVLegalizerInfo::legalizeExt (MachineInstr &MI,
614
+ MachineIRBuilder &MIB) const {
615
+
616
+ unsigned Opc = MI.getOpcode ();
617
+ assert (Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
618
+ Opc == TargetOpcode::G_ANYEXT);
619
+
620
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
621
+ Register Dst = MI.getOperand (0 ).getReg ();
622
+ Register Src = MI.getOperand (1 ).getReg ();
623
+
624
+ LLT DstTy = MRI.getType (Dst);
625
+ int64_t ExtTrueVal = Opc == TargetOpcode::G_SEXT ? -1 : 1 ;
626
+ LLT DstEltTy = DstTy.getElementType ();
627
+ auto SplatZero = MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, 0 ));
628
+ auto SplatTrue =
629
+ MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, ExtTrueVal));
630
+ MIB.buildSelect (Dst, Src, SplatTrue, SplatZero);
631
+
632
+ MI.eraseFromParent ();
633
+ return true ;
634
+ }
635
+
636
+ // / Return the type of the mask type suitable for masking the provided
637
+ // / vector type. This is simply an i1 element type vector of the same
638
+ // / (possibly scalable) length.
639
+ static LLT getMaskTypeFor (LLT VecTy) {
640
+ assert (VecTy.isVector ());
641
+ ElementCount EC = VecTy.getElementCount ();
642
+ return LLT::vector (EC, LLT::scalar (1 ));
643
+ }
644
+
645
+ // / Creates an all ones mask suitable for masking a vector of type VecTy with
646
+ // / vector length VL.
647
+ static MachineInstrBuilder buildAllOnesMask (LLT VecTy, const SrcOp &VL,
648
+ MachineIRBuilder &MIB,
649
+ MachineRegisterInfo &MRI) {
650
+ LLT MaskTy = getMaskTypeFor (VecTy);
651
+ return MIB.buildInstr (RISCV::G_VMSET_VL, {MaskTy}, {VL});
652
+ }
653
+
654
+ // / Gets the two common "VL" operands: an all-ones mask and the vector length.
655
+ // / VecTy is a scalable vector type.
656
+ static std::pair<MachineInstrBuilder, Register>
657
+ buildDefaultVLOps (const DstOp &Dst, MachineIRBuilder &MIB,
658
+ MachineRegisterInfo &MRI) {
659
+ LLT VecTy = Dst.getLLTTy (MRI);
660
+ assert (VecTy.isScalableVector () && " Expecting scalable container type" );
661
+ Register VL (RISCV::X0);
662
+ MachineInstrBuilder Mask = buildAllOnesMask (VecTy, VL, MIB, MRI);
663
+ return {Mask, VL};
664
+ }
665
+
666
+ static MachineInstrBuilder
667
+ buildSplatPartsS64WithVL (const DstOp &Dst, const SrcOp &Passthru, Register Lo,
668
+ Register Hi, Register VL, MachineIRBuilder &MIB,
669
+ MachineRegisterInfo &MRI) {
670
+ // TODO: If the Hi bits of the splat are undefined, then it's fine to just
671
+ // splat Lo even if it might be sign extended. I don't think we have
672
+ // introduced a case where we're build a s64 where the upper bits are undef
673
+ // yet.
674
+
675
+ // Fall back to a stack store and stride x0 vector load.
676
+ // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
677
+ // preprocessDAG in SDAG.
678
+ return MIB.buildInstr (RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
679
+ {Passthru, Lo, Hi, VL});
680
+ }
681
+
682
+ static MachineInstrBuilder
683
+ buildSplatSplitS64WithVL (const DstOp &Dst, const SrcOp &Passthru,
684
+ const SrcOp &Scalar, Register VL,
685
+ MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
686
+ assert (Scalar.getLLTTy (MRI) == LLT::scalar (64 ) && " Unexpected VecTy!" );
687
+ auto Unmerge = MIB.buildUnmerge (LLT::scalar (32 ), Scalar);
688
+ return buildSplatPartsS64WithVL (Dst, Passthru, Unmerge.getReg (0 ),
689
+ Unmerge.getReg (1 ), VL, MIB, MRI);
690
+ }
691
+
692
+ // Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
693
+ // legal equivalently-sized i8 type, so we can use that as a go-between.
694
+ // Splats of s1 types that have constant value can be legalized as VMSET_VL or
695
+ // VMCLR_VL.
696
+ bool RISCVLegalizerInfo::legalizeSplatVector (MachineInstr &MI,
697
+ MachineIRBuilder &MIB) const {
698
+ assert (MI.getOpcode () == TargetOpcode::G_SPLAT_VECTOR);
699
+
700
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
701
+
702
+ Register Dst = MI.getOperand (0 ).getReg ();
703
+ Register SplatVal = MI.getOperand (1 ).getReg ();
704
+
705
+ LLT VecTy = MRI.getType (Dst);
706
+ LLT XLenTy (STI.getXLenVT ());
707
+
708
+ // Handle case of s64 element vectors on rv32
709
+ if (XLenTy.getSizeInBits () == 32 &&
710
+ VecTy.getElementType ().getSizeInBits () == 64 ) {
711
+ auto [_, VL] = buildDefaultVLOps (Dst, MIB, MRI);
712
+ buildSplatSplitS64WithVL (Dst, MIB.buildUndef (VecTy), SplatVal, VL, MIB,
713
+ MRI);
714
+ MI.eraseFromParent ();
715
+ return true ;
716
+ }
717
+
718
+ // All-zeros or all-ones splats are handled specially.
719
+ MachineInstr &SplatValMI = *MRI.getVRegDef (SplatVal);
720
+ if (isAllOnesOrAllOnesSplat (SplatValMI, MRI)) {
721
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
722
+ MIB.buildInstr (RISCV::G_VMSET_VL, {Dst}, {VL});
723
+ MI.eraseFromParent ();
724
+ return true ;
725
+ }
726
+ if (isNullOrNullSplat (SplatValMI, MRI)) {
727
+ auto VL = buildDefaultVLOps (VecTy, MIB, MRI).second ;
728
+ MIB.buildInstr (RISCV::G_VMCLR_VL, {Dst}, {VL});
729
+ MI.eraseFromParent ();
730
+ return true ;
731
+ }
579
732
733
+ // Handle non-constant mask splat (i.e. not sure if it's all zeros or all
734
+ // ones) by promoting it to an s8 splat.
735
+ LLT InterEltTy = LLT::scalar (8 );
736
+ LLT InterTy = VecTy.changeElementType (InterEltTy);
737
+ auto ZExtSplatVal = MIB.buildZExt (InterEltTy, SplatVal);
738
+ auto And =
739
+ MIB.buildAnd (InterEltTy, ZExtSplatVal, MIB.buildConstant (InterEltTy, 1 ));
740
+ auto LHS = MIB.buildSplatVector (InterTy, And);
741
+ auto ZeroSplat =
742
+ MIB.buildSplatVector (InterTy, MIB.buildConstant (InterEltTy, 0 ));
743
+ MIB.buildICmp (CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
580
744
MI.eraseFromParent ();
581
745
return true ;
582
746
}
@@ -640,6 +804,12 @@ bool RISCVLegalizerInfo::legalizeCustom(
640
804
return legalizeVAStart (MI, MIRBuilder);
641
805
case TargetOpcode::G_VSCALE:
642
806
return legalizeVScale (MI, MIRBuilder);
807
+ case TargetOpcode::G_ZEXT:
808
+ case TargetOpcode::G_SEXT:
809
+ case TargetOpcode::G_ANYEXT:
810
+ return legalizeExt (MI, MIRBuilder);
811
+ case TargetOpcode::G_SPLAT_VECTOR:
812
+ return legalizeSplatVector (MI, MIRBuilder);
643
813
}
644
814
645
815
llvm_unreachable (" expected switch to return" );
0 commit comments