@@ -139,20 +139,21 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
139
139
.clampScalar (0 , s32, sXLen )
140
140
.minScalarSameAs (1 , 0 );
141
141
142
+ auto &ExtActions =
143
+ getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
144
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
145
+ typeIsLegalIntOrFPVec (1 , IntOrFPVecTys, ST)));
142
146
if (ST.is64Bit ()) {
143
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
144
- .legalFor ({{sXLen , s32}})
145
- .maxScalar (0 , sXLen );
146
-
147
+ ExtActions.legalFor ({{sXLen , s32}});
147
148
getActionDefinitionsBuilder (G_SEXT_INREG)
148
149
.customFor ({sXLen })
149
150
.maxScalar (0 , sXLen )
150
151
.lower ();
151
152
} else {
152
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar (0 , sXLen );
153
-
154
153
getActionDefinitionsBuilder (G_SEXT_INREG).maxScalar (0 , sXLen ).lower ();
155
154
}
155
+ ExtActions.customIf (typeIsLegalBoolVec (1 , BoolVecTys, ST))
156
+ .maxScalar (0 , sXLen );
156
157
157
158
// Merge/Unmerge
158
159
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
@@ -570,6 +571,33 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
570
571
auto VScale = MIB.buildLShr (XLenTy, VLENB, MIB.buildConstant (XLenTy, 3 ));
571
572
MIB.buildMul (Dst, VScale, MIB.buildConstant (XLenTy, Val));
572
573
}
574
+ MI.eraseFromParent ();
575
+ return true ;
576
+ }
577
+
578
+ // Custom-lower extensions from mask vectors by using a vselect either with 1
579
+ // for zero/any-extension or -1 for sign-extension:
580
+ // (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
581
+ // Note that any-extension is lowered identically to zero-extension.
582
+ bool RISCVLegalizerInfo::legalizeExt (MachineInstr &MI,
583
+ MachineIRBuilder &MIB) const {
584
+
585
+ unsigned Opc = MI.getOpcode ();
586
+ assert (Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
587
+ Opc == TargetOpcode::G_ANYEXT);
588
+
589
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
590
+ Register Dst = MI.getOperand (0 ).getReg ();
591
+ Register Src = MI.getOperand (1 ).getReg ();
592
+
593
+ LLT DstTy = MRI.getType (Dst);
594
+ int64_t ExtTrueVal =
595
+ Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_ANYEXT ? 1 : -1 ;
596
+ LLT DstEltTy = DstTy.getElementType ();
597
+ auto SplatZero = MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, 0 ));
598
+ auto SplatTrue =
599
+ MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, ExtTrueVal));
600
+ MIB.buildSelect (Dst, Src, SplatTrue, SplatZero);
573
601
574
602
MI.eraseFromParent ();
575
603
return true ;
@@ -634,6 +662,10 @@ bool RISCVLegalizerInfo::legalizeCustom(
634
662
return legalizeVAStart (MI, MIRBuilder);
635
663
case TargetOpcode::G_VSCALE:
636
664
return legalizeVScale (MI, MIRBuilder);
665
+ case TargetOpcode::G_ZEXT:
666
+ case TargetOpcode::G_SEXT:
667
+ case TargetOpcode::G_ANYEXT:
668
+ return legalizeExt (MI, MIRBuilder);
637
669
}
638
670
639
671
llvm_unreachable (" expected switch to return" );
0 commit comments