Skip to content

Decompose gep of complex type struct to its element type #107848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

vfdff
Copy link
Contributor

@vfdff vfdff commented Sep 9, 2024

Similar to PR96606, but this PR address the scenario of all zero indices except the last indice, which is not a const value.

We usual wide the index of gep to be same width as pointer width, so the index of getelementptr may be offen extend to i64 for AArch64 for example. Vectorization will choose the VL according data type, so it may be <vscale x 4 x float> for float.
when its the address is comming from a struct.std::complex similar to following IR node, it need multiply by 2 in codegen, so we can't assume the (<vscale x 4 x i64> %3 * 2) can be hold by <vscale x 4 x i32>, so it splits the node llvm.masked.gather.nxv4f32.nxv4p0.

  > %4 = getelementptr inbounds [10000 x %"struct.std::complex"], ptr @mdlComplex, i64 0, <vscale x 4 x i64> %3
  > %wide.masked.gather = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> %4, i32 4, <vscale x 4 x i1> %active.lane.mask, <vscale x 4 x float> poison)

This PR decompose gep of complex type struct to its element type, then the index of getelementptr doesn't need multiply, so it dones't need split the llvm.masked.gather.nxv4f32.nxv4p0 if we known its offset extend from i32, .ie, it changes the (sext a) << 1 into sext (a<<1)

Fix #107825

Similar to PR96606, but this PR address the scenario of all zero indices
except the last indice, which is not a const value.

We usual wide the index of gep to be same width as pointer width, so the
index of getelementptr may be offen extend to i64 for AArch64 for example.
Vectorization will choose the VL according data type, so it may be
<vscale x 4 x float> for float.
when its the address is comming from a struct.std::complex similar to
following IR node, it need multiply by 2 in codegen, so we can't assume
the (<vscale x 4 x i64> %3 * 2) can be hold by <vscale x 4 x i32>, so it
splits the node llvm.masked.gather.nxv4f32.nxv4p0.
```
  > %4 = getelementptr inbounds [10000 x %"struct.std::complex"], ptr @mdlComplex, i64 0, <vscale x 4 x i64> %3
  > %wide.masked.gather = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> %4, i32 4, <vscale x 4 x i1> %active.lane.mask, <vscale x 4 x float> poison)
```

This PR decompos gep of complex type struct to its element type,
then the index of getelementptr doesn't need multiply, so it dones't
need split the llvm.masked.gather.nxv4f32.nxv4p0 it we known its offset
extend from i32.

Fix llvm#107825
@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: Allen (vfdff)

Changes

Similar to PR96606, but this PR address the scenario of all zero indices except the last indice, which is not a const value.

We usual wide the index of gep to be same width as pointer width, so the index of getelementptr may be offen extend to i64 for AArch64 for example. Vectorization will choose the VL according data type, so it may be <vscale x 4 x float> for float.
when its the address is comming from a struct.std::complex similar to following IR node, it need multiply by 2 in codegen, so we can't assume the (<vscale x 4 x i64> %3 * 2) can be hold by <vscale x 4 x i32>, so it splits the node llvm.masked.gather.nxv4f32.nxv4p0.

  &gt; %4 = getelementptr inbounds [10000 x %"struct.std::complex"], ptr @<!-- -->mdlComplex, i64 0, &lt;vscale x 4 x i64&gt; %3
  &gt; %wide.masked.gather = tail call &lt;vscale x 4 x float&gt; @<!-- -->llvm.masked.gather.nxv4f32.nxv4p0(&lt;vscale x 4 x ptr&gt; %4, i32 4, &lt;vscale x 4 x i1&gt; %active.lane.mask, &lt;vscale x 4 x float&gt; poison)

This PR decompos gep of complex type struct to its element type, then the index of getelementptr doesn't need multiply, so it dones't need split the llvm.masked.gather.nxv4f32.nxv4p0 it we known its offset extend from i32.

Fix #107825


Full diff: https://github.com/llvm/llvm-project/pull/107848.diff

4 Files Affected:

  • (modified) llvm/include/llvm/IR/Instructions.h (+2)
  • (modified) llvm/lib/IR/Instructions.cpp (+11)
  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+32)
  • (modified) llvm/test/Transforms/LoopVectorize/reduction-inloop.ll (+14-10)
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index ab3321ee755717..c5a2e2cb442d29 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1078,6 +1078,8 @@ class GetElementPtrInst : public Instruction {
   /// a constant offset between them.
   bool hasAllConstantIndices() const;
 
+  bool hasAllZeroIndicesExceptLast() const;
+
   /// Set nowrap flags for GEP instruction.
   void setNoWrapFlags(GEPNoWrapFlags NW);
 
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 19da1f60d424d2..8363ff2c070d69 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1555,6 +1555,17 @@ bool GetElementPtrInst::hasAllConstantIndices() const {
   return true;
 }
 
+/// hasAllZeroIndicesExceptLast - Return true if all of the indices of this GEP
+/// are zero except the last indice.
+bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const {
+  for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) {
+    if (!isa<ConstantInt>(getOperand(i)) ||
+        !cast<ConstantInt>(getOperand(i))->isZero())
+      return false;
+  }
+  return true;
+}
+
 void GetElementPtrInst::setNoWrapFlags(GEPNoWrapFlags NW) {
   SubclassOptionalData = NW.getRaw();
 }
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8195e0539305cc..1de6d189ea7ae0 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2805,6 +2805,38 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
                                     GEP.getNoWrapFlags()));
   }
 
+  // For complex type: %"struct.std::complex" = type { { float, float } }
+  // Canonicalize
+  //  - %idxprom = sext i32 %Off to i64
+  //  - inbounds [100 x %"struct.std::complex"], ptr @p, i64 0, i64 %idx
+  // into
+  //  - %idxprom.scale = shl nsw i32 %Off, 1
+  //  - %1 = sext i32 %idxprom.scale to i64
+  //  - getelementptr inbounds float, ptr @p, i64 %1
+  auto *GepResElTy = GEP.getResultElementType();
+  if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 1)
+    GepResElTy = GepResElTy->getStructElementType(0);
+  if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 2 &&
+      GepResElTy->getStructElementType(0) ==
+          GepResElTy->getStructElementType(1) &&
+      GEP.hasAllZeroIndicesExceptLast()) {
+    unsigned LastIndice = GEP.getNumOperands() - 1;
+    Value *LastOp = GEP.getOperand(LastIndice);
+    if (auto *SExtI = dyn_cast<SExtInst>(LastOp)) {
+      GEPOperator *GEPOp = cast<GEPOperator>(&GEP);
+      bool NSW = GEPOp->hasNoUnsignedSignedWrap();
+      bool NUW = GEPOp->hasNoUnsignedWrap();
+      // We'll let instcombine(mul) convert this to a shl if possible.
+      auto IntTy = SExtI->getOperand(0)->getType();
+      Value *Offset =
+          Builder.CreateMul(SExtI->getOperand(0), ConstantInt::get(IntTy, 2),
+                            SExtI->getName() + ".scale", NUW, NSW);
+      return replaceInstUsesWith(
+          GEP, Builder.CreateGEP(GepResElTy->getStructElementType(0), PtrOp,
+                                 Offset, "", GEP.getNoWrapFlags()));
+    }
+  }
+
   // Canonicalize
   //  - scalable GEPs to an explicit offset using the llvm.vscale intrinsic.
   //    This has better support in BasicAA.
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index 7c5d6a1edf0b4b..e55a7931818939 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1316,8 +1316,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
 ; CHECK-NEXT:    [[TMP20:%.*]] = extractelement <4 x i1> [[TMP19]], i64 0
 ; CHECK-NEXT:    br i1 [[TMP20]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
-; CHECK-NEXT:    [[TMP21:%.*]] = sext i32 [[INDEX]] to i64
-; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP21]]
+; CHECK-NEXT:    [[DOTSCALE:%.*]] = shl nsw i32 [[INDEX]], 1
+; CHECK-NEXT:    [[TMP21:%.*]] = sext i32 [[DOTSCALE]] to i64
+; CHECK-NEXT:    [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
 ; CHECK-NEXT:    [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
 ; CHECK-NEXT:    [[TMP24:%.*]] = insertelement <4 x i32> poison, i32 [[TMP23]], i64 0
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE]]
@@ -1326,8 +1327,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
 ; CHECK-NEXT:    [[TMP26:%.*]] = extractelement <4 x i1> [[TMP19]], i64 1
 ; CHECK-NEXT:    br i1 [[TMP26]], label [[PRED_LOAD_IF1:%.*]], label [[PRED_LOAD_CONTINUE2:%.*]]
 ; CHECK:       pred.load.if1:
-; CHECK-NEXT:    [[TMP27:%.*]] = sext i32 [[TMP0]] to i64
-; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP27]]
+; CHECK-NEXT:    [[DOTSCALE7:%.*]] = shl nsw i32 [[TMP0]], 1
+; CHECK-NEXT:    [[TMP27:%.*]] = sext i32 [[DOTSCALE7]] to i64
+; CHECK-NEXT:    [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP27]]
 ; CHECK-NEXT:    [[TMP29:%.*]] = load i32, ptr [[TMP28]], align 4
 ; CHECK-NEXT:    [[TMP30:%.*]] = insertelement <4 x i32> [[TMP25]], i32 [[TMP29]], i64 1
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE2]]
@@ -1336,8 +1338,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
 ; CHECK-NEXT:    [[TMP32:%.*]] = extractelement <4 x i1> [[TMP19]], i64 2
 ; CHECK-NEXT:    br i1 [[TMP32]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
 ; CHECK:       pred.load.if3:
-; CHECK-NEXT:    [[TMP33:%.*]] = sext i32 [[TMP1]] to i64
-; CHECK-NEXT:    [[TMP34:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP33]]
+; CHECK-NEXT:    [[DOTSCALE8:%.*]] = shl nsw i32 [[TMP1]], 1
+; CHECK-NEXT:    [[TMP33:%.*]] = sext i32 [[DOTSCALE8]] to i64
+; CHECK-NEXT:    [[TMP34:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP33]]
 ; CHECK-NEXT:    [[TMP35:%.*]] = load i32, ptr [[TMP34]], align 4
 ; CHECK-NEXT:    [[TMP36:%.*]] = insertelement <4 x i32> [[TMP31]], i32 [[TMP35]], i64 2
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE4]]
@@ -1346,8 +1349,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
 ; CHECK-NEXT:    [[TMP38:%.*]] = extractelement <4 x i1> [[TMP19]], i64 3
 ; CHECK-NEXT:    br i1 [[TMP38]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6]]
 ; CHECK:       pred.load.if5:
-; CHECK-NEXT:    [[TMP39:%.*]] = sext i32 [[TMP2]] to i64
-; CHECK-NEXT:    [[TMP40:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP39]]
+; CHECK-NEXT:    [[DOTSCALE9:%.*]] = shl nsw i32 [[TMP2]], 1
+; CHECK-NEXT:    [[TMP39:%.*]] = sext i32 [[DOTSCALE9]] to i64
+; CHECK-NEXT:    [[TMP40:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP39]]
 ; CHECK-NEXT:    [[TMP41:%.*]] = load i32, ptr [[TMP40]], align 4
 ; CHECK-NEXT:    [[TMP42:%.*]] = insertelement <4 x i32> [[TMP37]], i32 [[TMP41]], i64 3
 ; CHECK-NEXT:    br label [[PRED_LOAD_CONTINUE6]]
@@ -1355,8 +1359,8 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
 ; CHECK-NEXT:    [[TMP43:%.*]] = phi <4 x i32> [ [[TMP37]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP42]], [[PRED_LOAD_IF5]] ]
 ; CHECK-NEXT:    [[TMP44:%.*]] = icmp ne <4 x i32> [[TMP43]], zeroinitializer
 ; CHECK-NEXT:    [[NOT_:%.*]] = xor <4 x i1> [[TMP19]], <i1 true, i1 true, i1 true, i1 true>
