@@ -139,17 +139,23 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
139
139
.clampScalar (0 , s32, sXLen )
140
140
.minScalarSameAs (1 , 0 );
141
141
142
+ auto &ExtActions =
143
+ getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
144
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
145
+ typeIsLegalIntOrFPVec (1 , IntOrFPVecTys, ST)));
146
+
142
147
if (ST.is64Bit ()) {
143
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT })
144
- .legalFor ({{ sXLen , s32}} )
148
+ ExtActions. legalFor ({{ sXLen , s32} })
149
+ .customIf ( typeIsLegalBoolVec ( 1 , BoolVecTys, ST) )
145
150
.maxScalar (0 , sXLen );
146
151
147
152
getActionDefinitionsBuilder (G_SEXT_INREG)
148
153
.customFor ({sXLen })
149
154
.maxScalar (0 , sXLen )
150
155
.lower ();
151
156
} else {
152
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar (0 , sXLen );
157
+ ExtActions.customIf (typeIsLegalBoolVec (1 , BoolVecTys, ST))
158
+ .maxScalar (0 , sXLen );
153
159
154
160
getActionDefinitionsBuilder (G_SEXT_INREG).maxScalar (0 , sXLen ).lower ();
155
161
}
@@ -570,6 +576,33 @@ bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
570
576
auto VScale = MIB.buildLShr (XLenTy, VLENB, MIB.buildConstant (XLenTy, 3 ));
571
577
MIB.buildMul (Dst, VScale, MIB.buildConstant (XLenTy, Val));
572
578
}
579
+ MI.eraseFromParent ();
580
+ return true ;
581
+ }
582
+
583
+ // Custom-lower extensions from mask vectors by using a vselect either with 1
584
+ // for zero/any-extension or -1 for sign-extension:
585
+ // (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
586
+ // Note that any-extension is lowered identically to zero-extension.
587
+ bool RISCVLegalizerInfo::legalizeExt (MachineInstr &MI,
588
+ MachineIRBuilder &MIB) const {
589
+
590
+ unsigned Opc = MI.getOpcode ();
591
+ assert (Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
592
+ Opc == TargetOpcode::G_ANYEXT);
593
+
594
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
595
+ Register Dst = MI.getOperand (0 ).getReg ();
596
+ Register Src = MI.getOperand (1 ).getReg ();
597
+
598
+ LLT DstTy = MRI.getType (Dst);
599
+ int64_t ExtTrueVal =
600
+ Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_ANYEXT ? 1 : -1 ;
601
+ LLT DstEltTy = DstTy.getElementType ();
602
+ auto SplatZero = MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, 0 ));
603
+ auto SplatTrue =
604
+ MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, ExtTrueVal));
605
+ MIB.buildSelect (Dst, Src, SplatTrue, SplatZero);
573
606
574
607
MI.eraseFromParent ();
575
608
return true ;
@@ -634,6 +667,10 @@ bool RISCVLegalizerInfo::legalizeCustom(
634
667
return legalizeVAStart (MI, MIRBuilder);
635
668
case TargetOpcode::G_VSCALE:
636
669
return legalizeVScale (MI, MIRBuilder);
670
+ case TargetOpcode::G_ZEXT:
671
+ case TargetOpcode::G_SEXT:
672
+ case TargetOpcode::G_ANYEXT:
673
+ return legalizeExt (MI, MIRBuilder);
637
674
}
638
675
639
676
llvm_unreachable (" expected switch to return" );
0 commit comments