-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Decompose gep of complex type struct to its element type #107848
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Similar to PR96606, but this PR address the scenario of all zero indices except the last indice, which is not a const value. We usual wide the index of gep to be same width as pointer width, so the index of getelementptr may be offen extend to i64 for AArch64 for example. Vectorization will choose the VL according data type, so it may be <vscale x 4 x float> for float. when its the address is comming from a struct.std::complex similar to following IR node, it need multiply by 2 in codegen, so we can't assume the (<vscale x 4 x i64> %3 * 2) can be hold by <vscale x 4 x i32>, so it splits the node llvm.masked.gather.nxv4f32.nxv4p0. ``` > %4 = getelementptr inbounds [10000 x %"struct.std::complex"], ptr @mdlComplex, i64 0, <vscale x 4 x i64> %3 > %wide.masked.gather = tail call <vscale x 4 x float> @llvm.masked.gather.nxv4f32.nxv4p0(<vscale x 4 x ptr> %4, i32 4, <vscale x 4 x i1> %active.lane.mask, <vscale x 4 x float> poison) ``` This PR decompos gep of complex type struct to its element type, then the index of getelementptr doesn't need multiply, so it dones't need split the llvm.masked.gather.nxv4f32.nxv4p0 it we known its offset extend from i32. Fix llvm#107825
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-llvm-ir Author: Allen (vfdff) ChangesSimilar to PR96606, but this PR address the scenario of all zero indices except the last indice, which is not a const value. We usual wide the index of gep to be same width as pointer width, so the index of getelementptr may be offen extend to i64 for AArch64 for example. Vectorization will choose the VL according data type, so it may be <vscale x 4 x float> for float.
This PR decompos gep of complex type struct to its element type, then the index of getelementptr doesn't need multiply, so it dones't need split the llvm.masked.gather.nxv4f32.nxv4p0 it we known its offset extend from i32. Fix #107825 Full diff: https://github.com/llvm/llvm-project/pull/107848.diff 4 Files Affected:
diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h
index ab3321ee755717..c5a2e2cb442d29 100644
--- a/llvm/include/llvm/IR/Instructions.h
+++ b/llvm/include/llvm/IR/Instructions.h
@@ -1078,6 +1078,8 @@ class GetElementPtrInst : public Instruction {
/// a constant offset between them.
bool hasAllConstantIndices() const;
+ bool hasAllZeroIndicesExceptLast() const;
+
/// Set nowrap flags for GEP instruction.
void setNoWrapFlags(GEPNoWrapFlags NW);
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 19da1f60d424d2..8363ff2c070d69 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1555,6 +1555,17 @@ bool GetElementPtrInst::hasAllConstantIndices() const {
return true;
}
+/// hasAllZeroIndicesExceptLast - Return true if all of the indices of this GEP
+/// are zero except the last indice.
+bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const {
+ for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) {
+ if (!isa<ConstantInt>(getOperand(i)) ||
+ !cast<ConstantInt>(getOperand(i))->isZero())
+ return false;
+ }
+ return true;
+}
+
void GetElementPtrInst::setNoWrapFlags(GEPNoWrapFlags NW) {
SubclassOptionalData = NW.getRaw();
}
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 8195e0539305cc..1de6d189ea7ae0 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2805,6 +2805,38 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
GEP.getNoWrapFlags()));
}
+ // For complex type: %"struct.std::complex" = type { { float, float } }
+ // Canonicalize
+ // - %idxprom = sext i32 %Off to i64
+ // - inbounds [100 x %"struct.std::complex"], ptr @p, i64 0, i64 %idx
+ // into
+ // - %idxprom.scale = shl nsw i32 %Off, 1
+ // - %1 = sext i32 %idxprom.scale to i64
+ // - getelementptr inbounds float, ptr @p, i64 %1
+ auto *GepResElTy = GEP.getResultElementType();
+ if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 1)
+ GepResElTy = GepResElTy->getStructElementType(0);
+ if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 2 &&
+ GepResElTy->getStructElementType(0) ==
+ GepResElTy->getStructElementType(1) &&
+ GEP.hasAllZeroIndicesExceptLast()) {
+ unsigned LastIndice = GEP.getNumOperands() - 1;
+ Value *LastOp = GEP.getOperand(LastIndice);
+ if (auto *SExtI = dyn_cast<SExtInst>(LastOp)) {
+ GEPOperator *GEPOp = cast<GEPOperator>(&GEP);
+ bool NSW = GEPOp->hasNoUnsignedSignedWrap();
+ bool NUW = GEPOp->hasNoUnsignedWrap();
+ // We'll let instcombine(mul) convert this to a shl if possible.
+ auto IntTy = SExtI->getOperand(0)->getType();
+ Value *Offset =
+ Builder.CreateMul(SExtI->getOperand(0), ConstantInt::get(IntTy, 2),
+ SExtI->getName() + ".scale", NUW, NSW);
+ return replaceInstUsesWith(
+ GEP, Builder.CreateGEP(GepResElTy->getStructElementType(0), PtrOp,
+ Offset, "", GEP.getNoWrapFlags()));
+ }
+ }
+
// Canonicalize
// - scalable GEPs to an explicit offset using the llvm.vscale intrinsic.
// This has better support in BasicAA.
diff --git a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
index 7c5d6a1edf0b4b..e55a7931818939 100644
--- a/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
+++ b/llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
@@ -1316,8 +1316,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i1> [[TMP19]], i64 0
; CHECK-NEXT: br i1 [[TMP20]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
; CHECK: pred.load.if:
-; CHECK-NEXT: [[TMP21:%.*]] = sext i32 [[INDEX]] to i64
-; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP21]]
+; CHECK-NEXT: [[DOTSCALE:%.*]] = shl nsw i32 [[INDEX]], 1
+; CHECK-NEXT: [[TMP21:%.*]] = sext i32 [[DOTSCALE]] to i64
+; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP21]]
; CHECK-NEXT: [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
; CHECK-NEXT: [[TMP24:%.*]] = insertelement <4 x i32> poison, i32 [[TMP23]], i64 0
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE]]
@@ -1326,8 +1327,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x i1> [[TMP19]], i64 1
; CHECK-NEXT: br i1 [[TMP26]], label [[PRED_LOAD_IF1:%.*]], label [[PRED_LOAD_CONTINUE2:%.*]]
; CHECK: pred.load.if1:
-; CHECK-NEXT: [[TMP27:%.*]] = sext i32 [[TMP0]] to i64
-; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP27]]
+; CHECK-NEXT: [[DOTSCALE7:%.*]] = shl nsw i32 [[TMP0]], 1
+; CHECK-NEXT: [[TMP27:%.*]] = sext i32 [[DOTSCALE7]] to i64
+; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP27]]
; CHECK-NEXT: [[TMP29:%.*]] = load i32, ptr [[TMP28]], align 4
; CHECK-NEXT: [[TMP30:%.*]] = insertelement <4 x i32> [[TMP25]], i32 [[TMP29]], i64 1
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE2]]
@@ -1336,8 +1338,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x i1> [[TMP19]], i64 2
; CHECK-NEXT: br i1 [[TMP32]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
; CHECK: pred.load.if3:
-; CHECK-NEXT: [[TMP33:%.*]] = sext i32 [[TMP1]] to i64
-; CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP33]]
+; CHECK-NEXT: [[DOTSCALE8:%.*]] = shl nsw i32 [[TMP1]], 1
+; CHECK-NEXT: [[TMP33:%.*]] = sext i32 [[DOTSCALE8]] to i64
+; CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP33]]
; CHECK-NEXT: [[TMP35:%.*]] = load i32, ptr [[TMP34]], align 4
; CHECK-NEXT: [[TMP36:%.*]] = insertelement <4 x i32> [[TMP31]], i32 [[TMP35]], i64 2
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE4]]
@@ -1346,8 +1349,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x i1> [[TMP19]], i64 3
; CHECK-NEXT: br i1 [[TMP38]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6]]
; CHECK: pred.load.if5:
-; CHECK-NEXT: [[TMP39:%.*]] = sext i32 [[TMP2]] to i64
-; CHECK-NEXT: [[TMP40:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP39]]
+; CHECK-NEXT: [[DOTSCALE9:%.*]] = shl nsw i32 [[TMP2]], 1
+; CHECK-NEXT: [[TMP39:%.*]] = sext i32 [[DOTSCALE9]] to i64
+; CHECK-NEXT: [[TMP40:%.*]] = getelementptr inbounds i32, ptr [[B]], i64 [[TMP39]]
; CHECK-NEXT: [[TMP41:%.*]] = load i32, ptr [[TMP40]], align 4
; CHECK-NEXT: [[TMP42:%.*]] = insertelement <4 x i32> [[TMP37]], i32 [[TMP41]], i64 3
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE6]]
@@ -1355,8 +1359,8 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP43:%.*]] = phi <4 x i32> [ [[TMP37]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP42]], [[PRED_LOAD_IF5]] ]
; CHECK-NEXT: [[TMP44:%.*]] = icmp ne <4 x i32> [[TMP43]], zeroinitializer
; CHECK-NEXT: [[NOT_:%.*]] = xor <4 x i1> [[TMP19]], <i1 true, i1 true, i1 true, i1 true>
-; CHECK-NEXT: [[DOTNOT7:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
-; CHECK-NEXT: [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT7]] to i4
+; CHECK-NEXT: [[DOTNOT10:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
+; CHECK-NEXT: [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT10]] to i4
; CHECK-NEXT: [[TMP46:%.*]] = call range(i4 0, 5) i4 @llvm.ctpop.i4(i4 [[TMP45]])
; CHECK-NEXT: [[TMP47:%.*]] = zext nneg i4 [[TMP46]] to i32
; CHECK-NEXT: [[TMP48]] = add i32 [[VEC_PHI]], [[TMP47]]
|
ping :) |
llvm/lib/IR/Instructions.cpp
Outdated
if (!isa<ConstantInt>(getOperand(i)) || | ||
!cast<ConstantInt>(getOperand(i))->isZero()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dyn_cast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apply your comment, thanks
@@ -1078,6 +1078,8 @@ class GetElementPtrInst : public Instruction { | |||
/// a constant offset between them. | |||
bool hasAllConstantIndices() const; | |||
|
|||
bool hasAllZeroIndicesExceptLast() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs doc string
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
llvm/lib/IR/Instructions.cpp
Outdated
/// are zero except the last indice. | ||
bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const { | ||
for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) { | ||
ConstantInt *Val = dyn_cast<ConstantInt>(getOperand(i)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this handle the vector case too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, also add a vector case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your vector test is not a vector test. This code will not work for a vector of pointers typed getelementptr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean something like the following one ?
define <4 x ptr> @decompose_complex_vector1(ptr %array, <4 x ptr> %baseptrs) {
%val = load <4 x i32>, ptr %array, align 4
%sextVal = sext <4 x i32> %val to <4 x i64>
%arrayidx = getelementptr inbounds [10000 x %"class.std::__1::complex"], <4 x ptr> %baseptrs, <4 x i64> zeroinitializer, <4 x i64> %sextVal
ret <4 x ptr> %arrayidx
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
support the vector case now, thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This patch doesn't make sense to me. It partially decomposes the GEP, so that we both have an explicit multiplication in the offset, and a GEP scale multiply. We should have only one or the other.
On top of that, this is driven by a GEP source element type based heuristic, which is not allowed.
hi, nikic, thanks for your comment.
|
If you're referencing PRs, please use a proper link and don't make me reassemble the URL by hand: #96606 That PR does something very reasonable, which is to convert the GEP into ptradd representation if this allows merging some of the offset arithmetic. It leaves behind only explicit offset arithmetic, not a mix of offset arithmetic and GEP scaling. Basically, the end result for your case would be an i8 GEP with |
Thank you for your advice, so I need to adjust to i8 GEP directly here too? |
Offset, "", GEP.getNoWrapFlags())); | ||
unsigned EleTypeBytes = | ||
GepResElTy->getStructElementType(0)->getScalarSizeInBits() / 8; | ||
if (EleTypeBytes > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing test with { i1, i1 }
?
Similar to PR96606, but this PR address the scenario of all zero indices except the last indice, which is not a const value.
We usual wide the index of gep to be same width as pointer width, so the index of getelementptr may be offen extend to i64 for AArch64 for example. Vectorization will choose the VL according data type, so it may be <vscale x 4 x float> for float.
when its the address is comming from a struct.std::complex similar to following IR node, it need multiply by 2 in codegen, so we can't assume the (<vscale x 4 x i64> %3 * 2) can be hold by <vscale x 4 x i32>, so it splits the node llvm.masked.gather.nxv4f32.nxv4p0.
This PR decompose gep of complex type struct to its element type, then the index of getelementptr doesn't need multiply, so it dones't need split the llvm.masked.gather.nxv4f32.nxv4p0 if we known its offset extend from i32, .ie, it changes the (sext a) << 1 into sext (a<<1)
Fix #107825