Skip to content

Commit 0ee1db2

Browse files
committed
[X86] Try to avoid casts around logical vector ops recursively.
Currently PromoteMaskArithemtic only looks at a single operation to skip casts. This means we miss cases where we combine multiple masks. This patch updates PromoteMaskArithemtic to try to recursively promote AND/XOR/AND nodes that terminate in truncates of the right size or constant vectors. Reviewers: craig.topper, RKSimon, spatel Reviewed By: RKSimon Differential Revision: https://reviews.llvm.org/D72524
1 parent 886d2c2 commit 0ee1db2

File tree

2 files changed

+272
-605
lines changed

2 files changed

+272
-605
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39898,6 +39898,65 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
3989839898
return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
3989939899
}
3990039900

39901+
// Try to widen AND, OR and XOR nodes to VT in order to remove casts around
39902+
// logical operations, like in the example below.
39903+
// or (and (truncate x, truncate y)),
39904+
// (xor (truncate z, build_vector (constants)))
39905+
// Given a target type \p VT, we generate
39906+
// or (and x, y), (xor z, zext(build_vector (constants)))
39907+
// given x, y and z are of type \p VT. We can do so, if operands are either
39908+
// truncates from VT types, the second operand is a vector of constants or can
39909+
// be recursively promoted.
39910+
static SDValue PromoteMaskArithmetic(SDNode *N, EVT VT, SelectionDAG &DAG,
39911+
unsigned Depth) {
39912+
// Limit recursion to avoid excessive compile times.
39913+
if (Depth >= SelectionDAG::MaxRecursionDepth)
39914+
return SDValue();
39915+
39916+
if (N->getOpcode() != ISD::XOR && N->getOpcode() != ISD::AND &&
39917+
N->getOpcode() != ISD::OR)
39918+
return SDValue();
39919+
39920+
SDValue N0 = N->getOperand(0);
39921+
SDValue N1 = N->getOperand(1);
39922+
SDLoc DL(N);
39923+
39924+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
39925+
if (!TLI.isOperationLegalOrPromote(N->getOpcode(), VT))
39926+
return SDValue();
39927+
39928+
if (SDValue NN0 = PromoteMaskArithmetic(N0.getNode(), VT, DAG, Depth + 1))
39929+
N0 = NN0;
39930+
else {
39931+
// The Left side has to be a trunc.
39932+
if (N0.getOpcode() != ISD::TRUNCATE)
39933+
return SDValue();
39934+
39935+
// The type of the truncated inputs.
39936+
if (N0.getOperand(0).getValueType() != VT)
39937+
return SDValue();
39938+
39939+
N0 = N0.getOperand(0);
39940+
}
39941+
39942+
if (SDValue NN1 = PromoteMaskArithmetic(N1.getNode(), VT, DAG, Depth + 1))
39943+
N1 = NN1;
39944+
else {
39945+
// The right side has to be a 'trunc' or a constant vector.
39946+
bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
39947+
N1.getOperand(0).getValueType() == VT;
39948+
if (!RHSTrunc && !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
39949+
return SDValue();
39950+
39951+
if (RHSTrunc)
39952+
N1 = N1.getOperand(0);
39953+
else
39954+
N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
39955+
}
39956+
39957+
return DAG.getNode(N->getOpcode(), DL, VT, N0, N1);
39958+
}
39959+
3990139960
// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
3990239961
// register. In most cases we actually compare or select YMM-sized registers
3990339962
// and mixing the two types creates horrible code. This method optimizes
@@ -39909,53 +39968,19 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
3990939968
EVT VT = N->getValueType(0);
3991039969
assert(VT.isVector() && "Expected vector type");
3991139970

39971+
SDLoc DL(N);
3991239972
assert((N->getOpcode() == ISD::ANY_EXTEND ||
3991339973
N->getOpcode() == ISD::ZERO_EXTEND ||
3991439974
N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
3991539975

3991639976
SDValue Narrow = N->getOperand(0);
3991739977
EVT NarrowVT = Narrow.getValueType();
3991839978

39919-
if (Narrow->getOpcode() != ISD::XOR &&
39920-
Narrow->getOpcode() != ISD::AND &&
39921-
Narrow->getOpcode() != ISD::OR)
39922-
return SDValue();
39923-
39924-
SDValue N0 = Narrow->getOperand(0);
39925-
SDValue N1 = Narrow->getOperand(1);
39926-
SDLoc DL(Narrow);
39927-
39928-
// The Left side has to be a trunc.
39929-
if (N0.getOpcode() != ISD::TRUNCATE)
39930-
return SDValue();
39931-
39932-
// The type of the truncated inputs.
39933-
if (N0.getOperand(0).getValueType() != VT)
39934-
return SDValue();
39935-
39936-
// The right side has to be a 'trunc' or a constant vector.
39937-
bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
39938-
N1.getOperand(0).getValueType() == VT;
39939-
if (!RHSTrunc &&
39940-
!ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
39941-
return SDValue();
39942-
39943-
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
39944-
39945-
if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT))
39946-
return SDValue();
39947-
39948-
// Set N0 and N1 to hold the inputs to the new wide operation.
39949-
N0 = N0.getOperand(0);
39950-
if (RHSTrunc)
39951-
N1 = N1.getOperand(0);
39952-
else
39953-
N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
39954-
3995539979
// Generate the wide operation.
39956-
SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1);
39957-
unsigned Opcode = N->getOpcode();
39958-
switch (Opcode) {
39980+
SDValue Op = PromoteMaskArithmetic(Narrow.getNode(), VT, DAG, 0);
39981+
if (!Op)
39982+
return SDValue();
39983+
switch (N->getOpcode()) {
3995939984
default: llvm_unreachable("Unexpected opcode");
3996039985
case ISD::ANY_EXTEND:
3996139986
return Op;

0 commit comments

Comments
 (0)