Skip to content

[AArch64][GlobalISel] Legalize Insert vector element #81453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/IR/DebugLoc.h"
#include "llvm/IR/Module.h"
Expand Down Expand Up @@ -1300,8 +1301,10 @@ class MachineIRBuilder {
MachineInstrBuilder buildExtractVectorElementConstant(const DstOp &Res,
const SrcOp &Val,
const int Idx) {
return buildExtractVectorElement(Res, Val,
buildConstant(LLT::scalar(64), Idx));
auto TLI = getMF().getSubtarget().getTargetLowering();
unsigned VecIdxWidth = TLI->getVectorIdxTy(getDataLayout()).getSizeInBits();
return buildExtractVectorElement(
Res, Val, buildConstant(LLT::scalar(VecIdxWidth), Idx));
}

/// Build and insert \p Res = G_EXTRACT_VECTOR_ELT \p Val, \p Idx
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/Target/GlobalISel/SelectionDAGCompat.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def : GINodeEquiv<G_CTLZ_ZERO_UNDEF, ctlz_zero_undef>;
def : GINodeEquiv<G_CTTZ_ZERO_UNDEF, cttz_zero_undef>;
def : GINodeEquiv<G_CTPOP, ctpop>;
def : GINodeEquiv<G_EXTRACT_VECTOR_ELT, extractelt>;
def : GINodeEquiv<G_INSERT_VECTOR_ELT, vector_insert>;
def : GINodeEquiv<G_CONCAT_VECTORS, concat_vectors>;
def : GINodeEquiv<G_BUILD_VECTOR, build_vector>;
def : GINodeEquiv<G_FCEIL, fceil>;
Expand Down
16 changes: 15 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3087,7 +3087,21 @@ bool IRTranslator::translateInsertElement(const User &U,
Register Res = getOrCreateVReg(U);
Register Val = getOrCreateVReg(*U.getOperand(0));
Register Elt = getOrCreateVReg(*U.getOperand(1));
Register Idx = getOrCreateVReg(*U.getOperand(2));
unsigned PreferredVecIdxWidth = TLI->getVectorIdxTy(*DL).getSizeInBits();
Register Idx;
if (auto *CI = dyn_cast<ConstantInt>(U.getOperand(2))) {
if (CI->getBitWidth() != PreferredVecIdxWidth) {
APInt NewIdx = CI->getValue().zextOrTrunc(PreferredVecIdxWidth);
auto *NewIdxCI = ConstantInt::get(CI->getContext(), NewIdx);
Idx = getOrCreateVReg(*NewIdxCI);
}
}
if (!Idx)
Idx = getOrCreateVReg(*U.getOperand(2));
if (MRI->getType(Idx).getSizeInBits() != PreferredVecIdxWidth) {
const LLT VecIdxTy = LLT::scalar(PreferredVecIdxWidth);
Idx = MIRBuilder.buildZExtOrTrunc(VecIdxTy, Idx).getReg(0);
}
MIRBuilder.buildInsertVectorElement(Res, Val, Elt, Idx);
return true;
}
Expand Down
55 changes: 55 additions & 0 deletions llvm/lib/CodeGen/MachineVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#include "llvm/CodeGen/SlotIndexes.h"
#include "llvm/CodeGen/StackMaps.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
Expand Down Expand Up @@ -1788,6 +1789,60 @@ void MachineVerifier::verifyPreISelGenericInstruction(const MachineInstr *MI) {

break;
}
case TargetOpcode::G_EXTRACT_VECTOR_ELT: {
LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
LLT SrcTy = MRI->getType(MI->getOperand(1).getReg());
LLT IdxTy = MRI->getType(MI->getOperand(2).getReg());

if (!DstTy.isScalar() && !DstTy.isPointer()) {
report("Destination type must be a scalar or pointer", MI);
break;
}

if (!SrcTy.isVector()) {
report("First source must be a vector", MI);
break;
}

auto TLI = MF->getSubtarget().getTargetLowering();
if (IdxTy.getSizeInBits() !=
TLI->getVectorIdxTy(MF->getDataLayout()).getFixedSizeInBits()) {
report("Index type must match VectorIdxTy", MI);
break;
}

break;
}
case TargetOpcode::G_INSERT_VECTOR_ELT: {
LLT DstTy = MRI->getType(MI->getOperand(0).getReg());
LLT VecTy = MRI->getType(MI->getOperand(1).getReg());
LLT ScaTy = MRI->getType(MI->getOperand(2).getReg());
LLT IdxTy = MRI->getType(MI->getOperand(3).getReg());

if (!DstTy.isVector()) {
report("Destination type must be a vector", MI);
break;
}

if (VecTy != DstTy) {
report("Destination type and vector type must match", MI);
break;
}

if (!ScaTy.isScalar() && !ScaTy.isPointer()) {
report("Inserted element must be a scalar or pointer", MI);
break;
}

auto TLI = MF->getSubtarget().getTargetLowering();
if (IdxTy.getSizeInBits() !=
TLI->getVectorIdxTy(MF->getDataLayout()).getFixedSizeInBits()) {
report("Index type must match VectorIdxTy", MI);
break;
}

break;
}
case TargetOpcode::G_DYN_STACKALLOC: {
const MachineOperand &DstOp = MI->getOperand(0);
const MachineOperand &AllocOp = MI->getOperand(1);
Expand Down
12 changes: 9 additions & 3 deletions llvm/lib/Target/AArch64/AArch64Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def ext: GICombineRule <
(apply [{ applyEXT(*${root}, ${matchinfo}); }])
>;

def insertelt_nonconst: GICombineRule <
(defs root:$root, shuffle_matchdata:$matchinfo),
(match (wip_match_opcode G_INSERT_VECTOR_ELT):$root,
[{ return matchNonConstInsert(*${root}, MRI); }]),
(apply [{ applyNonConstInsert(*${root}, MRI, B); }])
>;

def shuf_to_ins_matchdata : GIDefMatchData<"std::tuple<Register, int, Register, int>">;
def shuf_to_ins: GICombineRule <
(defs root:$root, shuf_to_ins_matchdata:$matchinfo),
Expand All @@ -140,8 +147,7 @@ def form_duplane : GICombineRule <
>;

def shuffle_vector_lowering : GICombineGroup<[dup, rev, ext, zip, uzp, trn,
form_duplane,
shuf_to_ins]>;
form_duplane, shuf_to_ins]>;

// Turn G_UNMERGE_VALUES -> G_EXTRACT_VECTOR_ELT's
def vector_unmerge_lowering : GICombineRule <
Expand Down Expand Up @@ -269,7 +275,7 @@ def AArch64PostLegalizerLowering
lower_vector_fcmp, form_truncstore,
vector_sext_inreg_to_shift,
unmerge_ext_to_unmerge, lower_mull,
vector_unmerge_lowering]> {
vector_unmerge_lowering, insertelt_nonconst]> {
}

// Post-legalization combines which are primarily optimizations.
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Target/AArch64/AArch64InstrAtomics.td
Original file line number Diff line number Diff line change
Expand Up @@ -547,10 +547,10 @@ let Predicates = [HasLSE] in {
let Predicates = [HasRCPC3, HasNEON] in {
// LDAP1 loads
def : Pat<(vector_insert (v2i64 VecListOne128:$Rd),
(i64 (acquiring_load<atomic_load_64> GPR64sp:$Rn)), VectorIndexD:$idx),
(i64 (acquiring_load<atomic_load_64> GPR64sp:$Rn)), (i64 VectorIndexD:$idx)),
(LDAP1 VecListOne128:$Rd, VectorIndexD:$idx, GPR64sp:$Rn)>;
def : Pat<(vector_insert (v2f64 VecListOne128:$Rd),
(f64 (bitconvert (i64 (acquiring_load<atomic_load_64> GPR64sp:$Rn)))), VectorIndexD:$idx),
(f64 (bitconvert (i64 (acquiring_load<atomic_load_64> GPR64sp:$Rn)))), (i64 VectorIndexD:$idx)),
(LDAP1 VecListOne128:$Rd, VectorIndexD:$idx, GPR64sp:$Rn)>;
def : Pat<(v1i64 (scalar_to_vector
(i64 (acquiring_load<atomic_load_64> GPR64sp:$Rn)))),
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -7983,7 +7983,7 @@ class SIMDInsFromMain<string size, ValueType vectype,
"|" # size # "\t$Rd$idx, $Rn}",
"$Rd = $dst",
[(set V128:$dst,
(vector_insert (vectype V128:$Rd), regtype:$Rn, idxtype:$idx))]> {
(vector_insert (vectype V128:$Rd), regtype:$Rn, (i64 idxtype:$idx)))]> {
let Inst{14-11} = 0b0011;
}