-; CHECK-NEXT:    [[DOTNOT7:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
-; CHECK-NEXT:    [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT7]] to i4
+; CHECK-NEXT:    [[DOTNOT10:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
+; CHECK-NEXT:    [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT10]] to i4
 ; CHECK-NEXT:    [[TMP46:%.*]] = call range(i4 0, 5) i4 @llvm.ctpop.i4(i4 [[TMP45]])
 ; CHECK-NEXT:    [[TMP47:%.*]] = zext nneg i4 [[TMP46]] to i32
 ; CHECK-NEXT:    [[TMP48]] = add i32 [[VEC_PHI]], [[TMP47]]

@vfdff vfdff requested a review from dtcxzyw September 12, 2024 01:18
@dtcxzyw dtcxzyw changed the title Decompos gep of complex type struct to its element type Decompose gep of complex type struct to its element type Sep 12, 2024
@vfdff
Copy link
Contributor Author

vfdff commented Sep 14, 2024

ping :)

dtcxzyw added a commit to dtcxzyw/llvm-opt-benchmark that referenced this pull request Sep 14, 2024
@vfdff vfdff requested a review from arsenm September 18, 2024 01:17
Comment on lines 1562 to 1563
if (!isa<ConstantInt>(getOperand(i)) ||
!cast<ConstantInt>(getOperand(i))->isZero())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dyn_cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apply your comment, thanks

@@ -1078,6 +1078,8 @@ class GetElementPtrInst : public Instruction {
/// a constant offset between them.
bool hasAllConstantIndices() const;

bool hasAllZeroIndicesExceptLast() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs doc string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

/// are zero except the last indice.
bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const {
for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) {
ConstantInt *Val = dyn_cast<ConstantInt>(getOperand(i));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this handle the vector case too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, also add a vector case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your vector test is not a vector test. This code will not work for a vector of pointers typed getelementptr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean something like the following one ?

define <4 x ptr> @decompose_complex_vector1(ptr %array, <4 x ptr> %baseptrs) {
  %val = load <4 x i32>, ptr %array, align 4
  %sextVal = sext <4 x i32> %val to <4 x i64>
  %arrayidx = getelementptr inbounds [10000 x %"class.std::__1::complex"], <4 x ptr> %baseptrs, <4 x i64> zeroinitializer, <4 x i64> %sextVal
  ret <4 x ptr> %arrayidx
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

support the vector case now, thanks

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This patch doesn't make sense to me. It partially decomposes the GEP, so that we both have an explicit multiplication in the offset, and a GEP scale multiply. We should have only one or the other.

On top of that, this is driven by a GEP source element type based heuristic, which is not allowed.

@vfdff
Copy link
Contributor Author

vfdff commented Sep 24, 2024

hi, nikic, thanks for your comment.

  1. For the 1st point, can I add OneUse to avoid that issue you mentioned?
  2. For the 2nd point, It seems similar to pr96606, so do you mean that we should also driven by GEP.getSourceElementType(), but not the GEP.getResultElementType() ?

@nikic
Copy link
Contributor

nikic commented Sep 25, 2024

hi, nikic, thanks for your comment.

1. For the 1st point, can I add  **OneUse** to avoid that issue you mentioned?

2. For the 2nd point, It seems similar to pr96606, so do you mean that we should also driven by **GEP.getSourceElementType()**, but not the **GEP.getResultElementType()** ?

If you're referencing PRs, please use a proper link and don't make me reassemble the URL by hand: #96606

That PR does something very reasonable, which is to convert the GEP into ptradd representation if this allows merging some of the offset arithmetic. It leaves behind only explicit offset arithmetic, not a mix of offset arithmetic and GEP scaling. Basically, the end result for your case would be an i8 GEP with sext (a<<3) rather than a float GEP with sext (a << 1).

@vfdff
Copy link
Contributor Author

vfdff commented Sep 25, 2024

Thank you for your advice, so I need to adjust to i8 GEP directly here too?
I've tried to adjust directly to i8 GEP before, but I see some test cases regression, then I'll see why.

Offset, "", GEP.getNoWrapFlags()));
unsigned EleTypeBytes =
GepResElTy->getStructElementType(0)->getScalarSizeInBits() / 8;
if (EleTypeBytes > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test with { i1, i1 }?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[AArch64] Decompose gep of complex type struct for better codegen
5 participants