Skip to content

AMDGPU: Convert vector 64-bit shl to 32-bit if shift amt >= 32 #132964

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 28, 2025

Conversation

LU-JOHN
Copy link
Contributor

@LU-JOHN LU-JOHN commented Mar 25, 2025

Convert vector 64-bit shl to 32-bit if shift amt is known to be >= 32.

@llvmbot
Copy link
Member

llvmbot commented Mar 25, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: None (LU-JOHN)

Changes

Convert vector 64-bit shl to 32-bit if shift amt is known to be >= 32.


Full diff: https://github.com/llvm/llvm-project/pull/132964.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp (+32-12)
  • (modified) llvm/test/CodeGen/AMDGPU/shl64_reduce.ll (+20-14)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index dbbd67cea27e5..22621bea2a40f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -4084,7 +4084,7 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
     }
   }
 
-  if (VT != MVT::i64)
+  if (VT.getScalarType() != MVT::i64)
     return SDValue();
 
   // i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
@@ -4092,21 +4092,24 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
   // On some subtargets, 64-bit shift is a quarter rate instruction. In the
   // common case, splitting this into a move and a 32-bit shift is faster and
   // the same code size.
-  EVT TargetType = VT.getHalfSizedIntegerVT(*DAG.getContext());
-  EVT TargetVecPairType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
   KnownBits Known = DAG.computeKnownBits(RHS);
 
-  if (Known.getMinValue().getZExtValue() < TargetType.getSizeInBits())
+  EVT ElementType = VT.getScalarType();
+  EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
+  EVT TargetType = (VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
+                                  : TargetScalarType);
+
+  if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
     return SDValue();
   SDValue ShiftAmt;
 
   if (CRHS) {
-    ShiftAmt =
-        DAG.getConstant(RHSVal - TargetType.getSizeInBits(), SL, TargetType);
+    ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
+                               TargetType);
   } else {
     SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
     const SDValue ShiftMask =
-        DAG.getConstant(TargetType.getSizeInBits() - 1, SL, TargetType);
+        DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
     // This AND instruction will clamp out of bounds shift values.
     // It will also be removed during later instruction selection.
     ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
@@ -4116,9 +4119,21 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
   SDValue NewShift =
       DAG.getNode(ISD::SHL, SL, TargetType, Lo, ShiftAmt, N->getFlags());
 
-  const SDValue Zero = DAG.getConstant(0, SL, TargetType);
-
-  SDValue Vec = DAG.getBuildVector(TargetVecPairType, SL, {Zero, NewShift});
+  const SDValue Zero = DAG.getConstant(0, SL, TargetScalarType);
+  SDValue Vec;
+
+  if (VT.isVector()) {
+    EVT ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
+    unsigned NElts = TargetType.getVectorNumElements();
+    SmallVector<SDValue, 8> Ops;
+    DAG.ExtractVectorElements(NewShift, Ops, 0, NElts);
+    for (unsigned I = 0; I != NElts; ++I)
+      Ops.insert(Ops.begin() + 2 * I, Zero);
+    Vec = DAG.getNode(ISD::BUILD_VECTOR, SL, ConcatType, Ops);
+  } else {
+    EVT ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
+    Vec = DAG.getBuildVector(ConcatType, SL, {Zero, NewShift});
+  }
   return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
 }
 