Expand All @@ -7997,8 +7997,8 @@ class SIMDInsFromElement<string size, ValueType vectype,
[(set V128:$dst,
(vector_insert
(vectype V128:$Rd),
(elttype (vector_extract (vectype V128:$Rn), idxtype:$idx2)),
idxtype:$idx))]>;
(elttype (vector_extract (vectype V128:$Rn), (i64 idxtype:$idx2))),
(i64 idxtype:$idx)))]>;

class SIMDInsMainMovAlias<string size, Instruction inst,
RegisterClass regtype, Operand idxtype>
Expand Down
39 changes: 24 additions & 15 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -6601,6 +6601,15 @@ def : Pat<(v8i8 (vector_insert (v8i8 V64:$Rn), (i32 GPR32:$Rm), (i64 VectorIndex
VectorIndexB:$imm, GPR32:$Rm),
dsub)>;

def : Pat<(v8i8 (vector_insert (v8i8 V64:$Rn), (i8 FPR8:$Rm), (i64 VectorIndexB:$imm))),
(EXTRACT_SUBREG
(INSvi8lane (v16i8 (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), V64:$Rn, dsub)),
VectorIndexB:$imm, (v16i8 (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), FPR8:$Rm, bsub)), (i64 0)),
dsub)>;
def : Pat<(v16i8 (vector_insert (v16i8 V128:$Rn), (i8 FPR8:$Rm), (i64 VectorIndexB:$imm))),
(INSvi8lane V128:$Rn, VectorIndexB:$imm,
(v16i8 (INSERT_SUBREG (v16i8 (IMPLICIT_DEF)), FPR8:$Rm, bsub)), (i64 0))>;

// Copy an element at a constant index in one vector into a constant indexed
// element of another.
// FIXME refactor to a shared class/dev parameterized on vector type, vector
Expand Down Expand Up @@ -6633,26 +6642,26 @@ def : Pat<(v2i64 (int_aarch64_neon_vcopy_lane
multiclass Neon_INS_elt_pattern<ValueType VT128, ValueType VT64,
ValueType VTScal, Instruction INS> {
def : Pat<(VT128 (vector_insert V128:$src,
(VTScal (vector_extract (VT128 V128:$Rn), imm:$Immn)),
imm:$Immd)),
(VTScal (vector_extract (VT128 V128:$Rn), (i64 imm:$Immn))),
(i64 imm:$Immd))),
(INS V128:$src, imm:$Immd, V128:$Rn, imm:$Immn)>;

def : Pat<(VT128 (vector_insert V128:$src,
(VTScal (vector_extract (VT64 V64:$Rn), imm:$Immn)),
imm:$Immd)),
(VTScal (vector_extract (VT64 V64:$Rn), (i64 imm:$Immn))),
(i64 imm:$Immd))),
(INS V128:$src, imm:$Immd,
(SUBREG_TO_REG (i64 0), V64:$Rn, dsub), imm:$Immn)>;

def : Pat<(VT64 (vector_insert V64:$src,
(VTScal (vector_extract (VT128 V128:$Rn), imm:$Immn)),
imm:$Immd)),
(VTScal (vector_extract (VT128 V128:$Rn), (i64 imm:$Immn))),
(i64 imm:$Immd))),
(EXTRACT_SUBREG (INS (SUBREG_TO_REG (i64 0), V64:$src, dsub),
imm:$Immd, V128:$Rn, imm:$Immn),
dsub)>;

def : Pat<(VT64 (vector_insert V64:$src,
(VTScal (vector_extract (VT64 V64:$Rn), imm:$Immn)),
imm:$Immd)),
(VTScal (vector_extract (VT64 V64:$Rn), (i64 imm:$Immn))),
(i64 imm:$Immd))),
(EXTRACT_SUBREG
(INS (SUBREG_TO_REG (i64 0), V64:$src, dsub), imm:$Immd,
(SUBREG_TO_REG (i64 0), V64:$Rn, dsub), imm:$Immn),
Expand All @@ -6671,14 +6680,14 @@ defm : Neon_INS_elt_pattern<v2i64, v1i64, i64, INSvi64lane>;

// Insert from bitcast
// vector_insert(bitcast(f32 src), n, lane) -> INSvi32lane(src, lane, INSERT_SUBREG(-, n), 0)
def : Pat<(v4i32 (vector_insert v4i32:$src, (i32 (bitconvert (f32 FPR32:$Sn))), imm:$Immd)),
def : Pat<(v4i32 (vector_insert v4i32:$src, (i32 (bitconvert (f32 FPR32:$Sn))), (i64 imm:$Immd))),
(INSvi32lane V128:$src, imm:$Immd, (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$Sn, ssub), 0)>;
def : Pat<(v2i32 (vector_insert v2i32:$src, (i32 (bitconvert (f32 FPR32:$Sn))), imm:$Immd)),
def : Pat<(v2i32 (vector_insert v2i32:$src, (i32 (bitconvert (f32 FPR32:$Sn))), (i64 imm:$Immd))),
(EXTRACT_SUBREG
(INSvi32lane (v4i32 (INSERT_SUBREG (v4i32 (IMPLICIT_DEF)), V64:$src, dsub)),
imm:$Immd, (INSERT_SUBREG (IMPLICIT_DEF), FPR32:$Sn, ssub), 0),
dsub)>;
def : Pat<(v2i64 (vector_insert v2i64:$src, (i64 (bitconvert (f64 FPR64:$Sn))), imm:$Immd)),
def : Pat<(v2i64 (vector_insert v2i64:$src, (i64 (bitconvert (f64 FPR64:$Sn))), (i64 imm:$Immd))),
(INSvi64lane V128:$src, imm:$Immd, (INSERT_SUBREG (IMPLICIT_DEF), FPR64:$Sn, dsub), 0)>;

// bitcast of an extract
Expand Down Expand Up @@ -8100,7 +8109,7 @@ def : Pat<(v8bf16 (AArch64dup (bf16 (load GPR64sp:$Rn)))),
class Ld1Lane128Pat<SDPatternOperator scalar_load, Operand VecIndex,
ValueType VTy, ValueType STy, Instruction LD1>
: Pat<(vector_insert (VTy VecListOne128:$Rd),
(STy (scalar_load GPR64sp:$Rn)), VecIndex:$idx),
(STy (scalar_load GPR64sp:$Rn)), (i64 VecIndex:$idx)),
(LD1 VecListOne128:$Rd, VecIndex:$idx, GPR64sp:$Rn)>;

def : Ld1Lane128Pat<extloadi8, VectorIndexB, v16i8, i32, LD1i8>;
Expand All @@ -8123,14 +8132,14 @@ class Ld1Lane128IdxOpPat<SDPatternOperator scalar_load, Operand
VecIndex, ValueType VTy, ValueType STy,
Instruction LD1, SDNodeXForm IdxOp>
: Pat<(vector_insert (VTy VecListOne128:$Rd),
(STy (scalar_load GPR64sp:$Rn)), VecIndex:$idx),
(STy (scalar_load GPR64sp:$Rn)), (i64 VecIndex:$idx)),
(LD1 VecListOne128:$Rd, (IdxOp VecIndex:$idx), GPR64sp:$Rn)>;

class Ld1Lane64IdxOpPat<SDPatternOperator scalar_load, Operand VecIndex,
ValueType VTy, ValueType STy, Instruction LD1,
SDNodeXForm IdxOp>
: Pat<(vector_insert (VTy VecListOne64:$Rd),
(STy (scalar_load GPR64sp:$Rn)), VecIndex:$idx),
(STy (scalar_load GPR64sp:$Rn)), (i64 VecIndex:$idx)),
(EXTRACT_SUBREG
(LD1 (SUBREG_TO_REG (i32 0), VecListOne64:$Rd, dsub),
(IdxOp VecIndex:$idx), GPR64sp:$Rn),
Expand Down Expand Up @@ -8170,7 +8179,7 @@ let Predicates = [IsNeonAvailable] in {
class Ld1Lane64Pat<SDPatternOperator scalar_load, Operand VecIndex,
ValueType VTy, ValueType STy, Instruction LD1>
: Pat<(vector_insert (VTy VecListOne64:$Rd),
(STy (scalar_load GPR64sp:$Rn)), VecIndex:$idx),
(STy (scalar_load GPR64sp:$Rn)), (i64 VecIndex:$idx)),
(EXTRACT_SUBREG
(LD1 (SUBREG_TO_REG (i32 0), VecListOne64:$Rd, dsub),
VecIndex:$idx, GPR64sp:$Rn),
Expand Down
62 changes: 0 additions & 62 deletions llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ class AArch64InstructionSelector : public InstructionSelector {
MachineInstr *tryAdvSIMDModImmFP(Register Dst, unsigned DstSize, APInt Bits,
MachineIRBuilder &MIRBuilder);

bool selectInsertElt(MachineInstr &I, MachineRegisterInfo &MRI);
bool tryOptConstantBuildVec(MachineInstr &MI, LLT DstTy,
MachineRegisterInfo &MRI);
/// \returns true if a G_BUILD_VECTOR instruction \p MI can be selected as a
Expand Down Expand Up @@ -3498,8 +3497,6 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
return selectShuffleVector(I, MRI);
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
return selectExtractElt(I, MRI);
case TargetOpcode::G_INSERT_VECTOR_ELT:
return selectInsertElt(I, MRI);
case TargetOpcode::G_CONCAT_VECTORS:
return selectConcatVectors(I, MRI);
case TargetOpcode::G_JUMP_TABLE:
Expand Down Expand Up @@ -5330,65 +5327,6 @@ bool AArch64InstructionSelector::selectUSMovFromExtend(
return true;
}

bool AArch64InstructionSelector::selectInsertElt(MachineInstr &I,
MachineRegisterInfo &MRI) {
assert(I.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT);

// Get information on the destination.
Register DstReg = I.getOperand(0).getReg();
const LLT DstTy = MRI.getType(DstReg);
unsigned VecSize = DstTy.getSizeInBits();

// Get information on the element we want to insert into the destination.
Register EltReg = I.getOperand(2).getReg();
const LLT EltTy = MRI.getType(EltReg);
unsigned EltSize = EltTy.getSizeInBits();
if (EltSize < 8 || EltSize > 64)
return false;

// Find the definition of the index. Bail out if it's not defined by a
// G_CONSTANT.
Register IdxReg = I.getOperand(3).getReg();
auto VRegAndVal = getIConstantVRegValWithLookThrough(IdxReg, MRI);
if (!VRegAndVal)
return false;
unsigned LaneIdx = VRegAndVal->Value.getSExtValue();

// Perform the lane insert.
Register SrcReg = I.getOperand(1).getReg();
const RegisterBank &EltRB = *RBI.getRegBank(EltReg, MRI, TRI);

if (VecSize < 128) {
// If the vector we're inserting into is smaller than 128 bits, widen it
// to 128 to do the insert.
MachineInstr *ScalarToVec =
emitScalarToVector(VecSize, &AArch64::FPR128RegClass, SrcReg, MIB);
if (!ScalarToVec)
return false;
SrcReg = ScalarToVec->getOperand(0).getReg();
}

// Create an insert into a new FPR128 register.
// Note that if our vector is already 128 bits, we end up emitting an extra
// register.
MachineInstr *InsMI =
emitLaneInsert(std::nullopt, SrcReg, EltReg, LaneIdx, EltRB, MIB);

if (VecSize < 128) {
// If we had to widen to perform the insert, then we have to demote back to
// the original size to get the result we want.
if (!emitNarrowVector(DstReg, InsMI->getOperand(0).getReg(), MIB, MRI))
return false;
} else {
// No widening needed.
InsMI->getOperand(0).setReg(DstReg);
constrainSelectedInstRegOperands(*InsMI, TII, TRI, RBI);
}

I.eraseFromParent();
return true;
}

MachineInstr *AArch64InstructionSelector::tryAdvSIMDModImm8(
Register Dst, unsigned DstSize, APInt Bits, MachineIRBuilder &Builder) {
unsigned int Op;
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Target/AArch64/GISel/AArch64LegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -886,9 +886,15 @@ AArch64LegalizerInfo::AArch64LegalizerInfo(const AArch64Subtarget &ST)
.clampMaxNumElements(1, p0, 2);

getActionDefinitionsBuilder(G_INSERT_VECTOR_ELT)
.legalIf(typeInSet(0, {v16s8, v8s8, v8s16, v4s16, v4s32, v2s32, v2s64}))
.legalIf(
typeInSet(0, {v16s8, v8s8, v8s16, v4s16, v4s32, v2s32, v2s64, v2p0}))
.moreElementsToNextPow2(0)
.widenVectorEltsToVectorMinSize(0, 64);
.widenVectorEltsToVectorMinSize(0, 64)
.clampNumElements(0, v8s8, v16s8)
.clampNumElements(0, v4s16, v8s16)
.clampNumElements(0, v2s32, v4s32)
.clampMaxNumElements(0, s64, 2)
.clampMaxNumElements(0, p0, 2);

getActionDefinitionsBuilder(G_BUILD_VECTOR)
.legalFor({{v8s8, s8},
Expand Down
Loading