|
94 | 94 | #include <bitset>
|
95 | 95 | #include <cassert>
|
96 | 96 | #include <cctype>
|
| 97 | +#include <cmath> |
97 | 98 | #include <cstdint>
|
98 | 99 | #include <cstdlib>
|
99 | 100 | #include <iterator>
|
@@ -1523,6 +1524,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
|
1523 | 1524 | setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
|
1524 | 1525 | setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
|
1525 | 1526 | setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
|
| 1527 | + setOperationAction(ISD::OR, VT, Custom); |
1526 | 1528 |
|
1527 | 1529 | setOperationAction(ISD::SELECT_CC, VT, Expand);
|
1528 | 1530 | setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
|
@@ -13782,8 +13784,88 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
|
13782 | 13784 | return ResultSLI;
|
13783 | 13785 | }
|
13784 | 13786 |
|
| 13787 | +/// Try to lower the construction of a pointer alias mask to a WHILEWR. |
| 13788 | +/// The mask's enabled lanes represent the elements that will not overlap across one loop iteration. |
| 13789 | +/// This tries to match: |
| 13790 | +/// or (splat (setcc_lt (sub ptrA, ptrB), -(element_size - 1))), |
| 13791 | +/// (get_active_lane_mask 0, (div (sub ptrA, ptrB), element_size)) |
| 13792 | +SDValue tryWhileWRFromOR(SDValue Op, SelectionDAG &DAG) { |
| 13793 | + if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE2()) |
| 13794 | + return SDValue(); |
| 13795 | + auto LaneMask = Op.getOperand(0); |
| 13796 | + auto Splat = Op.getOperand(1); |
| 13797 | + |
| 13798 | + if (LaneMask.getOpcode() != ISD::INTRINSIC_WO_CHAIN || |
| 13799 | + LaneMask.getConstantOperandVal(0) != Intrinsic::get_active_lane_mask || |
| 13800 | + Splat.getOpcode() != ISD::SPLAT_VECTOR) |
| 13801 | + return SDValue(); |
| 13802 | + |
| 13803 | + auto Cmp = Splat.getOperand(0); |
| 13804 | + if (Cmp.getOpcode() != ISD::SETCC) |
| 13805 | + return SDValue(); |
| 13806 | + |
| 13807 | + CondCodeSDNode *Cond = dyn_cast<CondCodeSDNode>(Cmp.getOperand(2)); |
| 13808 | + assert(Cond && "SETCC doesn't have a condition code"); |
| 13809 | + |
| 13810 | + auto ComparatorConst = dyn_cast<ConstantSDNode>(Cmp.getOperand(1)); |
| 13811 | + if (!ComparatorConst || ComparatorConst->getSExtValue() > 0 || |
| 13812 | + Cond->get() != ISD::CondCode::SETLT) |
| 13813 | + return SDValue(); |
| 13814 | + unsigned CompValue = std::abs(ComparatorConst->getSExtValue()); |
| 13815 | + unsigned EltSize = CompValue + 1; |
| 13816 | + if (!isPowerOf2_64(EltSize) || EltSize > 64) |
| 13817 | + return SDValue(); |
| 13818 | + |
| 13819 | + auto Diff = Cmp.getOperand(0); |
| 13820 | + if (Diff.getOpcode() != ISD::SUB || Diff.getValueType() != MVT::i64) |
| 13821 | + return SDValue(); |
| 13822 | + |
| 13823 | + auto LaneMaskConst = dyn_cast<ConstantSDNode>(LaneMask.getOperand(1)); |
| 13824 | + if (!LaneMaskConst || LaneMaskConst->getZExtValue() != 0 || |
| 13825 | + (EltSize != 1 && LaneMask.getOperand(2).getOpcode() != ISD::SRA)) |
| 13826 | + return SDValue(); |
| 13827 | + |
| 13828 | + // An alias mask for i8 elements omits the division because it would just divide by 1 |
| 13829 | + if (EltSize > 1) { |
| 13830 | + auto DiffDiv = LaneMask.getOperand(2); |
| 13831 | + auto DiffDivConst = dyn_cast<ConstantSDNode>(DiffDiv.getOperand(1)); |
| 13832 | + if (!DiffDivConst || DiffDivConst->getZExtValue() != std::log2(EltSize)) |
| 13833 | + return SDValue(); |
| 13834 | + } else if (LaneMask.getOperand(2) != Diff) |
| 13835 | + return SDValue(); |
| 13836 | + |
| 13837 | + auto StorePtr = Diff.getOperand(0); |
| 13838 | + auto ReadPtr = Diff.getOperand(1); |
| 13839 | + |
| 13840 | + unsigned IntrinsicID = 0; |
| 13841 | + switch (EltSize) { |
| 13842 | + case 1: |
| 13843 | + IntrinsicID = Intrinsic::aarch64_sve_whilewr_b; |
| 13844 | + break; |
| 13845 | + case 2: |
| 13846 | + IntrinsicID = Intrinsic::aarch64_sve_whilewr_h; |
| 13847 | + break; |
| 13848 | + case 4: |
| 13849 | + IntrinsicID = Intrinsic::aarch64_sve_whilewr_s; |
| 13850 | + break; |
| 13851 | + case 8: |
| 13852 | + IntrinsicID = Intrinsic::aarch64_sve_whilewr_d; |
| 13853 | + break; |
| 13854 | + default: |
| 13855 | + return SDValue(); |
| 13856 | + } |
| 13857 | + SDLoc DL(Op); |
| 13858 | + SDValue ID = DAG.getConstant(IntrinsicID, DL, MVT::i32); |
| 13859 | + auto N = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), ID, |
| 13860 | + StorePtr, ReadPtr); |
| 13861 | + return N; |
| 13862 | +} |
| 13863 | + |
13785 | 13864 | SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
|
13786 | 13865 | SelectionDAG &DAG) const {
|
| 13866 | + |
| 13867 | + if (SDValue SV = tryWhileWRFromOR(Op, DAG)) |
| 13868 | + return SV; |
13787 | 13869 | if (useSVEForFixedLengthVectorVT(Op.getValueType(),
|
13788 | 13870 | !Subtarget->isNeonAvailable()))
|
13789 | 13871 | return LowerToScalableOp(Op, DAG);
|
|
0 commit comments