Skip to content

[AArch64][SVE] Teach compiler to use information that there are no ac… #68698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 88 additions & 38 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,8 @@ instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II,
AddendOp = II.getOperand(2);
Mul = II.getOperand(1);
}
if (match(P, m_ZeroInt()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that this is going have the llvm intrinsic as [llvm-ir-instrinsinc]_u?
Because only for _u llvm-ir intrinsics that the compiler can make sure this can return undefined values.
I agree with Paul and maybe here is not the best place to do this check.

return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
Comment on lines +1172 to +1173
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe? I can see instCombineSVEVectorFuseMulAddSub is called by instCombineSVEVectorFAdd whose result for an all inactive predicate is defined.

As with calling instCombineSVEAllActive you might need to move this functionality into the relevant places that call this function.


if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
m_Value(MulOp1))))
Expand Down Expand Up @@ -1301,9 +1303,26 @@ static std::optional<Instruction *> instCombineSVEAllActive(IntrinsicInst &II,
return &II;
}

// Optimize operations that take an all false predicate or send them for
// canonicalization.
static std::optional<Instruction *>
instCombineSVEAllOrNoActive(InstCombiner &IC, IntrinsicInst &II,
Intrinsic::ID IID) {
if (match(II.getOperand(0), m_ZeroInt())) {
if (II.getIntrinsicID() != IID)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add a comments
// llvm_ir, pred(0), op1, op2 - Spec says to return op1 when all lanes are inactive for sv[func]_m or sv[func]_z

return IC.replaceInstUsesWith(II, II.getOperand(1));
else
return IC.replaceInstUsesWith(II, UndefValue::get(II.getType()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add a comments
// llvm_ir_u, pred(0), op1, op2 - Spec says to return undef when all lanes are inactive for sv[func]_x

}
if (II.getIntrinsicID() != IID)
return instCombineSVEAllActive(II, IID);
return std::nullopt;
}

static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
IntrinsicInst &II) {
if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_add_u))
if (auto II_U =
instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_add_u))
return II_U;
if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
Intrinsic::aarch64_sve_mla>(
Expand All @@ -1318,7 +1337,8 @@ static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,

static std::optional<Instruction *>
instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) {
if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fadd_u))
if (auto II_U =
instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fadd_u))
return II_U;
if (auto FMLA =
instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
Expand Down Expand Up @@ -1360,7 +1380,8 @@ instCombineSVEVectorFAddU(InstCombiner &IC, IntrinsicInst &II) {

static std::optional<Instruction *>
instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) {
if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fsub_u))
if (auto II_U =
instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fsub_u))
return II_U;
if (auto FMLS =
instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
Expand Down Expand Up @@ -1402,7 +1423,8 @@ instCombineSVEVectorFSubU(InstCombiner &IC, IntrinsicInst &II) {

static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
IntrinsicInst &II) {
if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sub_u))
if (auto II_U =
instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sub_u))
return II_U;
if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
Intrinsic::aarch64_sve_mls>(
Expand All @@ -1418,10 +1440,8 @@ static std::optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
auto *OpMultiplicand = II.getOperand(1);
auto *OpMultiplier = II.getOperand(2);

// Canonicalise a non _u intrinsic only.
if (II.getIntrinsicID() != IID)
if (auto II_U = instCombineSVEAllActive(II, IID))
return II_U;
if (auto II_U = instCombineSVEAllOrNoActive(IC, II, IID))
return II_U;

// Return true if a given instruction is a unit splat value, false otherwise.
auto IsUnitSplat = [](auto *I) {
Expand Down Expand Up @@ -1786,34 +1806,45 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
case Intrinsic::aarch64_sve_ptest_last:
return instCombineSVEPTest(IC, II);
case Intrinsic::aarch64_sve_fabd:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fabd_u);
case Intrinsic::aarch64_sve_fabd_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fabd_u);
case Intrinsic::aarch64_sve_fadd:
return instCombineSVEVectorFAdd(IC, II);
case Intrinsic::aarch64_sve_fadd_u:
return instCombineSVEVectorFAddU(IC, II);
case Intrinsic::aarch64_sve_fdiv:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fdiv_u);
case Intrinsic::aarch64_sve_fdiv_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fdiv_u);
case Intrinsic::aarch64_sve_fmax:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmax_u);
case Intrinsic::aarch64_sve_fmax_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmax_u);
case Intrinsic::aarch64_sve_fmaxnm:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmaxnm_u);
case Intrinsic::aarch64_sve_fmaxnm_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmaxnm_u);
case Intrinsic::aarch64_sve_fmin:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmin_u);
case Intrinsic::aarch64_sve_fmin_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmin_u);
case Intrinsic::aarch64_sve_fminnm:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fminnm_u);
case Intrinsic::aarch64_sve_fminnm_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fminnm_u);
case Intrinsic::aarch64_sve_fmla:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmla_u);
case Intrinsic::aarch64_sve_fmla_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmla_u);
case Intrinsic::aarch64_sve_fmls:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmls_u);
case Intrinsic::aarch64_sve_fmls_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmls_u);
case Intrinsic::aarch64_sve_fmul:
case Intrinsic::aarch64_sve_fmul_u:
return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u);
case Intrinsic::aarch64_sve_fmulx:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmulx_u);
case Intrinsic::aarch64_sve_fmulx_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmulx_u);
case Intrinsic::aarch64_sve_fnmla:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fnmla_u);
case Intrinsic::aarch64_sve_fnmla_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fnmla_u);
case Intrinsic::aarch64_sve_fnmls:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fnmls_u);
case Intrinsic::aarch64_sve_fnmls_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fnmls_u);
case Intrinsic::aarch64_sve_fsub:
return instCombineSVEVectorFSub(IC, II);
case Intrinsic::aarch64_sve_fsub_u:
Expand All @@ -1825,52 +1856,71 @@ AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
Intrinsic::aarch64_sve_mla_u>(
IC, II, true);
case Intrinsic::aarch64_sve_mla:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_mla_u);
case Intrinsic::aarch64_sve_mla_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mla_u);
case Intrinsic::aarch64_sve_mls:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_mls_u);
case Intrinsic::aarch64_sve_mls_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mls_u);
case Intrinsic::aarch64_sve_mul:
case Intrinsic::aarch64_sve_mul_u:
return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u);
case Intrinsic::aarch64_sve_sabd:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sabd_u);
case Intrinsic::aarch64_sve_sabd_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sabd_u);
case Intrinsic::aarch64_sve_smax:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smax_u);
case Intrinsic::aarch64_sve_smax_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smax_u);
case Intrinsic::aarch64_sve_smin:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smin_u);
case Intrinsic::aarch64_sve_smin_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smin_u);
case Intrinsic::aarch64_sve_smulh:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smulh_u);
case Intrinsic::aarch64_sve_smulh_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smulh_u);
case Intrinsic::aarch64_sve_sub:
return instCombineSVEVectorSub(IC, II);
case Intrinsic::aarch64_sve_sub_u:
return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
Intrinsic::aarch64_sve_mls_u>(
IC, II, true);
case Intrinsic::aarch64_sve_uabd:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_uabd_u);
case Intrinsic::aarch64_sve_uabd_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_uabd_u);
case Intrinsic::aarch64_sve_umax:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umax_u);
case Intrinsic::aarch64_sve_umax_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umax_u);
case Intrinsic::aarch64_sve_umin:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umin_u);
case Intrinsic::aarch64_sve_umin_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umin_u);
case Intrinsic::aarch64_sve_umulh:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umulh_u);
case Intrinsic::aarch64_sve_umulh_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umulh_u);
case Intrinsic::aarch64_sve_asr:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_asr_u);
case Intrinsic::aarch64_sve_asr_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_asr_u);
case Intrinsic::aarch64_sve_lsl:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_lsl_u);
case Intrinsic::aarch64_sve_lsl_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_lsl_u);
case Intrinsic::aarch64_sve_lsr:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_lsr_u);
case Intrinsic::aarch64_sve_lsr_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_lsr_u);
case Intrinsic::aarch64_sve_and:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_and_u);
case Intrinsic::aarch64_sve_and_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_and_u);
case Intrinsic::aarch64_sve_bic:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_bic_u);
case Intrinsic::aarch64_sve_bic_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_bic_u);
case Intrinsic::aarch64_sve_eor:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_eor_u);
case Intrinsic::aarch64_sve_eor_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_eor_u);
case Intrinsic::aarch64_sve_orr:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_orr_u);
case Intrinsic::aarch64_sve_orr_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_orr_u);
case Intrinsic::aarch64_sve_sqsub:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sqsub_u);
case Intrinsic::aarch64_sve_sqsub_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sqsub_u);
case Intrinsic::aarch64_sve_uqsub:
return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_uqsub_u);
case Intrinsic::aarch64_sve_uqsub_u:
return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_uqsub_u);
case Intrinsic::aarch64_sve_tbl:
return instCombineSVETBL(IC, II);
case Intrinsic::aarch64_sve_uunpkhi:
Expand Down
Loading