Skip to content

Commit acc43db

Browse files
authored
[LoongArch] Convert vector mask to vXi1 using [X]VMSKLTZ (#142978)
This patch adds a DAG combine optimization that transforms `BITCAST` nodes converting vector masks into `vXi1` types via the `[X]VMSKLTZ` instructions.
1 parent 4fb81f1 commit acc43db

File tree

3 files changed

+341
-1848
lines changed

3 files changed

+341
-1848
lines changed

llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4423,6 +4423,62 @@ static SDValue performSRLCombine(SDNode *N, SelectionDAG &DAG,
44234423
return SDValue();
44244424
}
44254425

4426+
// Helper to peek through bitops/trunc/setcc to determine size of source vector.
4427+
// Allows BITCASTCombine to determine what size vector generated a <X x i1>.
4428+
static bool checkBitcastSrcVectorSize(SDValue Src, unsigned Size,
4429+
unsigned Depth) {
4430+
// Limit recursion.
4431+
if (Depth >= SelectionDAG::MaxRecursionDepth)
4432+
return false;
4433+
switch (Src.getOpcode()) {
4434+
case ISD::SETCC:
4435+
case ISD::TRUNCATE:
4436+
return Src.getOperand(0).getValueSizeInBits() == Size;
4437+
case ISD::FREEZE:
4438+
return checkBitcastSrcVectorSize(Src.getOperand(0), Size, Depth + 1);
4439+
case ISD::AND:
4440+
case ISD::XOR:
4441+
case ISD::OR:
4442+
return checkBitcastSrcVectorSize(Src.getOperand(0), Size, Depth + 1) &&
4443+
checkBitcastSrcVectorSize(Src.getOperand(1), Size, Depth + 1);
4444+
case ISD::SELECT:
4445+
case ISD::VSELECT:
4446+
return Src.getOperand(0).getScalarValueSizeInBits() == 1 &&
4447+
checkBitcastSrcVectorSize(Src.getOperand(1), Size, Depth + 1) &&
4448+
checkBitcastSrcVectorSize(Src.getOperand(2), Size, Depth + 1);
4449+
case ISD::BUILD_VECTOR:
4450+
return ISD::isBuildVectorAllZeros(Src.getNode()) ||
4451+
ISD::isBuildVectorAllOnes(Src.getNode());
4452+
}
4453+
return false;
4454+
}
4455+
4456+
// Helper to push sign extension of vXi1 SETCC result through bitops.
4457+
static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
4458+
SDValue Src, const SDLoc &DL) {
4459+
switch (Src.getOpcode()) {
4460+
case ISD::SETCC:
4461+
case ISD::FREEZE:
4462+
case ISD::TRUNCATE:
4463+
case ISD::BUILD_VECTOR:
4464+
return DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
4465+
case ISD::AND:
4466+
case ISD::XOR:
4467+
case ISD::OR:
4468+
return DAG.getNode(
4469+
Src.getOpcode(), DL, SExtVT,
4470+
signExtendBitcastSrcVector(DAG, SExtVT, Src.getOperand(0), DL),
4471+
signExtendBitcastSrcVector(DAG, SExtVT, Src.getOperand(1), DL));
4472+
case ISD::SELECT:
4473+
case ISD::VSELECT:
4474+
return DAG.getSelect(
4475+
DL, SExtVT, Src.getOperand(0),
4476+
signExtendBitcastSrcVector(DAG, SExtVT, Src.getOperand(1), DL),
4477+
signExtendBitcastSrcVector(DAG, SExtVT, Src.getOperand(2), DL));
4478+
}
4479+
llvm_unreachable("Unexpected node type for vXi1 sign extension");
4480+
}
4481+
44264482
static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
44274483
TargetLowering::DAGCombinerInfo &DCI,
44284484
const LoongArchSubtarget &Subtarget) {
@@ -4493,10 +4549,56 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
44934549
}
44944550
}
44954551

4496-
if (Opc == ISD::DELETED_NODE)
4497-
return SDValue();
4552+
// Generate vXi1 using [X]VMSKLTZ
4553+
if (Opc == ISD::DELETED_NODE) {
4554+
MVT SExtVT;
4555+
bool UseLASX = false;
4556+
bool PropagateSExt = false;
4557+
switch (SrcVT.getSimpleVT().SimpleTy) {
4558+
default:
4559+
return SDValue();
4560+
case MVT::v2i1:
4561+
SExtVT = MVT::v2i64;
4562+
break;
4563+
case MVT::v4i1:
4564+
SExtVT = MVT::v4i32;
4565+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4566+
SExtVT = MVT::v4i64;
4567+
UseLASX = true;
4568+
PropagateSExt = true;
4569+
}
4570+
break;
4571+
case MVT::v8i1:
4572+
SExtVT = MVT::v8i16;
4573+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4574+
SExtVT = MVT::v8i32;
4575+
UseLASX = true;
4576+
PropagateSExt = true;
4577+
}
4578+
break;
4579+
case MVT::v16i1:
4580+
SExtVT = MVT::v16i8;
4581+
if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
4582+
SExtVT = MVT::v16i16;
4583+
UseLASX = true;
4584+
PropagateSExt = true;
4585+
}
4586+
break;
4587+
case MVT::v32i1:
4588+
SExtVT = MVT::v32i8;
4589+
UseLASX = true;
4590+
break;
4591+
};
4592+
if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
4593+
return SDValue();
4594+
Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
4595+
: DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
4596+
Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
4597+
} else {
4598+
Src = Src.getOperand(0);
4599+
}
44984600

4499-
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
4601+
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
45004602
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
45014603
V = DAG.getZExtOrTrunc(V, DL, T);
45024604
return DAG.getBitcast(VT, V);

0 commit comments

Comments
 (0)