@@ -5182,9 +5197,14 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
     break;
   }
   case ISD::SHL: {
-    if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
+    // Range metadata can be invalidated when loads are converted to legal types
+    // (e.g. v2i64 -> v4i32).
+    // Try to convert vector shl before type legalization so that range metadata
+    // can be utilized.
+    if (!(N->getValueType(0).isVector() &&
+          DCI.getDAGCombineLevel() == BeforeLegalizeTypes) &&
+        DCI.getDAGCombineLevel() < AfterLegalizeDAG)
       break;
-
     return performShlCombine(N, DCI);
   }
   case ISD::SRL: {
diff --git a/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll b/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
index 69242f4e44840..21b7ed4d6b779 100644
--- a/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
+++ b/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
@@ -72,39 +72,41 @@ define i64 @shl_metadata_cant_be_narrowed_to_i32(i64 %arg0, ptr %arg1.ptr) {
   ret i64 %shl
 }
 
-; FIXME: This case should be reduced
 define <2 x i64> @shl_v2_metadata(<2 x i64> %arg0, ptr %arg1.ptr) {
 ; CHECK-LABEL: shl_v2_metadata:
 ; CHECK:       ; %bb.0:
 ; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; CHECK-NEXT:    flat_load_dwordx4 v[4:7], v[4:5]
+; CHECK-NEXT:    flat_load_dwordx4 v[3:6], v[4:5]
 ; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; CHECK-NEXT:    v_lshlrev_b64 v[0:1], v4, v[0:1]
-; CHECK-NEXT:    v_lshlrev_b64 v[2:3], v6, v[2:3]
+; CHECK-NEXT:    v_lshlrev_b32_e32 v1, v3, v0
+; CHECK-NEXT:    v_lshlrev_b32_e32 v3, v5, v2
+; CHECK-NEXT:    v_mov_b32_e32 v0, 0
+; CHECK-NEXT:    v_mov_b32_e32 v2, 0
 ; CHECK-NEXT:    s_setpc_b64 s[30:31]
   %shift.amt = load <2 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
   %shl = shl <2 x i64> %arg0, %shift.amt
   ret <2 x i64> %shl
 }
 
-; FIXME: This case should be reduced
 define <3 x i64> @shl_v3_metadata(<3 x i64> %arg0, ptr %arg1.ptr) {
 ; CHECK-LABEL: shl_v3_metadata:
 ; CHECK:       ; %bb.0:
 ; CHECK-NEXT:    s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; CHECK-NEXT:    flat_load_dword v12, v[6:7] offset:16
+; CHECK-NEXT:    flat_load_dword v1, v[6:7] offset:16
 ; CHECK-NEXT:    flat_load_dwordx4 v[8:11], v[6:7]
 ; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; CHECK-NEXT:    v_lshlrev_b64 v[4:5], v12, v[4:5]
-; CHECK-NEXT:    v_lshlrev_b64 v[0:1], v8, v[0:1]
-; CHECK-NEXT:    v_lshlrev_b64 v[2:3], v10, v[2:3]
+; CHECK-NEXT:    v_lshlrev_b32_e32 v5, v1, v4
+; CHECK-NEXT:    v_lshlrev_b32_e32 v1, v8, v0
+; CHECK-NEXT:    v_lshlrev_b32_e32 v3, v10, v2
+; CHECK-NEXT:    v_mov_b32_e32 v0, 0
+; CHECK-NEXT:    v_mov_b32_e32 v2, 0
+; CHECK-NEXT:    v_mov_b32_e32 v4, 0
 ; CHECK-NEXT:    s_setpc_b64 s[30:31]
   %shift.amt = load <3 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
   %shl = shl <3 x i64> %arg0, %shift.amt
   ret <3 x i64> %shl
 }
 
-; FIXME: This case should be reduced
 define <4 x i64> @shl_v4_metadata(<4 x i64> %arg0, ptr %arg1.ptr) {
 ; CHECK-LABEL: shl_v4_metadata:
 ; CHECK:       ; %bb.0:
@@ -113,11 +115,15 @@ define <4 x i64> @shl_v4_metadata(<4 x i64> %arg0, ptr %arg1.ptr) {
 ; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
 ; CHECK-NEXT:    flat_load_dwordx4 v[13:16], v[8:9] offset:16
 ; CHECK-NEXT:    ; kill: killed $vgpr8 killed $vgpr9
-; CHECK-NEXT:    v_lshlrev_b64 v[0:1], v10, v[0:1]
-; CHECK-NEXT:    v_lshlrev_b64 v[2:3], v12, v[2:3]
+; CHECK-NEXT:    v_lshlrev_b32_e32 v1, v10, v0
+; CHECK-NEXT:    v_lshlrev_b32_e32 v3, v12, v2
 ; CHECK-NEXT:    s_waitcnt vmcnt(0) lgkmcnt(0)
-; CHECK-NEXT:    v_lshlrev_b64 v[4:5], v13, v[4:5]
-; CHECK-NEXT:    v_lshlrev_b64 v[6:7], v15, v[6:7]
+; CHECK-NEXT:    v_lshlrev_b32_e32 v5, v13, v4
+; CHECK-NEXT:    v_lshlrev_b32_e32 v7, v15, v6
+; CHECK-NEXT:    v_mov_b32_e32 v0, 0
+; CHECK-NEXT:    v_mov_b32_e32 v2, 0
+; CHECK-NEXT:    v_mov_b32_e32 v4, 0
+; CHECK-NEXT:    v_mov_b32_e32 v6, 0
 ; CHECK-NEXT:    s_setpc_b64 s[30:31]
   %shift.amt = load <4 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
   %shl = shl <4 x i64> %arg0, %shift.amt

@LU-JOHN LU-JOHN changed the title Convert vector 64-bit shl to 32-bit if shift amt >= 32 AMDGPU: Convert vector 64-bit shl to 32-bit if shift amt >= 32 Mar 25, 2025
@arsenm arsenm merged commit 827f2ad into llvm:main Mar 28, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants