Skip to content

[AARCH64][Neon] switch to using bitcasts in arm_neon.h where appropriate #127043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions clang/include/clang/Basic/TargetBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ namespace clang {
EltType ET = getEltType();
return ET == Poly8 || ET == Poly16 || ET == Poly64;
}
bool isFloatingPoint() const {
EltType ET = getEltType();
return ET == Float16 || ET == Float32 || ET == Float64 || ET == BFloat16;
}
bool isUnsigned() const { return (Flags & UnsignedFlag) != 0; }
bool isQuad() const { return (Flags & QuadFlag) != 0; }
unsigned getEltSizeInBits() const {
Expand Down
68 changes: 34 additions & 34 deletions clang/include/clang/Basic/arm_neon.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def OP_MLAL : Op<(op "+", $p0, (call "vmull", $p1, $p2))>;
def OP_MULLHi : Op<(call "vmull", (call "vget_high", $p0),
(call "vget_high", $p1))>;
def OP_MULLHi_P64 : Op<(call "vmull",
(cast "poly64_t", (call "vget_high", $p0)),
(cast "poly64_t", (call "vget_high", $p1)))>;
(bitcast "poly64_t", (call "vget_high", $p0)),
(bitcast "poly64_t", (call "vget_high", $p1)))>;
def OP_MULLHi_N : Op<(call "vmull_n", (call "vget_high", $p0), $p1)>;
def OP_MLALHi : Op<(call "vmlal", $p0, (call "vget_high", $p1),
(call "vget_high", $p2))>;
Expand Down Expand Up @@ -95,11 +95,11 @@ def OP_TRN2 : Op<(shuffle $p0, $p1, (interleave
def OP_ZIP2 : Op<(shuffle $p0, $p1, (highhalf (interleave mask0, mask1)))>;
def OP_UZP2 : Op<(shuffle $p0, $p1, (add (decimate (rotl mask0, 1), 2),
(decimate (rotl mask1, 1), 2)))>;
def OP_EQ : Op<(cast "R", (op "==", $p0, $p1))>;
def OP_GE : Op<(cast "R", (op ">=", $p0, $p1))>;
def OP_LE : Op<(cast "R", (op "<=", $p0, $p1))>;
def OP_GT : Op<(cast "R", (op ">", $p0, $p1))>;
def OP_LT : Op<(cast "R", (op "<", $p0, $p1))>;
def OP_EQ : Op<(bitcast "R", (op "==", $p0, $p1))>;
def OP_GE : Op<(bitcast "R", (op ">=", $p0, $p1))>;
def OP_LE : Op<(bitcast "R", (op "<=", $p0, $p1))>;
def OP_GT : Op<(bitcast "R", (op ">", $p0, $p1))>;
def OP_LT : Op<(bitcast "R", (op "<", $p0, $p1))>;
def OP_NEG : Op<(op "-", $p0)>;
def OP_NOT : Op<(op "~", $p0)>;
def OP_AND : Op<(op "&", $p0, $p1)>;
Expand All @@ -108,33 +108,33 @@ def OP_XOR : Op<(op "^", $p0, $p1)>;
def OP_ANDN : Op<(op "&", $p0, (op "~", $p1))>;
def OP_ORN : Op<(op "|", $p0, (op "~", $p1))>;
def OP_CAST : LOp<[(save_temp $promote, $p0),
(cast "R", $promote)]>;
(bitcast "R", $promote)]>;
def OP_HI : Op<(shuffle $p0, $p0, (highhalf mask0))>;
def OP_LO : Op<(shuffle $p0, $p0, (lowhalf mask0))>;
def OP_CONC : Op<(shuffle $p0, $p1, (add mask0, mask1))>;
def OP_DUP : Op<(dup $p0)>;
def OP_DUP_LN : Op<(call_mangled "splat_lane", $p0, $p1)>;
def OP_SEL : Op<(cast "R", (op "|",
(op "&", $p0, (cast $p0, $p1)),
(op "&", (op "~", $p0), (cast $p0, $p2))))>;
def OP_SEL : Op<(bitcast "R", (op "|",
(op "&", $p0, (bitcast $p0, $p1)),
(op "&", (op "~", $p0), (bitcast $p0, $p2))))>;
def OP_REV16 : Op<(shuffle $p0, $p0, (rev 16, mask0))>;
def OP_REV32 : Op<(shuffle $p0, $p0, (rev 32, mask0))>;
def OP_REV64 : Op<(shuffle $p0, $p0, (rev 64, mask0))>;
def OP_XTN : Op<(call "vcombine", $p0, (call "vmovn", $p1))>;
def OP_SQXTUN : Op<(call "vcombine", (cast $p0, "U", $p0),
def OP_SQXTUN : Op<(call "vcombine", (bitcast $p0, "U", $p0),
(call "vqmovun", $p1))>;
def OP_QXTN : Op<(call "vcombine", $p0, (call "vqmovn", $p1))>;
def OP_VCVT_NA_HI_F16 : Op<(call "vcombine", $p0, (call "vcvt_f16_f32", $p1))>;
def OP_VCVT_NA_HI_F32 : Op<(call "vcombine", $p0, (call "vcvt_f32_f64", $p1))>;
def OP_VCVT_EX_HI_F32 : Op<(call "vcvt_f32_f16", (call "vget_high", $p0))>;
def OP_VCVT_EX_HI_F64 : Op<(call "vcvt_f64_f32", (call "vget_high", $p0))>;
def OP_VCVTX_HI : Op<(call "vcombine", $p0, (call "vcvtx_f32", $p1))>;
def OP_REINT : Op<(cast "R", $p0)>;
def OP_REINT : Op<(bitcast "R", $p0)>;
def OP_ADDHNHi : Op<(call "vcombine", $p0, (call "vaddhn", $p1, $p2))>;
def OP_RADDHNHi : Op<(call "vcombine", $p0, (call "vraddhn", $p1, $p2))>;
def OP_SUBHNHi : Op<(call "vcombine", $p0, (call "vsubhn", $p1, $p2))>;
def OP_RSUBHNHi : Op<(call "vcombine", $p0, (call "vrsubhn", $p1, $p2))>;
def OP_ABDL : Op<(cast "R", (call "vmovl", (cast $p0, "U",
def OP_ABDL : Op<(bitcast "R", (call "vmovl", (bitcast $p0, "U",
(call "vabd", $p0, $p1))))>;
def OP_ABDLHi : Op<(call "vabdl", (call "vget_high", $p0),
(call "vget_high", $p1))>;
Expand All @@ -152,15 +152,15 @@ def OP_QDMLSLHi : Op<(call "vqdmlsl", $p0, (call "vget_high", $p1),
(call "vget_high", $p2))>;
def OP_QDMLSLHi_N : Op<(call "vqdmlsl_n", $p0, (call "vget_high", $p1), $p2)>;
def OP_DIV : Op<(op "/", $p0, $p1)>;
def OP_LONG_HI : Op<(cast "R", (call (name_replace "_high_", "_"),
def OP_LONG_HI : Op<(bitcast "R", (call (name_replace "_high_", "_"),
(call "vget_high", $p0), $p1))>;
def OP_NARROW_HI : Op<(cast "R", (call "vcombine",
(cast "R", "H", $p0),
(cast "R", "H",
def OP_NARROW_HI : Op<(bitcast "R", (call "vcombine",
(bitcast "R", "H", $p0),
(bitcast "R", "H",
(call (name_replace "_high_", "_"),
$p1, $p2))))>;
def OP_MOVL_HI : LOp<[(save_temp $a1, (call "vget_high", $p0)),
(cast "R",
(bitcast "R",
(call "vshll_n", $a1, (literal "int32_t", "0")))]>;
def OP_COPY_LN : Op<(call "vset_lane", (call "vget_lane", $p2, $p3), $p0, $p1)>;
def OP_SCALAR_MUL_LN : Op<(op "*", $p0, (call "vget_lane", $p1, $p2))>;
Expand Down Expand Up @@ -221,18 +221,18 @@ def OP_FMLSL_LN_Hi : Op<(call "vfmlsl_high", $p0, $p1,

def OP_USDOT_LN
: Op<(call "vusdot", $p0, $p1,
(cast "8", "S", (call_mangled "splat_lane", (bitcast "int32x2_t", $p2), $p3)))>;
(bitcast "8", "S", (call_mangled "splat_lane", (bitcast "int32x2_t", $p2), $p3)))>;
def OP_USDOT_LNQ
: Op<(call "vusdot", $p0, $p1,
(cast "8", "S", (call_mangled "splat_lane", (bitcast "int32x4_t", $p2), $p3)))>;
(bitcast "8", "S", (call_mangled "splat_lane", (bitcast "int32x4_t", $p2), $p3)))>;

// sudot splats the second vector and then calls vusdot
def OP_SUDOT_LN
: Op<(call "vusdot", $p0,
(cast "8", "U", (call_mangled "splat_lane", (bitcast "int32x2_t", $p2), $p3)), $p1)>;
(bitcast "8", "U", (call_mangled "splat_lane", (bitcast "int32x2_t", $p2), $p3)), $p1)>;
def OP_SUDOT_LNQ
: Op<(call "vusdot", $p0,
(cast "8", "U", (call_mangled "splat_lane", (bitcast "int32x4_t", $p2), $p3)), $p1)>;
(bitcast "8", "U", (call_mangled "splat_lane", (bitcast "int32x4_t", $p2), $p3)), $p1)>;

def OP_BFDOT_LN
: Op<(call "vbfdot", $p0, $p1,
Expand Down Expand Up @@ -263,7 +263,7 @@ def OP_VCVT_BF16_F32_A32
: Op<(call "__a32_vcvt_bf16", $p0)>;

def OP_VCVT_BF16_F32_LO_A32
: Op<(call "vcombine", (cast "bfloat16x4_t", (literal "uint64_t", "0ULL")),
: Op<(call "vcombine", (bitcast "bfloat16x4_t", (literal "uint64_t", "0ULL")),
(call "__a32_vcvt_bf16", $p0))>;
def OP_VCVT_BF16_F32_HI_A32
: Op<(call "vcombine", (call "__a32_vcvt_bf16", $p1),
Expand Down Expand Up @@ -924,12 +924,12 @@ def CFMLE : SOpInst<"vcle", "U..", "lUldQdQlQUl", OP_LE>;
def CFMGT : SOpInst<"vcgt", "U..", "lUldQdQlQUl", OP_GT>;
def CFMLT : SOpInst<"vclt", "U..", "lUldQdQlQUl", OP_LT>;

def CMEQ : SInst<"vceqz", "U.",
def CMEQ : SInst<"vceqz", "U(.!)",
"csilfUcUsUiUlPcPlQcQsQiQlQfQUcQUsQUiQUlQPcdQdQPl">;
def CMGE : SInst<"vcgez", "U.", "csilfdQcQsQiQlQfQd">;
def CMLE : SInst<"vclez", "U.", "csilfdQcQsQiQlQfQd">;
def CMGT : SInst<"vcgtz", "U.", "csilfdQcQsQiQlQfQd">;
def CMLT : SInst<"vcltz", "U.", "csilfdQcQsQiQlQfQd">;
def CMGE : SInst<"vcgez", "U(.!)", "csilfdQcQsQiQlQfQd">;
def CMLE : SInst<"vclez", "U(.!)", "csilfdQcQsQiQlQfQd">;
def CMGT : SInst<"vcgtz", "U(.!)", "csilfdQcQsQiQlQfQd">;
def CMLT : SInst<"vcltz", "U(.!)", "csilfdQcQsQiQlQfQd">;

////////////////////////////////////////////////////////////////////////////////
// Max/Min Integer
Expand Down Expand Up @@ -1667,11 +1667,11 @@ let TargetGuard = "fullfp16,neon" in {
// ARMv8.2-A FP16 one-operand vector intrinsics.

// Comparison
def CMEQH : SInst<"vceqz", "U.", "hQh">;
def CMGEH : SInst<"vcgez", "U.", "hQh">;
def CMGTH : SInst<"vcgtz", "U.", "hQh">;
def CMLEH : SInst<"vclez", "U.", "hQh">;
def CMLTH : SInst<"vcltz", "U.", "hQh">;
def CMEQH : SInst<"vceqz", "U(.!)", "hQh">;
def CMGEH : SInst<"vcgez", "U(.!)", "hQh">;
def CMGTH : SInst<"vcgtz", "U(.!)", "hQh">;
def CMLEH : SInst<"vclez", "U(.!)", "hQh">;
def CMLTH : SInst<"vcltz", "U(.!)", "hQh">;

// Vector conversion
def VCVT_F16 : SInst<"vcvt_f16", "F(.!)", "sUsQsQUs">;
Expand Down
8 changes: 4 additions & 4 deletions clang/lib/CodeGen/CodeGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -4694,10 +4694,10 @@ class CodeGenFunction : public CodeGenTypeCache {
llvm::Value *EmitTargetBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
ReturnValueSlot ReturnValue);

llvm::Value *EmitAArch64CompareBuiltinExpr(llvm::Value *Op, llvm::Type *Ty,
const llvm::CmpInst::Predicate Fp,
const llvm::CmpInst::Predicate Ip,
const llvm::Twine &Name = "");
llvm::Value *
EmitAArch64CompareBuiltinExpr(llvm::Value *Op, llvm::Type *Ty,
const llvm::CmpInst::Predicate Pred,
const llvm::Twine &Name = "");
llvm::Value *EmitARMBuiltinExpr(unsigned BuiltinID, const CallExpr *E,
ReturnValueSlot ReturnValue,
llvm::Triple::ArchType Arch);
Expand Down
102 changes: 66 additions & 36 deletions clang/lib/CodeGen/TargetBuiltins/ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1750,8 +1750,9 @@ Value *CodeGenFunction::EmitCommonNeonBuiltinExpr(

// Determine the type of this overloaded NEON intrinsic.
NeonTypeFlags Type(NeonTypeConst->getZExtValue());
bool Usgn = Type.isUnsigned();
bool Quad = Type.isQuad();
const bool Usgn = Type.isUnsigned();
const bool Quad = Type.isQuad();
const bool Floating = Type.isFloatingPoint();
const bool HasLegalHalfType = getTarget().hasLegalHalfType();
const bool AllowBFloatArgsAndRet =
getTargetHooks().getABIInfo().allowBFloatArgsAndRet();
Expand Down Expand Up @@ -1852,24 +1853,28 @@ Value *CodeGenFunction::EmitCommonNeonBuiltinExpr(
}
case NEON::BI__builtin_neon_vceqz_v:
case NEON::BI__builtin_neon_vceqzq_v:
return EmitAArch64CompareBuiltinExpr(Ops[0], Ty, ICmpInst::FCMP_OEQ,
ICmpInst::ICMP_EQ, "vceqz");
return EmitAArch64CompareBuiltinExpr(
Ops[0], Ty, Floating ? ICmpInst::FCMP_OEQ : ICmpInst::ICMP_EQ, "vceqz");
case NEON::BI__builtin_neon_vcgez_v:
case NEON::BI__builtin_neon_vcgezq_v:
return EmitAArch64CompareBuiltinExpr(Ops[0], Ty, ICmpInst::FCMP_OGE,
ICmpInst::ICMP_SGE, "vcgez");
return EmitAArch64CompareBuiltinExpr(
Ops[0], Ty, Floating ? ICmpInst::FCMP_OGE : ICmpInst::ICMP_SGE,
"vcgez");
case NEON::BI__builtin_neon_vclez_v:
case NEON::BI__builtin_neon_vclezq_v:
return EmitAArch64CompareBuiltinExpr(Ops[0], Ty, ICmpInst::FCMP_OLE,
ICmpInst::ICMP_SLE, "vclez");
return EmitAArch64CompareBuiltinExpr(
Ops[0], Ty, Floating ? ICmpInst::FCMP_OLE : ICmpInst::ICMP_SLE,
"vclez");
case NEON::BI__builtin_neon_vcgtz_v:
case NEON::BI__builtin_neon_vcgtzq_v:
return EmitAArch64CompareBuiltinExpr(Ops[0], Ty, ICmpInst::FCMP_OGT,
ICmpInst::ICMP_SGT, "vcgtz");
return EmitAArch64CompareBuiltinExpr(
Ops[0], Ty, Floating ? ICmpInst::FCMP_OGT : ICmpInst::ICMP_SGT,
"vcgtz");
case NEON::BI__builtin_neon_vcltz_v:
case NEON::BI__builtin_neon_vcltzq_v:
return EmitAArch64CompareBuiltinExpr(Ops[0], Ty, ICmpInst::FCMP_OLT,
ICmpInst::ICMP_SLT, "vcltz");
return EmitAArch64CompareBuiltinExpr(
Ops[0], Ty, Floating ? ICmpInst::FCMP_OLT : ICmpInst::ICMP_SLT,
"vcltz");
case NEON::BI__builtin_neon_vclz_v:
case NEON::BI__builtin_neon_vclzq_v:
// We generate target-independent intrinsic, which needs a second argument
Expand Down Expand Up @@ -2432,28 +2437,32 @@ Value *CodeGenFunction::EmitCommonNeonBuiltinExpr(
return Builder.CreateBitCast(Result, ResultType, NameHint);
}

Value *CodeGenFunction::EmitAArch64CompareBuiltinExpr(
Value *Op, llvm::Type *Ty, const CmpInst::Predicate Fp,
const CmpInst::Predicate Ip, const Twine &Name) {
llvm::Type *OTy = Op->getType();

// FIXME: this is utterly horrific. We should not be looking at previous
// codegen context to find out what needs doing. Unfortunately TableGen
// currently gives us exactly the same calls for vceqz_f32 and vceqz_s32
// (etc).
if (BitCastInst *BI = dyn_cast<BitCastInst>(Op))
OTy = BI->getOperand(0)->getType();

Op = Builder.CreateBitCast(Op, OTy);
if (OTy->getScalarType()->isFloatingPointTy()) {
if (Fp == CmpInst::FCMP_OEQ)
Op = Builder.CreateFCmp(Fp, Op, Constant::getNullValue(OTy));
Value *
CodeGenFunction::EmitAArch64CompareBuiltinExpr(Value *Op, llvm::Type *Ty,
const CmpInst::Predicate Pred,
const Twine &Name) {

if (isa<FixedVectorType>(Ty)) {
// Vector types are cast to i8 vectors. Recover original type.
Op = Builder.CreateBitCast(Op, Ty);
}

if (CmpInst::isFPPredicate(Pred)) {
if (Pred == CmpInst::FCMP_OEQ)
Op = Builder.CreateFCmp(Pred, Op, Constant::getNullValue(Op->getType()));
else
Op = Builder.CreateFCmpS(Fp, Op, Constant::getNullValue(OTy));
Op = Builder.CreateFCmpS(Pred, Op, Constant::getNullValue(Op->getType()));
} else {
Op = Builder.CreateICmp(Ip, Op, Constant::getNullValue(OTy));
Op = Builder.CreateICmp(Pred, Op, Constant::getNullValue(Op->getType()));
}
return Builder.CreateSExt(Op, Ty, Name);

llvm::Type *ResTy = Ty;
if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
ResTy = FixedVectorType::get(
IntegerType::get(getLLVMContext(), VTy->getScalarSizeInBits()),
VTy->getNumElements());

return Builder.CreateSExt(Op, ResTy, Name);
}

static Value *packTBLDVectorList(CodeGenFunction &CGF, ArrayRef<Value *> Ops,
Expand Down Expand Up @@ -5955,45 +5964,66 @@ Value *CodeGenFunction::EmitAArch64BuiltinExpr(unsigned BuiltinID,
return Builder.CreateFAdd(Op0, Op1, "vpaddd");
}
case NEON::BI__builtin_neon_vceqzd_s64:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::ICMP_EQ, "vceqz");
case NEON::BI__builtin_neon_vceqzd_f64:
case NEON::BI__builtin_neon_vceqzs_f32:
case NEON::BI__builtin_neon_vceqzh_f16:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::FCMP_OEQ, ICmpInst::ICMP_EQ, "vceqz");
ICmpInst::FCMP_OEQ, "vceqz");
case NEON::BI__builtin_neon_vcgezd_s64:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::ICMP_SGE, "vcgez");
case NEON::BI__builtin_neon_vcgezd_f64:
case NEON::BI__builtin_neon_vcgezs_f32:
case NEON::BI__builtin_neon_vcgezh_f16:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::FCMP_OGE, ICmpInst::ICMP_SGE, "vcgez");
ICmpInst::FCMP_OGE, "vcgez");
case NEON::BI__builtin_neon_vclezd_s64:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::ICMP_SLE, "vclez");
case NEON::BI__builtin_neon_vclezd_f64:
case NEON::BI__builtin_neon_vclezs_f32:
case NEON::BI__builtin_neon_vclezh_f16:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::FCMP_OLE, ICmpInst::ICMP_SLE, "vclez");
ICmpInst::FCMP_OLE, "vclez");
case NEON::BI__builtin_neon_vcgtzd_s64:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::ICMP_SGT, "vcgtz");
case NEON::BI__builtin_neon_vcgtzd_f64:
case NEON::BI__builtin_neon_vcgtzs_f32:
case NEON::BI__builtin_neon_vcgtzh_f16:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::FCMP_OGT, ICmpInst::ICMP_SGT, "vcgtz");
ICmpInst::FCMP_OGT, "vcgtz");
case NEON::BI__builtin_neon_vcltzd_s64:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::ICMP_SLT, "vcltz");

case NEON::BI__builtin_neon_vcltzd_f64:
case NEON::BI__builtin_neon_vcltzs_f32:
case NEON::BI__builtin_neon_vcltzh_f16:
Ops.push_back(EmitScalarExpr(E->getArg(0)));
return EmitAArch64CompareBuiltinExpr(
Ops[0], ConvertType(E->getCallReturnType(getContext())),
ICmpInst::FCMP_OLT, ICmpInst::ICMP_SLT, "vcltz");
ICmpInst::FCMP_OLT, "vcltz");

case NEON::BI__builtin_neon_vceqzd_u64: {
Ops.push_back(EmitScalarExpr(E->getArg(0)));
Expand Down
Loading
Loading