Skip to content

Commit adf02ae

Browse files
authored
[AMDGPU] Simplify lowerBUILD_VECTOR (#109094)
Simplify `lowerBUILD_VECTOR` by commoning up the way the vectors are split. Also reorder the checks to avoid a long condition inside `if`.
1 parent 2e3c7db commit adf02ae

File tree

2 files changed

+32
-89
lines changed

2 files changed

+32
-89
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 31 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -7443,98 +7443,49 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
74437443
SDLoc SL(Op);
74447444
EVT VT = Op.getValueType();
74457445

7446-
if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
7447-
VT == MVT::v8f16 || VT == MVT::v4bf16 || VT == MVT::v8bf16) {
7448-
EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
7449-
VT.getVectorNumElements() / 2);
7450-
MVT HalfIntVT = MVT::getIntegerVT(HalfVT.getSizeInBits());
7446+
if (VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16) {
7447+
assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
74517448

7452-
// Turn into pair of packed build_vectors.
7453-
// TODO: Special case for constants that can be materialized with s_mov_b64.
7454-
SmallVector<SDValue, 4> LoOps, HiOps;
7455-
for (unsigned I = 0, E = VT.getVectorNumElements() / 2; I != E; ++I) {
7456-
LoOps.push_back(Op.getOperand(I));
7457-
HiOps.push_back(Op.getOperand(I + E));
7458-
}
7459-
SDValue Lo = DAG.getBuildVector(HalfVT, SL, LoOps);
7460-
SDValue Hi = DAG.getBuildVector(HalfVT, SL, HiOps);
7461-
7462-
SDValue CastLo = DAG.getNode(ISD::BITCAST, SL, HalfIntVT, Lo);
7463-
SDValue CastHi = DAG.getNode(ISD::BITCAST, SL, HalfIntVT, Hi);
7464-
7465-
SDValue Blend = DAG.getBuildVector(MVT::getVectorVT(HalfIntVT, 2), SL,
7466-
{ CastLo, CastHi });
7467-
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
7468-
}
7449+
SDValue Lo = Op.getOperand(0);
7450+
SDValue Hi = Op.getOperand(1);
74697451

7470-
if (VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v16bf16) {
7471-
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
7472-
VT.getVectorNumElements() / 4);
7473-
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
7474-
7475-
SmallVector<SDValue, 4> Parts[4];
7476-
for (unsigned I = 0, E = VT.getVectorNumElements() / 4; I != E; ++I) {
7477-
for (unsigned P = 0; P < 4; ++P)
7478-
Parts[P].push_back(Op.getOperand(I + P * E));
7479-
}
7480-
SDValue Casts[4];
7481-
for (unsigned P = 0; P < 4; ++P) {
7482-
SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
7483-
Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
7452+
// Avoid adding defined bits with the zero_extend.
7453+
if (Hi.isUndef()) {
7454+
Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
7455+
SDValue ExtLo = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Lo);
7456+
return DAG.getNode(ISD::BITCAST, SL, VT, ExtLo);
74847457
}
74857458

7486-
SDValue Blend =
7487-
DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 4), SL, Casts);
7488-
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
7489-
}
7459+
Hi = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Hi);
7460+
Hi = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Hi);
74907461

7491-
if (VT == MVT::v32i16 || VT == MVT::v32f16 || VT == MVT::v32bf16) {
7492-
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
7493-
VT.getVectorNumElements() / 8);
7494-
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
7462+
SDValue ShlHi = DAG.getNode(ISD::SHL, SL, MVT::i32, Hi,
7463+
DAG.getConstant(16, SL, MVT::i32));
7464+
if (Lo.isUndef())
7465+
return DAG.getNode(ISD::BITCAST, SL, VT, ShlHi);
74957466

7496-
SmallVector<SDValue, 8> Parts[8];
7497-
for (unsigned I = 0, E = VT.getVectorNumElements() / 8; I != E; ++I) {
7498-
for (unsigned P = 0; P < 8; ++P)
7499-
Parts[P].push_back(Op.getOperand(I + P * E));
7500-
}
7501-
SDValue Casts[8];
7502-
for (unsigned P = 0; P < 8; ++P) {
7503-
SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
7504-
Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
7505-
}
7467+
Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
7468+
Lo = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Lo);
75067469

7507-
SDValue Blend =
7508-
DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 8), SL, Casts);
7509-
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
7470+
SDValue Or = DAG.getNode(ISD::OR, SL, MVT::i32, Lo, ShlHi);
7471+
return DAG.getNode(ISD::BITCAST, SL, VT, Or);
75107472
}
75117473

