Skip to content

[AMDGPU] Simplify lowerBUILD_VECTOR #109094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 18, 2024
Merged

[AMDGPU] Simplify lowerBUILD_VECTOR #109094

merged 3 commits into from
Sep 18, 2024

Conversation

piotrAMD
Copy link
Collaborator

@piotrAMD piotrAMD commented Sep 18, 2024

Simplify lowerBUILD_VECTOR by commoning up the way the vectors are split.
Also reorder the checks to avoid a long condition inside if.

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2024

@llvm/pr-subscribers-backend-amdgpu

Author: Piotr Sobczak (piotrAMD)

Changes

Simplify lowerBUILD_VECTOR by commoning up the way the vectors are split.
Also reorder the checks to avoid a long condition inside if.


Full diff: https://github.com/llvm/llvm-project/pull/109094.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/SIISelLowering.cpp (+31-80)
  • (modified) llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll (+1-9)
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 4a861f0c03a0c5..10108866a7005a 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -7443,98 +7443,49 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
   SDLoc SL(Op);
   EVT VT = Op.getValueType();
 
-  if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
-      VT == MVT::v8f16 || VT == MVT::v4bf16 || VT == MVT::v8bf16) {
-    EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
-                                  VT.getVectorNumElements() / 2);
-    MVT HalfIntVT = MVT::getIntegerVT(HalfVT.getSizeInBits());
+  if (VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16) {
+    assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
 
-    // Turn into pair of packed build_vectors.
-    // TODO: Special case for constants that can be materialized with s_mov_b64.
-    SmallVector<SDValue, 4> LoOps, HiOps;
-    for (unsigned I = 0, E = VT.getVectorNumElements() / 2; I != E; ++I) {
-      LoOps.push_back(Op.getOperand(I));
-      HiOps.push_back(Op.getOperand(I + E));
-    }
-    SDValue Lo = DAG.getBuildVector(HalfVT, SL, LoOps);
-    SDValue Hi = DAG.getBuildVector(HalfVT, SL, HiOps);
-
-    SDValue CastLo = DAG.getNode(ISD::BITCAST, SL, HalfIntVT, Lo);
-    SDValue CastHi = DAG.getNode(ISD::BITCAST, SL, HalfIntVT, Hi);
-
-    SDValue Blend = DAG.getBuildVector(MVT::getVectorVT(HalfIntVT, 2), SL,
-                                       { CastLo, CastHi });
-    return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
-  }
+    SDValue Lo = Op.getOperand(0);
+    SDValue Hi = Op.getOperand(1);
 
-  if (VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v16bf16) {
-    EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
-                                     VT.getVectorNumElements() / 4);
-    MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
-
-    SmallVector<SDValue, 4> Parts[4];
-    for (unsigned I = 0, E = VT.getVectorNumElements() / 4; I != E; ++I) {
-      for (unsigned P = 0; P < 4; ++P)
-        Parts[P].push_back(Op.getOperand(I + P * E));
-    }
-    SDValue Casts[4];
-    for (unsigned P = 0; P < 4; ++P) {
-      SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
-      Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
+    // Avoid adding defined bits with the zero_extend.
+    if (Hi.isUndef()) {
+      Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
+      SDValue ExtLo = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Lo);
+      return DAG.getNode(ISD::BITCAST, SL, VT, ExtLo);
     }
 
-    SDValue Blend =
-        DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 4), SL, Casts);
-    return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
-  }
+    Hi = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Hi);
+    Hi = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Hi);
 
-  if (VT == MVT::v32i16 || VT == MVT::v32f16 || VT == MVT::v32bf16) {
-    EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
-                                     VT.getVectorNumElements() / 8);
-    MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
+    SDValue ShlHi = DAG.getNode(ISD::SHL, SL, MVT::i32, Hi,
+                                DAG.getConstant(16, SL, MVT::i32));
+    if (Lo.isUndef())
+      return DAG.getNode(ISD::BITCAST, SL, VT, ShlHi);
 
-    SmallVector<SDValue, 8> Parts[8];
-    for (unsigned I = 0, E = VT.getVectorNumElements() / 8; I != E; ++I) {
-      for (unsigned P = 0; P < 8; ++P)
-        Parts[P].push_back(Op.getOperand(I + P * E));
-    }
-    SDValue Casts[8];
-    for (unsigned P = 0; P < 8; ++P) {
-      SDValue Vec = DAG.getBuildVector(QuarterVT, SL, Parts[P]);
-      Casts[P] = DAG.getNode(ISD::BITCAST, SL, QuarterIntVT, Vec);
-    }
+    Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
+    Lo = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Lo);
 
-    SDValue Blend =
-        DAG.getBuildVector(MVT::getVectorVT(QuarterIntVT, 8), SL, Casts);
-    return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
+    SDValue Or = DAG.getNode(ISD::OR, SL, MVT::i32, Lo, ShlHi);
+    return DAG.getNode(ISD::BITCAST, SL, VT, Or);
   }
 
-  assert(VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16);
-  assert(!Subtarget->hasVOP3PInsts() && "this should be legal");
+  // Split into 2-element chunks.
+  const unsigned NumParts = VT.getVectorNumElements() / 2;
+  EVT PartVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(), 2);
+  MVT PartIntVT = MVT::getIntegerVT(PartVT.getSizeInBits());
 
-  SDValue Lo = Op.getOperand(0);
-  SDValue Hi = Op.getOperand(1);
-
-  // Avoid adding defined bits with the zero_extend.
-  if (Hi.isUndef()) {
-    Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
-    SDValue ExtLo = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Lo);
-    return DAG.getNode(ISD::BITCAST, SL, VT, ExtLo);
+  SmallVector<SDValue> Casts;
+  for (unsigned P = 0; P < NumParts; ++P) {
+    SDValue Vec = DAG.getBuildVector(
+        PartVT, SL, {Op.getOperand(P * 2), Op.getOperand(P * 2 + 1)});
+    Casts.push_back(DAG.getNode(ISD::BITCAST, SL, PartIntVT, Vec));
   }
 
-  Hi = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Hi);
-  Hi = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Hi);
-
-  SDValue ShlHi = DAG.getNode(ISD::SHL, SL, MVT::i32, Hi,
-                              DAG.getConstant(16, SL, MVT::i32));
-  if (Lo.isUndef())
-    return DAG.getNode(ISD::BITCAST, SL, VT, ShlHi);
-
-  Lo = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Lo);
-  Lo = DAG.getNode(ISD::ZERO_EXTEND, SL, MVT::i32, Lo);
-
-  SDValue Or = DAG.getNode(ISD::OR, SL, MVT::i32, Lo, ShlHi);
-  return DAG.getNode(ISD::BITCAST, SL, VT, Or);
+  SDValue Blend =
+      DAG.getBuildVector(MVT::getVectorVT(PartIntVT, NumParts), SL, Casts);
+  return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
 }
 
 bool
diff --git a/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll b/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
index 3135addec16183..c68138acc9b2bf 100644
--- a/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
+++ b/llvm/test/CodeGen/AMDGPU/insert_vector_elt.v2bf16.ll
@@ -965,11 +965,7 @@ define amdgpu_kernel void @v_insertelement_v8bf16_3(ptr addrspace(1) %out, ptr a
 ; GFX900-NEXT:    v_mov_b32_e32 v5, 0x5040100
 ; GFX900-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX900-NEXT:    global_load_dwordx4 v[0:3], v4, s[2:3]
-; GFX900-NEXT:    s_mov_b32 s2, 0xffff
 ; GFX900-NEXT:    s_waitcnt vmcnt(0)
-; GFX900-NEXT:    v_bfi_b32 v3, s2, v3, v3
-; GFX900-NEXT:    v_bfi_b32 v2, s2, v2, v2
-; GFX900-NEXT:    v_bfi_b32 v0, s2, v0, v0
 ; GFX900-NEXT:    v_perm_b32 v1, s4, v1, v5
 ; GFX900-NEXT:    global_store_dwordx4 v4, v[0:3], s[0:1]
 ; GFX900-NEXT:    s_endpgm
@@ -980,14 +976,10 @@ define amdgpu_kernel void @v_insertelement_v8bf16_3(ptr addrspace(1) %out, ptr a
 ; GFX940-NEXT:    s_load_dword s0, s[2:3], 0x10
 ; GFX940-NEXT:    v_and_b32_e32 v0, 0x3ff, v0
 ; GFX940-NEXT:    v_lshlrev_b32_e32 v4, 4, v0
-; GFX940-NEXT:    s_mov_b32 s1, 0xffff
+; GFX940-NEXT:    v_mov_b32_e32 v5, 0x5040100
 ; GFX940-NEXT:    s_waitcnt lgkmcnt(0)
 ; GFX940-NEXT:    global_load_dwordx4 v[0:3], v4, s[6:7]
-; GFX940-NEXT:    v_mov_b32_e32 v5, 0x5040100
 ; GFX940-NEXT:    s_waitcnt vmcnt(0)
-; GFX940-NEXT:    v_bfi_b32 v3, s1, v3, v3
-; GFX940-NEXT:    v_bfi_b32 v2, s1, v2, v2
-; GFX940-NEXT:    v_bfi_b32 v0, s1, v0, v0
 ; GFX940-NEXT:    v_perm_b32 v1, s0, v1, v5
 ; GFX940-NEXT:    global_store_dwordx4 v4, v[0:3], s[4:5] sc0 sc1
 ; GFX940-NEXT:    s_endpgm

VT == MVT::v8i16 || VT == MVT::v8f16 || VT == MVT::v8bf16 ||
VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v16bf16 ||
VT == MVT::v32i16 || VT == MVT::v32f16 || VT == MVT::v32bf16) {
if (VT == MVT::v2f16 || VT == MVT::v2i16 || VT == MVT::v2bf16) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit just reorders the code, so v2* code goes at the beginning.

@piotrAMD piotrAMD merged commit adf02ae into llvm:main Sep 18, 2024
10 checks passed
tmsri pushed a commit to tmsri/llvm-project that referenced this pull request Sep 19, 2024
Simplify `lowerBUILD_VECTOR` by commoning up the way the vectors
are split.
Also reorder the checks to avoid a long condition inside `if`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants