@@ -4423,6 +4423,62 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
4423
4423
return SDValue ();
4424
4424
}
4425
4425
4426
+ // Helper to peek through bitops/trunc/setcc to determine size of source vector.
4427
+ // Allows BITCASTCombine to determine what size vector generated a <X x i1>.
4428
+ static bool checkBitcastSrcVectorSize (SDValue Src, unsigned Size,
4429
+ unsigned Depth) {
4430
+ // Limit recursion.
4431
+ if (Depth >= SelectionDAG::MaxRecursionDepth)
4432
+ return false ;
4433
+ switch (Src.getOpcode ()) {
4434
+ case ISD::SETCC:
4435
+ case ISD::TRUNCATE:
4436
+ return Src.getOperand (0 ).getValueSizeInBits () == Size;
4437
+ case ISD::FREEZE:
4438
+ return checkBitcastSrcVectorSize (Src.getOperand (0 ), Size, Depth + 1 );
4439
+ case ISD::AND:
4440
+ case ISD::XOR:
4441
+ case ISD::OR:
4442
+ return checkBitcastSrcVectorSize (Src.getOperand (0 ), Size, Depth + 1 ) &&
4443
+ checkBitcastSrcVectorSize (Src.getOperand (1 ), Size, Depth + 1 );
4444
+ case ISD::SELECT:
4445
+ case ISD::VSELECT:
4446
+ return Src.getOperand (0 ).getScalarValueSizeInBits () == 1 &&
4447
+ checkBitcastSrcVectorSize (Src.getOperand (1 ), Size, Depth + 1 ) &&
4448
+ checkBitcastSrcVectorSize (Src.getOperand (2 ), Size, Depth + 1 );
4449
+ case ISD::BUILD_VECTOR:
4450
+ return ISD::isBuildVectorAllZeros (Src.getNode ()) ||
4451
+ ISD::isBuildVectorAllOnes (Src.getNode ());
4452
+ }
4453
+ return false ;
4454
+ }
4455
+
4456
+ // Helper to push sign extension of vXi1 SETCC result through bitops.
4457
+ static SDValue signExtendBitcastSrcVector (SelectionDAG &DAG, EVT SExtVT,
4458
+ SDValue Src, const SDLoc &DL) {
4459
+ switch (Src.getOpcode ()) {
4460
+ case ISD::SETCC:
4461
+ case ISD::FREEZE:
4462
+ case ISD::TRUNCATE:
4463
+ case ISD::BUILD_VECTOR:
4464
+ return DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4465
+ case ISD::AND:
4466
+ case ISD::XOR:
4467
+ case ISD::OR:
4468
+ return DAG.getNode (
4469
+ Src.getOpcode (), DL, SExtVT,
4470
+ signExtendBitcastSrcVector (DAG, SExtVT, Src.getOperand (0 ), DL),
4471
+ signExtendBitcastSrcVector (DAG, SExtVT, Src.getOperand (1 ), DL));
4472
+ case ISD::SELECT:
4473
+ case ISD::VSELECT:
4474
+ return DAG.getSelect (
4475
+ DL, SExtVT, Src.getOperand (0 ),
4476
+ signExtendBitcastSrcVector (DAG, SExtVT, Src.getOperand (1 ), DL),
4477
+ signExtendBitcastSrcVector (DAG, SExtVT, Src.getOperand (2 ), DL));
4478
+ }
4479
+ llvm_unreachable (" Unexpected node type for vXi1 sign extension" );
4480
+ }
4481
+
4426
4482
static SDValue performBITCASTCombine (SDNode *N, SelectionDAG &DAG,
4427
4483
TargetLowering::DAGCombinerInfo &DCI,
4428
4484
const LoongArchSubtarget &Subtarget) {
@@ -4493,10 +4549,56 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
4493
4549
}
4494
4550
}
4495
4551
4496
- if (Opc == ISD::DELETED_NODE)
4497
- return SDValue ();
4552
+ // Generate vXi1 using [X]VMSKLTZ
4553
+ if (Opc == ISD::DELETED_NODE) {
4554
+ MVT SExtVT;
4555
+ bool UseLASX = false ;
4556
+ bool PropagateSExt = false ;
4557
+ switch (SrcVT.getSimpleVT ().SimpleTy ) {
4558
+ default :
4559
+ return SDValue ();
4560
+ case MVT::v2i1:
4561
+ SExtVT = MVT::v2i64;
4562
+ break ;
4563
+ case MVT::v4i1:
4564
+ SExtVT = MVT::v4i32;
4565
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4566
+ SExtVT = MVT::v4i64;
4567
+ UseLASX = true ;
4568
+ PropagateSExt = true ;
4569
+ }
4570
+ break ;
4571
+ case MVT::v8i1:
4572
+ SExtVT = MVT::v8i16;
4573
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4574
+ SExtVT = MVT::v8i32;
4575
+ UseLASX = true ;
4576
+ PropagateSExt = true ;
4577
+ }
4578
+ break ;
4579
+ case MVT::v16i1:
4580
+ SExtVT = MVT::v16i8;
4581
+ if (Subtarget.hasExtLASX () && checkBitcastSrcVectorSize (Src, 256 , 0 )) {
4582
+ SExtVT = MVT::v16i16;
4583
+ UseLASX = true ;
4584
+ PropagateSExt = true ;
4585
+ }
4586
+ break ;
4587
+ case MVT::v32i1:
4588
+ SExtVT = MVT::v32i8;
4589
+ UseLASX = true ;
4590
+ break ;
4591
+ };
4592
+ if (UseLASX && !Subtarget.has32S () && !Subtarget.hasExtLASX ())
4593
+ return SDValue ();
4594
+ Src = PropagateSExt ? signExtendBitcastSrcVector (DAG, SExtVT, Src, DL)
4595
+ : DAG.getNode (ISD::SIGN_EXTEND, DL, SExtVT, Src);
4596
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4597
+ } else {
4598
+ Src = Src.getOperand (0 );
4599
+ }
4498
4600
4499
- SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src. getOperand ( 0 ) );
4601
+ SDValue V = DAG.getNode (Opc, DL, MVT::i64 , Src);
4500
4602
EVT T = EVT::getIntegerVT (*DAG.getContext (), SrcVT.getVectorNumElements ());
4501
4603
V = DAG.getZExtOrTrunc (V, DL, T);
4502
4604
return DAG.getBitcast (VT, V);
0 commit comments