7512-
assert(VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16);
7513-
assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
7474+
// Split into 2-element chunks.
7475+
const unsigned NumParts = VT.getVectorNumElements() / 2;
7476+
EVT PartVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(), 2);
7477+
MVT PartIntVT = MVT::getIntegerVT(PartVT.getSizeInBits());
75147478

7515-
SDValue Lo = Op.getOperand(0);
7516-
SDValue Hi = Op.getOperand(1);
7517-
7518-
// Avoid adding defined bits with the zero_extend.
7519-
if (Hi.isUndef()) {
7520-
Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
7521-
SDValue ExtLo = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Lo);
7522-
return DAG.getNode(ISD::BITCAST, SL, VT, ExtLo);
7479+
SmallVector<SDValue> Casts;
7480+
for (unsigned P = 0; P < NumParts; ++P) {
7481+
SDValue Vec = DAG.getBuildVector(
7482+
PartVT, SL, {Op.getOperand(P * 2), Op.getOperand(P * 2 + 1)});
7483+
Casts.push_back(DAG.getNode(ISD::BITCAST, SL, PartIntVT, Vec));
75237484
}
75247485

7525-
Hi = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Hi);
7526-
Hi = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Hi);
7527-
7528-
SDValue ShlHi = DAG.getNode(ISD::SHL, SL, MVT::i32, Hi,
7529-
DAG.getConstant(16, SL, MVT::i32));
7530-
if (Lo.isUndef())
7531-
return DAG.getNode(ISD::BITCAST, SL, VT, ShlHi);
7532-
7533-
Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
7534-
Lo = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Lo);
7535-
7536-
SDValue Or = DAG.getNode(ISD::OR, SL, MVT::i32, Lo, ShlHi);
7537-
return DAG.getNode(ISD::BITCAST, SL, VT, Or);
7486+
SDValue Blend =
7487+
DAG.getBuildVector(MVT::getVectorVT(PartIntVT, NumParts), SL, Casts);
7488+
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
75387489
}
75397490

75407491
bool

llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -965,11 +965,7 @@ define amdgpu_kernel void @v_insertelement_v8bf16_3(ptr addrspace(1) %out, ptr a
965965
; GFX900-NEXT: v_mov_b32_e32 v5, 0x5040100
966966
; GFX900-NEXT: s_waitcnt lgkmcnt(0)
967967
; GFX900-NEXT: global_load_dwordx4 v[0:3], v4, s[2:3]
968-
; GFX900-NEXT: s_mov_b32 s2, 0xffff
969968
; GFX900-NEXT: s_waitcnt vmcnt(0)
970-
; GFX900-NEXT: v_bfi_b32 v3, s2, v3, v3
971-
; GFX900-NEXT: v_bfi_b32 v2, s2, v2, v2
972-
; GFX900-NEXT: v_bfi_b32 v0, s2, v0, v0
973969
; GFX900-NEXT: v_perm_b32 v1, s4, v1, v5
974970
; GFX900-NEXT: global_store_dwordx4 v4, v[0:3], s[0:1]
975971
; GFX900-NEXT: s_endpgm
@@ -980,14 +976,10 @@ define amdgpu_kernel void @v_insertelement_v8bf16_3(ptr addrspace(1) %out, ptr a
980976
; GFX940-NEXT: s_load_dword s0, s[2:3], 0x10
981977
; GFX940-NEXT: v_and_b32_e32 v0, 0x3ff, v0
982978
; GFX940-NEXT: v_lshlrev_b32_e32 v4, 4, v0
983-
; GFX940-NEXT: s_mov_b32 s1, 0xffff
979+
; GFX940-NEXT: v_mov_b32_e32 v5, 0x5040100
984980
; GFX940-NEXT: s_waitcnt lgkmcnt(0)
985981
; GFX940-NEXT: global_load_dwordx4 v[0:3], v4, s[6:7]
986-
; GFX940-NEXT: v_mov_b32_e32 v5, 0x5040100
987982
; GFX940-NEXT: s_waitcnt vmcnt(0)
988-
; GFX940-NEXT: v_bfi_b32 v3, s1, v3, v3
989-
; GFX940-NEXT: v_bfi_b32 v2, s1, v2, v2
990-
; GFX940-NEXT: v_bfi_b32 v0, s1, v0, v0
991983
; GFX940-NEXT: v_perm_b32 v1, s0, v1, v5
992984
; GFX940-NEXT: global_store_dwordx4 v4, v[0:3], s[4:5] sc0 sc1
993985
; GFX940-NEXT: s_endpgm

0 commit comments

Comments
 (0)