@@ -39898,6 +39898,65 @@ static SDValue combineANDXORWithAllOnesIntoANDNP(SDNode *N, SelectionDAG &DAG) {
39898
39898
return DAG.getNode(X86ISD::ANDNP, SDLoc(N), VT, X, Y);
39899
39899
}
39900
39900
39901
+ // Try to widen AND, OR and XOR nodes to VT in order to remove casts around
39902
+ // logical operations, like in the example below.
39903
+ // or (and (truncate x, truncate y)),
39904
+ // (xor (truncate z, build_vector (constants)))
39905
+ // Given a target type \p VT, we generate
39906
+ // or (and x, y), (xor z, zext(build_vector (constants)))
39907
+ // given x, y and z are of type \p VT. We can do so, if operands are either
39908
+ // truncates from VT types, the second operand is a vector of constants or can
39909
+ // be recursively promoted.
39910
+ static SDValue PromoteMaskArithmetic(SDNode *N, EVT VT, SelectionDAG &DAG,
39911
+ unsigned Depth) {
39912
+ // Limit recursion to avoid excessive compile times.
39913
+ if (Depth >= SelectionDAG::MaxRecursionDepth)
39914
+ return SDValue();
39915
+
39916
+ if (N->getOpcode() != ISD::XOR && N->getOpcode() != ISD::AND &&
39917
+ N->getOpcode() != ISD::OR)
39918
+ return SDValue();
39919
+
39920
+ SDValue N0 = N->getOperand(0);
39921
+ SDValue N1 = N->getOperand(1);
39922
+ SDLoc DL(N);
39923
+
39924
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
39925
+ if (!TLI.isOperationLegalOrPromote(N->getOpcode(), VT))
39926
+ return SDValue();
39927
+
39928
+ if (SDValue NN0 = PromoteMaskArithmetic(N0.getNode(), VT, DAG, Depth + 1))
39929
+ N0 = NN0;
39930
+ else {
39931
+ // The Left side has to be a trunc.
39932
+ if (N0.getOpcode() != ISD::TRUNCATE)
39933
+ return SDValue();
39934
+
39935
+ // The type of the truncated inputs.
39936
+ if (N0.getOperand(0).getValueType() != VT)
39937
+ return SDValue();
39938
+
39939
+ N0 = N0.getOperand(0);
39940
+ }
39941
+
39942
+ if (SDValue NN1 = PromoteMaskArithmetic(N1.getNode(), VT, DAG, Depth + 1))
39943
+ N1 = NN1;
39944
+ else {
39945
+ // The right side has to be a 'trunc' or a constant vector.
39946
+ bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
39947
+ N1.getOperand(0).getValueType() == VT;
39948
+ if (!RHSTrunc && !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
39949
+ return SDValue();
39950
+
39951
+ if (RHSTrunc)
39952
+ N1 = N1.getOperand(0);
39953
+ else
39954
+ N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
39955
+ }
39956
+
39957
+ return DAG.getNode(N->getOpcode(), DL, VT, N0, N1);
39958
+ }
39959
+
39901
39960
// On AVX/AVX2 the type v8i1 is legalized to v8i16, which is an XMM sized
39902
39961
// register. In most cases we actually compare or select YMM-sized registers
39903
39962
// and mixing the two types creates horrible code. This method optimizes
@@ -39909,53 +39968,19 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
39909
39968
EVT VT = N->getValueType(0);
39910
39969
assert(VT.isVector() && "Expected vector type");
39911
39970
39971
+ SDLoc DL(N);
39912
39972
assert((N->getOpcode() == ISD::ANY_EXTEND ||
39913
39973
N->getOpcode() == ISD::ZERO_EXTEND ||
39914
39974
N->getOpcode() == ISD::SIGN_EXTEND) && "Invalid Node");
39915
39975
39916
39976
SDValue Narrow = N->getOperand(0);
39917
39977
EVT NarrowVT = Narrow.getValueType();
39918
39978
39919
- if (Narrow->getOpcode() != ISD::XOR &&
39920
- Narrow->getOpcode() != ISD::AND &&
39921
- Narrow->getOpcode() != ISD::OR)
39922
- return SDValue();
39923
-
39924
- SDValue N0 = Narrow->getOperand(0);
39925
- SDValue N1 = Narrow->getOperand(1);
39926
- SDLoc DL(Narrow);
39927
-
39928
- // The Left side has to be a trunc.
39929
- if (N0.getOpcode() != ISD::TRUNCATE)
39930
- return SDValue();
39931
-
39932
- // The type of the truncated inputs.
39933
- if (N0.getOperand(0).getValueType() != VT)
39934
- return SDValue();
39935
-
39936
- // The right side has to be a 'trunc' or a constant vector.
39937
- bool RHSTrunc = N1.getOpcode() == ISD::TRUNCATE &&
39938
- N1.getOperand(0).getValueType() == VT;
39939
- if (!RHSTrunc &&
39940
- !ISD::isBuildVectorOfConstantSDNodes(N1.getNode()))
39941
- return SDValue();
39942
-
39943
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
39944
-
39945
- if (!TLI.isOperationLegalOrPromote(Narrow->getOpcode(), VT))
39946
- return SDValue();
39947
-
39948
- // Set N0 and N1 to hold the inputs to the new wide operation.
39949
- N0 = N0.getOperand(0);
39950
- if (RHSTrunc)
39951
- N1 = N1.getOperand(0);
39952
- else
39953
- N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, VT, N1);
39954
-
39955
39979
// Generate the wide operation.
39956
- SDValue Op = DAG.getNode(Narrow->getOpcode(), DL, VT, N0, N1);
39957
- unsigned Opcode = N->getOpcode();
39958
- switch (Opcode) {
39980
+ SDValue Op = PromoteMaskArithmetic(Narrow.getNode(), VT, DAG, 0);
39981
+ if (!Op)
39982
+ return SDValue();
39983
+ switch (N->getOpcode()) {
39959
39984
default: llvm_unreachable("Unexpected opcode");
39960
39985
case ISD::ANY_EXTEND:
39961
39986
return Op;
0 commit comments