@@ -139,20 +139,21 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
139
139
.clampScalar (0 , s32, sXLen )
140
140
.minScalarSameAs (1 , 0 );
141
141
142
+ auto &ExtActions =
143
+ getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
144
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
145
+ typeIsLegalIntOrFPVec (1 , IntOrFPVecTys, ST)));
142
146
if (ST.is64Bit ()) {
143
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
144
- .legalFor ({{sXLen , s32}})
145
- .maxScalar (0 , sXLen );
146
-
147
+ ExtActions.legalFor ({{sXLen , s32}});
147
148
getActionDefinitionsBuilder (G_SEXT_INREG)
148
149
.customFor ({sXLen })
149
150
.maxScalar (0 , sXLen )
150
151
.lower ();
151
152
} else {
152
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar (0 , sXLen );
153
-
154
153
getActionDefinitionsBuilder (G_SEXT_INREG).maxScalar (0 , sXLen ).lower ();
155
154
}
155
+ ExtActions.customIf (typeIsLegalBoolVec (1 , BoolVecTys, ST))
156
+ .maxScalar (0 , sXLen );
156
157
157
158
// Merge/Unmerge
158
159
for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
@@ -576,6 +577,32 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
576
577
auto VScale = MIB.buildLShr (XLenTy, VLENB, MIB.buildConstant (XLenTy, 3 ));
577
578
MIB.buildMul (Dst, VScale, MIB.buildConstant (XLenTy, Val));
578
579
}
580
+ MI.eraseFromParent ();
581
+ return true ;
582
+ }
583
+
584
+ // Custom-lower extensions from mask vectors by using a vselect either with 1
585
+ // for zero/any-extension or -1 for sign-extension:
586
+ // (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
587
+ // Note that any-extension is lowered identically to zero-extension.
588
+ bool RISCVLegalizerInfo::legalizeExt (MachineInstr &MI,
589
+ MachineIRBuilder &MIB) const {
590
+
591
+ unsigned Opc = MI.getOpcode ();
592
+ assert (Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
593
+ Opc == TargetOpcode::G_ANYEXT);
594
+
595
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
596
+ Register Dst = MI.getOperand (0 ).getReg ();
597
+ Register Src = MI.getOperand (1 ).getReg ();
598
+
599
+ LLT DstTy = MRI.getType (Dst);
600
+ int64_t ExtTrueVal = Opc == TargetOpcode::G_SEXT ? -1 : 1 ;
601
+ LLT DstEltTy = DstTy.getElementType ();
602
+ auto SplatZero = MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, 0 ));
603
+ auto SplatTrue =
604
+ MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, ExtTrueVal));
605
+ MIB.buildSelect (Dst, Src, SplatTrue, SplatZero);
579
606
580
607
MI.eraseFromParent ();
581
608
return true ;
@@ -640,6 +667,10 @@ bool RISCVLegalizerInfo::legalizeCustom(
640
667
return legalizeVAStart (MI, MIRBuilder);
641
668
case TargetOpcode::G_VSCALE:
642
669
return legalizeVScale (MI, MIRBuilder);
670
+ case TargetOpcode::G_ZEXT:
671
+ case TargetOpcode::G_SEXT:
672
+ case TargetOpcode::G_ANYEXT:
673
+ return legalizeExt (MI, MIRBuilder);
643
674
}
644
675
645
676
llvm_unreachable (" expected switch to return" );
0 commit comments