@@ -4606,7 +4606,33 @@ bool AArch64DAGToDAGISel::trySelectXAR(SDNode *N) {
4606
4606
return false ;
4607
4607
}
4608
4608
4609
- if (!Subtarget->hasSHA3 ())
4609
+ // We have Neon SHA3 XAR operation for v2i64 but for types
4610
+ // v4i32, v8i16, v16i8 we can use SVE operations when SVE2-SHA3
4611
+ // is available.
4612
+ EVT SVT;
4613
+ switch (VT.getSimpleVT ().SimpleTy ) {
4614
+ case MVT::v4i32:
4615
+ case MVT::v2i32:
4616
+ SVT = MVT::nxv4i32;
4617
+ break ;
4618
+ case MVT::v8i16:
4619
+ case MVT::v4i16:
4620
+ SVT = MVT::nxv8i16;
4621
+ break ;
4622
+ case MVT::v16i8:
4623
+ case MVT::v8i8:
4624
+ SVT = MVT::nxv16i8;
4625
+ break ;
4626
+ case MVT::v2i64:
4627
+ case MVT::v1i64:
4628
+ SVT = Subtarget->hasSHA3 () ? MVT::v2i64 : MVT::nxv2i64;
4629
+ break ;
4630
+ default :
4631
+ return false ;
4632
+ }
4633
+
4634
+ if ((!SVT.isScalableVector () && !Subtarget->hasSHA3 ()) ||
4635
+ (SVT.isScalableVector () && !Subtarget->hasSVE2 ()))
4610
4636
return false ;
4611
4637
4612
4638
if (N0->getOpcode () != AArch64ISD::VSHL ||
@@ -4632,41 +4658,97 @@ bool AArch64DAGToDAGISel::trySelectXAR(SDNode *N) {
4632
4658
SDValue Imm = CurDAG->getTargetConstant (
4633
4659
ShAmt, DL, N0.getOperand (1 ).getValueType (), false );
4634
4660
4635
- if (ShAmt + HsAmt != 64 )
4661
+ unsigned VTSizeInBits = VT.getScalarSizeInBits ();
4662
+ if (ShAmt + HsAmt != VTSizeInBits)
4636
4663
return false ;
4637
4664
4638
4665
if (!IsXOROperand) {
4639
4666
SDValue Zero = CurDAG->getTargetConstant (0 , DL, MVT::i64 );
4640
- SDNode *MOV =
4641
- CurDAG->getMachineNode (AArch64::MOVIv2d_ns, DL, MVT::v2i64, Zero);
4667
+ SDNode *MOV = CurDAG->getMachineNode (AArch64::MOVIv2d_ns, DL, SVT, Zero);
4642
4668
SDValue MOVIV = SDValue (MOV, 0 );
4669
+
4643
4670
R1 = N1->getOperand (0 );
4644
4671
R2 = MOVIV;
4645
4672
}
4646
4673
4647
- // If the input is a v1i64, widen to a v2i64 to use XAR.
4648
- assert ((VT == MVT::v1i64 || VT == MVT::v2i64) && " Unexpected XAR type!" );
4649
- if (VT == MVT::v1i64) {
4650
- EVT SVT = MVT::v2i64;
4674
+ if (SVT.isScalableVector ()) {
4675
+ SDValue Undef =
4676
+ SDValue (CurDAG->getMachineNode (TargetOpcode::IMPLICIT_DEF, DL, SVT), 0 );
4677
+
4678
+ if (VT.is64BitVector ()) {
4679
+ EVT QVT = VT.getDoubleNumVectorElementsVT (*CurDAG->getContext ());
4680
+
4681
+ SDValue UndefQ = SDValue (
4682
+ CurDAG->getMachineNode (TargetOpcode::IMPLICIT_DEF, DL, QVT), 0 );
4683
+ SDValue DSub = CurDAG->getTargetConstant (AArch64::dsub, DL, MVT::i32 );
4684
+
4685
+ R1 = SDValue (CurDAG->getMachineNode (AArch64::INSERT_SUBREG, DL, QVT,
4686
+ Undef, R1, DSub),
4687
+ 0 );
4688
+ if (R2.getValueType () == VT)
4689
+ R2 = SDValue (CurDAG->getMachineNode (AArch64::INSERT_SUBREG, DL, QVT,
4690
+ Undef, R2, DSub),
4691
+ 0 );
4692
+ }
4693
+
4694
+ SDValue ZSub = CurDAG->getTargetConstant (AArch64::zsub, DL, MVT::i32 );
4695
+
4696
+ R1 = SDValue (CurDAG->getMachineNode (AArch64::INSERT_SUBREG, DL, SVT, Undef,
4697
+ R1, ZSub),
4698
+ 0 );
4699
+ R2 = SDValue (CurDAG->getMachineNode (AArch64::INSERT_SUBREG, DL, SVT, Undef,
4700
+ R2, ZSub),
4701
+ 0 );
4702
+ }
4703
+
4704
+ if (!SVT.isScalableVector () && SVT != VT) {
4651
4705
SDValue Undef =
4652
4706
SDValue (CurDAG->getMachineNode (AArch64::IMPLICIT_DEF, DL, SVT), 0 );
4653
4707
SDValue DSub = CurDAG->getTargetConstant (AArch64::dsub, DL, MVT::i32 );
4708
+
4654
4709
R1 = SDValue (CurDAG->getMachineNode (AArch64::INSERT_SUBREG, DL, SVT, Undef,
4655
4710
R1, DSub),
4656
4711
0 );
4657
- if (R2.getValueType () == MVT::v1i64 )
4712
+ if (R2.getValueType () != SVT )
4658
4713
R2 = SDValue (CurDAG->getMachineNode (AArch64::INSERT_SUBREG, DL, SVT,
4659
4714
Undef, R2, DSub),
4660
4715
0 );
4661
4716
}
4662
4717
4663
4718
SDValue Ops[] = {R1, R2, Imm};
4664
- SDNode *XAR = CurDAG-> getMachineNode (AArch64::XAR, DL, MVT::v2i64, Ops) ;
4719
+ SDNode *XAR = nullptr ;
4665
4720
4666
- if (VT == MVT::v1i64) {
4721
+ if (SVT.isScalableVector ()) {
4722
+ if (auto Opc = SelectOpcodeFromVT<SelectTypeKind::Int>(
4723
+ SVT, {AArch64::XAR_ZZZI_B, AArch64::XAR_ZZZI_H, AArch64::XAR_ZZZI_S,
4724
+ AArch64::XAR_ZZZI_D}))
4725
+ XAR = CurDAG->getMachineNode (Opc, DL, SVT, Ops);
4726
+ } else {
4727
+ XAR = CurDAG->getMachineNode (AArch64::XAR, DL, SVT, Ops);
4728
+ }
4729
+
4730
+ assert (XAR && " Unexpected NULL value for XAR instruction in DAG" );
4731
+
4732
+ if (!SVT.isScalableVector () && SVT != VT) {
4667
4733
SDValue DSub = CurDAG->getTargetConstant (AArch64::dsub, DL, MVT::i32 );
4668
4734
XAR = CurDAG->getMachineNode (AArch64::EXTRACT_SUBREG, DL, VT,
4669
4735
SDValue (XAR, 0 ), DSub);
4736
+ } else if (SVT.isScalableVector ()) {
4737
+ if (VT.is64BitVector ()) {
4738
+ EVT QVT = VT.getDoubleNumVectorElementsVT (*CurDAG->getContext ());
4739
+
4740
+ SDValue ZSub = CurDAG->getTargetConstant (AArch64::zsub, DL, MVT::i32 );
4741
+ SDNode *Q = CurDAG->getMachineNode (AArch64::EXTRACT_SUBREG, DL, QVT,
4742
+ SDValue (XAR, 0 ), ZSub);
4743
+
4744
+ SDValue DSub = CurDAG->getTargetConstant (AArch64::dsub, DL, MVT::i32 );
4745
+ XAR = CurDAG->getMachineNode (AArch64::EXTRACT_SUBREG, DL, VT,
4746
+ SDValue (Q, 0 ), DSub);
4747
+ } else {
4748
+ SDValue ZSub = CurDAG->getTargetConstant (AArch64::zsub, DL, MVT::i32 );
4749
+ XAR = CurDAG->getMachineNode (AArch64::EXTRACT_SUBREG, DL, VT,
4750
+ SDValue (XAR, 0 ), ZSub);
4751
+ }
4670
4752
}
4671
4753
ReplaceNode (N, XAR);
4672
4754
return true ;
0 commit comments