@@ -111,18 +111,20 @@ RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
111
111
.clampScalar (0 , s32, sXLen )
112
112
.minScalarSameAs (1 , 0 );
113
113
114
+ auto &ExtActions =
115
+ getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
116
+ .customIf (typeIsLegalBoolVec (1 , BoolVecTys, ST))
117
+ .legalIf (all (typeIsLegalIntOrFPVec (0 , IntOrFPVecTys, ST),
118
+ typeIsLegalIntOrFPVec (1 , IntOrFPVecTys, ST)))
119
+ .maxScalar (0 , sXLen );
114
120
if (ST.is64Bit ()) {
115
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT})
116
- .legalFor ({{sXLen , s32}})
117
- .maxScalar (0 , sXLen );
121
+ ExtActions.legalFor ({{sXLen , s32}});
118
122
119
123
getActionDefinitionsBuilder (G_SEXT_INREG)
120
124
.customFor ({sXLen })
121
125
.maxScalar (0 , sXLen )
122
126
.lower ();
123
127
} else {
124
- getActionDefinitionsBuilder ({G_ZEXT, G_SEXT, G_ANYEXT}).maxScalar (0 , sXLen );
125
-
126
128
getActionDefinitionsBuilder (G_SEXT_INREG).maxScalar (0 , sXLen ).lower ();
127
129
}
128
130
@@ -495,6 +497,44 @@ bool RISCVLegalizerInfo::shouldBeInConstantPool(APInt APImm,
495
497
return !(!SeqLo.empty () && (SeqLo.size () + 2 ) <= STI.getMaxBuildIntsCost ());
496
498
}
497
499
500
+ // Custom-lower extensions from mask vectors by using a vselect either with 1
501
+ // for zero/any-extension or -1 for sign-extension:
502
+ // (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
503
+ // Note that any-extension is lowered identically to zero-extension.
504
+ bool RISCVLegalizerInfo::legalizeExt (MachineInstr &MI,
505
+ MachineIRBuilder &MIB) const {
506
+
507
+ unsigned Opc = MI.getOpcode ();
508
+ assert (Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
509
+ Opc == TargetOpcode::G_ANYEXT);
510
+
511
+ MachineRegisterInfo &MRI = *MIB.getMRI ();
512
+ Register Dst = MI.getOperand (0 ).getReg ();
513
+ Register Src = MI.getOperand (1 ).getReg ();
514
+
515
+ LLT DstTy = MRI.getType (Dst);
516
+ LLT SrcTy = MRI.getType (Src);
517
+
518
+ // The only custom legalization of extends we handle are vector extends.
519
+ if (!DstTy.isVector () || !SrcTy.isVector ())
520
+ return false ;
521
+
522
+ // The only custom legalization of extends is from mask types
523
+ if (SrcTy.getElementType ().getSizeInBits () != 1 )
524
+ return false ;
525
+
526
+ int64_t ExtTrueVal =
527
+ Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_ANYEXT ? 1 : -1 ;
528
+ LLT DstEltTy = DstTy.getElementType ();
529
+ auto SplatZero = MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, 0 ));
530
+ auto SplatTrue =
531
+ MIB.buildSplatVector (DstTy, MIB.buildConstant (DstEltTy, ExtTrueVal));
532
+ MIB.buildSelect (Dst, Src, SplatTrue, SplatZero);
533
+ MI.eraseFromParent ();
534
+ return true ;
535
+ }
536
+
537
+
498
538
bool RISCVLegalizerInfo::legalizeCustom (
499
539
LegalizerHelper &Helper, MachineInstr &MI,
500
540
LostDebugLocObserver &LocObserver) const {
@@ -552,6 +592,10 @@ bool RISCVLegalizerInfo::legalizeCustom(
552
592
}
553
593
case TargetOpcode::G_VASTART:
554
594
return legalizeVAStart (MI, MIRBuilder);
595
+ case TargetOpcode::G_ZEXT:
596
+ case TargetOpcode::G_SEXT:
597
+ case TargetOpcode::G_ANYEXT:
598
+ return legalizeExt (MI, MIRBuilder);
555
599
}
556
600
557
601
llvm_unreachable (" expected switch to return" );
0 commit comments