Skip to content

[AArch64] Generalize integer FPR lane stores for all types #134117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/ValueTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def amdgpuBufferFatPointer : ValueType<160, 234>;
// FIXME: Remove this and the getPointerType() override if MVT::i82 is added.
def amdgpuBufferStridedPointer : ValueType<192, 235>;

def aarch64mfp8 : ValueType<8, 236>; // 8-bit value in FPR (AArch64)

let isNormalValueType = false in {
def token : ValueType<0, 504>; // TokenTy
def MetadataVT : ValueType<0, 505> { // Metadata
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/CodeGen/ValueTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ std::string EVT::getEVTString() const {
return "amdgpuBufferFatPointer";
case MVT::amdgpuBufferStridedPointer:
return "amdgpuBufferStridedPointer";
case MVT::aarch64mfp8:
return "aarch64mfp8";
}
}

Expand All @@ -221,6 +223,8 @@ Type *EVT::getTypeForEVT(LLVMContext &Context) const {
case MVT::x86mmx: return llvm::FixedVectorType::get(llvm::IntegerType::get(Context, 64), 1);
case MVT::aarch64svcount:
return TargetExtType::get(Context, "aarch64.svcount");
case MVT::aarch64mfp8:
return FixedVectorType::get(IntegerType::get(Context, 8), 1);
case MVT::x86amx: return Type::getX86_AMXTy(Context);
case MVT::i64x8: return IntegerType::get(Context, 512);
case MVT::amdgpuBufferFatPointer: return IntegerType::get(Context, 160);
Expand Down
79 changes: 54 additions & 25 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
}

if (Subtarget->hasFPARMv8()) {
addRegisterClass(MVT::aarch64mfp8, &AArch64::FPR8RegClass);
addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
addRegisterClass(MVT::f32, &AArch64::FPR32RegClass);
Expand Down Expand Up @@ -23932,6 +23933,8 @@ static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
static unsigned getFPSubregForVT(EVT VT) {
assert(VT.isSimple() && "Expected simple VT");
switch (VT.getSimpleVT().SimpleTy) {
case MVT::aarch64mfp8:
return AArch64::bsub;
case MVT::f16:
return AArch64::hsub;
case MVT::f32:
Expand Down Expand Up @@ -24021,39 +24024,65 @@ static SDValue performSTORECombine(SDNode *N,
SDValue ExtIdx = Value.getOperand(1);
EVT VectorVT = Vector.getValueType();
EVT ElemVT = VectorVT.getVectorElementType();
if (!ValueVT.isInteger() || ElemVT == MVT::i8 || MemVT == MVT::i8)

if (!ValueVT.isInteger())
return SDValue();

// Propagate zero constants (applying this fold may miss optimizations).
if (ISD::isConstantSplatVectorAllZeros(Vector.getNode())) {
SDValue ZeroElt = DAG.getConstant(0, DL, ValueVT);
DAG.ReplaceAllUsesWith(Value, ZeroElt);
return SDValue();
}

if (ValueVT != MemVT && !ST->isTruncatingStore())
return SDValue();

// Heuristic: If there are other users of integer scalars extracted from
// this vector that won't fold into the store -- abandon folding. Applying
// this fold may extend the vector lifetime and disrupt paired stores.
for (const auto &Use : Vector->uses()) {
if (Use.getResNo() != Vector.getResNo())
continue;
const SDNode *User = Use.getUser();
if (User->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
(!User->hasOneUse() ||
(*User->user_begin())->getOpcode() != ISD::STORE))
return SDValue();
}
// This could generate an additional extract if the index is non-zero and
// the extracted value has multiple uses.
auto *ExtCst = dyn_cast<ConstantSDNode>(ExtIdx);
if ((!ExtCst || !ExtCst->isZero()) && !Value.hasOneUse())
return SDValue();

EVT FPElemVT = EVT::getFloatingPointVT(ElemVT.getSizeInBits());
EVT FPVectorVT = VectorVT.changeVectorElementType(FPElemVT);
SDValue Cast = DAG.getNode(ISD::BITCAST, DL, FPVectorVT, Vector);
SDValue Ext =
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, FPElemVT, Cast, ExtIdx);
// These can lower to st1, which is preferable if we're unlikely to fold the
// addressing into the store.
if (Subtarget->isNeonAvailable() && ElemVT == MemVT &&
(VectorVT.is64BitVector() || VectorVT.is128BitVector()) && ExtCst &&
!ExtCst->isZero() && ST->getBasePtr().getOpcode() != ISD::ADD)
return SDValue();

EVT FPMemVT = EVT::getFloatingPointVT(MemVT.getSizeInBits());
if (ST->isTruncatingStore() && FPMemVT != FPElemVT) {
SDValue Trunc = DAG.getTargetExtractSubreg(getFPSubregForVT(FPMemVT), DL,
FPMemVT, Ext);
return DAG.getStore(ST->getChain(), DL, Trunc, ST->getBasePtr(),
ST->getMemOperand());
if (MemVT == MVT::i64 || MemVT == MVT::i32) {
// Heuristic: If there are other users of w/x integer scalars extracted
// from this vector that won't fold into the store -- abandon folding.
// Applying this fold may disrupt paired stores.
for (const auto &Use : Vector->uses()) {
if (Use.getResNo() != Vector.getResNo())
continue;
const SDNode *User = Use.getUser();
if (User->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
(!User->hasOneUse() ||
(*User->user_begin())->getOpcode() != ISD::STORE))
return SDValue();
}
}

return DAG.getStore(ST->getChain(), DL, Ext, ST->getBasePtr(),
SDValue ExtVector = Vector;
if (!ExtCst || !ExtCst->isZero()) {
// Handle extracting from lanes != 0.
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
Value.getValueType(), Vector, ExtIdx);
SDValue Zero = DAG.getVectorIdxConstant(0, DL);
ExtVector = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VectorVT,
DAG.getUNDEF(VectorVT), Ext, Zero);
}

EVT FPMemVT = MemVT == MVT::i8
? MVT::aarch64mfp8
: EVT::getFloatingPointVT(MemVT.getSizeInBits());
SDValue FPSubreg = DAG.getTargetExtractSubreg(getFPSubregForVT(FPMemVT), DL,
FPMemVT, ExtVector);

return DAG.getStore(ST->getChain(), DL, FPSubreg, ST->getBasePtr(),
ST->getMemOperand());
}

Expand Down
44 changes: 31 additions & 13 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3587,7 +3587,7 @@ defm LDRW : LoadUI<0b10, 0, 0b01, GPR32z, uimm12s4, "ldr",
(load (am_indexed32 GPR64sp:$Rn, uimm12s4:$offset)))]>;
let Predicates = [HasFPARMv8] in {
defm LDRB : LoadUI<0b00, 1, 0b01, FPR8Op, uimm12s1, "ldr",
[(set FPR8Op:$Rt,
[(set (i8 FPR8Op:$Rt),
(load (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset)))]>;
defm LDRH : LoadUI<0b01, 1, 0b01, FPR16Op, uimm12s2, "ldr",
[(set (f16 FPR16Op:$Rt),
Expand Down Expand Up @@ -3775,7 +3775,7 @@ defm LDURW : LoadUnscaled<0b10, 0, 0b01, GPR32z, "ldur",
(load (am_unscaled32 GPR64sp:$Rn, simm9:$offset)))]>;
let Predicates = [HasFPARMv8] in {
defm LDURB : LoadUnscaled<0b00, 1, 0b01, FPR8Op, "ldur",
[(set FPR8Op:$Rt,
[(set (i8 FPR8Op:$Rt),
(load (am_unscaled8 GPR64sp:$Rn, simm9:$offset)))]>;
defm LDURH : LoadUnscaled<0b01, 1, 0b01, FPR16Op, "ldur",
[(set (f16 FPR16Op:$Rt),
Expand Down Expand Up @@ -4345,7 +4345,7 @@ defm STRW : StoreUIz<0b10, 0, 0b00, GPR32z, uimm12s4, "str",
(am_indexed32 GPR64sp:$Rn, uimm12s4:$offset))]>;
let Predicates = [HasFPARMv8] in {
defm STRB : StoreUI<0b00, 1, 0b00, FPR8Op, uimm12s1, "str",
[(store FPR8Op:$Rt,
[(store (i8 FPR8Op:$Rt),
(am_indexed8 GPR64sp:$Rn, uimm12s1:$offset))]>;
defm STRH : StoreUI<0b01, 1, 0b00, FPR16Op, uimm12s2, "str",
[(store (f16 FPR16Op:$Rt),
Expand Down Expand Up @@ -4481,7 +4481,7 @@ defm STURW : StoreUnscaled<0b10, 0, 0b00, GPR32z, "stur",
(am_unscaled32 GPR64sp:$Rn, simm9:$offset))]>;
let Predicates = [HasFPARMv8] in {
defm STURB : StoreUnscaled<0b00, 1, 0b00, FPR8Op, "stur",
[(store FPR8Op:$Rt,
[(store (i8 FPR8Op:$Rt),
(am_unscaled8 GPR64sp:$Rn, simm9:$offset))]>;
defm STURH : StoreUnscaled<0b01, 1, 0b00, FPR16Op, "stur",
[(store (f16 FPR16Op:$Rt),
Expand Down Expand Up @@ -4601,6 +4601,12 @@ def : Pat<(truncstorei16 GPR64:$Rt, (am_unscaled16 GPR64sp:$Rn, simm9:$offset)),
def : Pat<(truncstorei8 GPR64:$Rt, (am_unscaled8 GPR64sp:$Rn, simm9:$offset)),
(STURBBi (EXTRACT_SUBREG GPR64:$Rt, sub_32), GPR64sp:$Rn, simm9:$offset)>;

// aarch64mfp8 (bsub) stores
def : Pat<(store aarch64mfp8:$Rt, (am_unscaled8 GPR64sp:$Rn, simm9:$offset)),
(STURBi FPR8:$Rt, GPR64sp:$Rn, simm9:$offset)>;
def : Pat<(store aarch64mfp8:$Rt, (am_indexed8 GPR64sp:$Rn, uimm12s1:$offset)),
(STRBui FPR8:$Rt, GPR64sp:$Rn, uimm12s1:$offset)>;

// Match stores from lane 0 to the appropriate subreg's store.
multiclass VecStoreULane0Pat<SDPatternOperator StoreOp,
ValueType VTy, ValueType STy,
Expand Down Expand Up @@ -7242,8 +7248,15 @@ def : Pat<(v2i64 (int_aarch64_neon_vcopy_lane

// Move elements between vectors
multiclass Neon_INS_elt_pattern<ValueType VT128, ValueType VT64, ValueType VTSVE,
ValueType VTScal, Operand SVEIdxTy, Instruction INS> {
ValueType VTScal, Operand SVEIdxTy, Instruction INS, Instruction DUP, SubRegIndex DUPSub> {
// Extracting from the lowest 128-bits of an SVE vector
def : Pat<(VT128 (vector_insert undef,
(VTScal (vector_extract VTSVE:$Rm, (i64 SVEIdxTy:$Immn))),
(i64 0))),
(INSERT_SUBREG (VT128 (IMPLICIT_DEF)),
(DUP (VT128 (EXTRACT_SUBREG VTSVE:$Rm, zsub)), SVEIdxTy:$Immn),
DUPSub)>;

def : Pat<(VT128 (vector_insert VT128:$Rn,
(VTScal (vector_extract VTSVE:$Rm, (i64 SVEIdxTy:$Immn))),
(i64 imm:$Immd))),
Expand All @@ -7262,6 +7275,11 @@ multiclass Neon_INS_elt_pattern<ValueType VT128, ValueType VT64, ValueType VTSVE
(i64 imm:$Immd))),
(INS V128:$src, imm:$Immd, V128:$Rn, imm:$Immn)>;

def : Pat<(VT128 (vector_insert undef,
(VTScal (vector_extract (VT128 V128:$Rn), (i64 imm:$Immn))),
(i64 0))),
(INSERT_SUBREG (VT128 (IMPLICIT_DEF)), (DUP V128:$Rn, imm:$Immn), DUPSub)>;

def : Pat<(VT128 (vector_insert V128:$src,
(VTScal (vector_extract (VT64 V64:$Rn), (i64 imm:$Immn))),
(i64 imm:$Immd))),
Expand All @@ -7284,15 +7302,15 @@ multiclass Neon_INS_elt_pattern<ValueType VT128, ValueType VT64, ValueType VTSVE
dsub)>;
}

defm : Neon_INS_elt_pattern<v8f16, v4f16, nxv8f16, f16, VectorIndexH, INSvi16lane>;
defm : Neon_INS_elt_pattern<v8bf16, v4bf16, nxv8bf16, bf16, VectorIndexH, INSvi16lane>;
defm : Neon_INS_elt_pattern<v4f32, v2f32, nxv4f32, f32, VectorIndexS, INSvi32lane>;
defm : Neon_INS_elt_pattern<v2f64, v1f64, nxv2f64, f64, VectorIndexD, INSvi64lane>;
defm : Neon_INS_elt_pattern<v8f16, v4f16, nxv8f16, f16, VectorIndexH, INSvi16lane, DUPi16, hsub>;
defm : Neon_INS_elt_pattern<v8bf16, v4bf16, nxv8bf16, bf16, VectorIndexH, INSvi16lane, DUPi16, hsub>;
defm : Neon_INS_elt_pattern<v4f32, v2f32, nxv4f32, f32, VectorIndexS, INSvi32lane, DUPi32, ssub>;
defm : Neon_INS_elt_pattern<v2f64, v1f64, nxv2f64, f64, VectorIndexD, INSvi64lane, DUPi64, dsub>;

defm : Neon_INS_elt_pattern<v16i8, v8i8, nxv16i8, i32, VectorIndexB, INSvi8lane>;
defm : Neon_INS_elt_pattern<v8i16, v4i16, nxv8i16, i32, VectorIndexH, INSvi16lane>;
defm : Neon_INS_elt_pattern<v4i32, v2i32, nxv4i32, i32, VectorIndexS, INSvi32lane>;
defm : Neon_INS_elt_pattern<v2i64, v1i64, nxv2i64, i64, VectorIndexD, INSvi64lane>;
defm : Neon_INS_elt_pattern<v16i8, v8i8, nxv16i8, i32, VectorIndexB, INSvi8lane, DUPi8, bsub>;
defm : Neon_INS_elt_pattern<v8i16, v4i16, nxv8i16, i32, VectorIndexH, INSvi16lane, DUPi16, hsub>;
defm : Neon_INS_elt_pattern<v4i32, v2i32, nxv4i32, i32, VectorIndexS, INSvi32lane, DUPi32, ssub>;
defm : Neon_INS_elt_pattern<v2i64, v1i64, nxv2i64, i64, VectorIndexD, INSvi64lane, DUPi64, dsub>;

// Insert from bitcast
// vector_insert(bitcast(f32 src), n, lane) -> INSvi32lane(src, lane, INSERT_SUBREG(-, n), 0)
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64RegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def Q30 : AArch64Reg<30, "q30", [D30, D30_HI], ["v30", ""]>, DwarfRegAlias<B30
def Q31 : AArch64Reg<31, "q31", [D31, D31_HI], ["v31", ""]>, DwarfRegAlias<B31>;
}

def FPR8 : RegisterClass<"AArch64", [i8], 8, (sequence "B%u", 0, 31)> {
def FPR8 : RegisterClass<"AArch64", [i8, aarch64mfp8], 8, (sequence "B%u", 0, 31)> {
let Size = 8;
let DecoderMethod = "DecodeSimpleRegisterClass<AArch64::FPR8RegClassID, 0, 32>";
}
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -3475,6 +3475,22 @@ let Predicates = [HasSVE_or_SME] in {
(EXTRACT_SUBREG ZPR:$Zs, dsub)>;
}

multiclass sve_insert_extract_elt<ValueType VT, ValueType VTScalar, Instruction DUP, Operand IdxTy> {
// NOP pattern (needed to avoid pointless DUPs being added by the second pattern).
def : Pat<(VT (vector_insert undef,
(VTScalar (vector_extract VT:$vec, (i64 0))), (i64 0))),
(VT $vec)>;

def : Pat<(VT (vector_insert undef,
(VTScalar (vector_extract VT:$vec, (i64 IdxTy:$Idx))), (i64 0))),
(DUP ZPR:$vec, IdxTy:$Idx)>;
}

defm : sve_insert_extract_elt<nxv16i8, i32, DUP_ZZI_B, sve_elm_idx_extdup_b>;
defm : sve_insert_extract_elt<nxv8i16, i32, DUP_ZZI_H, sve_elm_idx_extdup_h>;
defm : sve_insert_extract_elt<nxv4i32, i32, DUP_ZZI_S, sve_elm_idx_extdup_s>;
defm : sve_insert_extract_elt<nxv2i64, i64, DUP_ZZI_D, sve_elm_idx_extdup_d>;

multiclass sve_predicated_add<SDNode extend, int value> {
def : Pat<(nxv16i8 (add ZPR:$op, (extend nxv16i1:$pred))),
(ADD_ZPmZ_B PPR:$pred, ZPR:$op, (DUP_ZI_B value, 0))>;
Expand Down
Loading
Loading