Skip to content

Decompose gep of complex type struct to its element type #107848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/IR/Instructions.h
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,10 @@ class GetElementPtrInst : public Instruction {
/// a constant offset between them.
bool hasAllConstantIndices() const;

/// Return true if all of the indices of this GEP are constant integer zero
/// except the last indice.
bool hasAllZeroIndicesExceptLast() const;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs doc string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


/// Set nowrap flags for GEP instruction.
void setNoWrapFlags(GEPNoWrapFlags NW);

Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/IR/Instructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,17 @@ bool GetElementPtrInst::hasAllConstantIndices() const {
return true;
}

/// hasAllZeroIndicesExceptLast - Return true if all of the indices of this GEP
/// are constant integer zero except the last indice.
bool GetElementPtrInst::hasAllZeroIndicesExceptLast() const {
for (unsigned i = 1, e = getNumOperands() - 1; i != e; ++i) {
Constant *C = dyn_cast<Constant>(getOperand(i));
if (!C || !C->isZeroValue())
return false;
}
return true;
}

void GetElementPtrInst::setNoWrapFlags(GEPNoWrapFlags NW) {
SubclassOptionalData = NW.getRaw();
}
Expand Down
35 changes: 35 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2805,6 +2805,41 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
GEP.getNoWrapFlags()));
}

// For complex type: %"struct.std::complex" = type { { float, float } }
// Canonicalize
// - %idxprom = sext i32 %Off to i64
// - inbounds [100 x %"struct.std::complex"], ptr @p, i64 0, i64 %idx
// into
// - %idxprom.scale = shl nsw i32 %Off, 3
// - %1 = sext i32 %idxprom.scale to i64
// - getelementptr inbounds i8, ptr @p, i64 %1
auto *GepResElTy = GEP.getResultElementType();
if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 1)
GepResElTy = GepResElTy->getStructElementType(0);
if (GepResElTy->isStructTy() && GepResElTy->getStructNumElements() == 2 &&
GepResElTy->getStructElementType(0) ==
GepResElTy->getStructElementType(1) &&
GEP.hasAllZeroIndicesExceptLast()) {
unsigned LastIndice = GEP.getNumOperands() - 1;
Value *LastOp = GEP.getOperand(LastIndice);
if (auto *SExtI = dyn_cast<SExtInst>(LastOp)) {
GEPOperator *GEPOp = cast<GEPOperator>(&GEP);
bool NSW = GEPOp->hasNoUnsignedSignedWrap();
bool NUW = GEPOp->hasNoUnsignedWrap();
// We'll let instcombine(mul) convert this to a shl if possible.
unsigned EleTypeBytes =
GepResElTy->getStructElementType(0)->getScalarSizeInBits() / 8;
if (EleTypeBytes > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test with { i1, i1 }?

auto IntTy = SExtI->getOperand(0)->getType();
Value *Offset = Builder.CreateMul(
SExtI->getOperand(0), ConstantInt::get(IntTy, 2 * EleTypeBytes),
SExtI->getName() + ".scale", NUW, NSW);
return replaceInstUsesWith(
GEP, Builder.CreatePtrAdd(PtrOp, Offset, "", GEP.getNoWrapFlags()));
}
}
}

// Canonicalize
// - scalable GEPs to an explicit offset using the llvm.vscale intrinsic.
// This has better support in BasicAA.
Expand Down
36 changes: 36 additions & 0 deletions llvm/test/Transforms/InstCombine/gep-complex.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt -S -passes=instcombine < %s | FileCheck %s

%"class.std::__1::complex" = type { float, float }
@mdlComplex = dso_local global [10000 x %"class.std::__1::complex"] zeroinitializer, align 4

define float @decompose_complex_scalar(ptr %array) {
; CHECK-LABEL: @decompose_complex_scalar(
; CHECK-NEXT: [[VAL:%.*]] = load i32, ptr [[ARRAY:%.*]], align 4
; CHECK-NEXT: [[SEXTVAL_SCALE:%.*]] = shl nsw i32 [[VAL]], 3
; CHECK-NEXT: [[TMP1:%.*]] = sext i32 [[SEXTVAL_SCALE]] to i64
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i8, ptr @mdlComplex, i64 [[TMP1]]
; CHECK-NEXT: [[RES:%.*]] = load float, ptr [[ARRAYIDX]], align 4
; CHECK-NEXT: ret float [[RES]]
;
%val = load i32, ptr %array, align 4
%sextVal = sext i32 %val to i64
%arrayidx = getelementptr inbounds [10000 x %"class.std::__1::complex"], ptr @mdlComplex, i32 0, i64 %sextVal
%res = load float, ptr %arrayidx, align 4
ret float %res
}

; A vector of pointers typed getelementptr
define <4 x ptr> @decompose_complex_vector(ptr %array, <4 x ptr> %baseptrs) {
; CHECK-LABEL: @decompose_complex_vector(
; CHECK-NEXT: [[VAL:%.*]] = load <4 x i32>, ptr [[ARRAY:%.*]], align 4
; CHECK-NEXT: [[SEXTVAL_SCALE:%.*]] = shl nsw <4 x i32> [[VAL]], <i32 3, i32 3, i32 3, i32 3>
; CHECK-NEXT: [[TMP1:%.*]] = sext <4 x i32> [[SEXTVAL_SCALE]] to <4 x i64>
; CHECK-NEXT: [[ARRAYIDX:%.*]] = getelementptr inbounds i8, <4 x ptr> [[BASEPTRS:%.*]], <4 x i64> [[TMP1]]
; CHECK-NEXT: ret <4 x ptr> [[ARRAYIDX]]
;
%val = load <4 x i32>, ptr %array, align 4
%sextVal = sext <4 x i32> %val to <4 x i64>
%arrayidx = getelementptr inbounds [10000 x %"class.std::__1::complex"], <4 x ptr> %baseptrs, <4 x i64> zeroinitializer, <4 x i64> %sextVal
ret <4 x ptr> %arrayidx
}
24 changes: 14 additions & 10 deletions llvm/test/Transforms/LoopVectorize/reduction-inloop.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1316,8 +1316,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP20:%.*]] = extractelement <4 x i1> [[TMP19]], i64 0
; CHECK-NEXT: br i1 [[TMP20]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
; CHECK: pred.load.if:
; CHECK-NEXT: [[TMP21:%.*]] = sext i32 [[INDEX]] to i64
; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP21]]
; CHECK-NEXT: [[DOTSCALE:%.*]] = shl nsw i32 [[INDEX]], 3
; CHECK-NEXT: [[TMP21:%.*]] = sext i32 [[DOTSCALE]] to i64
; CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP21]]
; CHECK-NEXT: [[TMP23:%.*]] = load i32, ptr [[TMP22]], align 4
; CHECK-NEXT: [[TMP24:%.*]] = insertelement <4 x i32> poison, i32 [[TMP23]], i64 0
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE]]
Expand All @@ -1326,8 +1327,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP26:%.*]] = extractelement <4 x i1> [[TMP19]], i64 1
; CHECK-NEXT: br i1 [[TMP26]], label [[PRED_LOAD_IF1:%.*]], label [[PRED_LOAD_CONTINUE2:%.*]]
; CHECK: pred.load.if1:
; CHECK-NEXT: [[TMP27:%.*]] = sext i32 [[TMP0]] to i64
; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP27]]
; CHECK-NEXT: [[DOTSCALE7:%.*]] = shl nsw i32 [[TMP0]], 3
; CHECK-NEXT: [[TMP27:%.*]] = sext i32 [[DOTSCALE7]] to i64
; CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP27]]
; CHECK-NEXT: [[TMP29:%.*]] = load i32, ptr [[TMP28]], align 4
; CHECK-NEXT: [[TMP30:%.*]] = insertelement <4 x i32> [[TMP25]], i32 [[TMP29]], i64 1
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE2]]
Expand All @@ -1336,8 +1338,9 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP32:%.*]] = extractelement <4 x i1> [[TMP19]], i64 2
; CHECK-NEXT: br i1 [[TMP32]], label [[PRED_LOAD_IF3:%.*]], label [[PRED_LOAD_CONTINUE4:%.*]]
; CHECK: pred.load.if3:
; CHECK-NEXT: [[TMP33:%.*]] = sext i32 [[TMP1]] to i64
; CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP33]]
; CHECK-NEXT: [[DOTSCALE8:%.*]] = shl nsw i32 [[TMP1]], 3
; CHECK-NEXT: [[TMP33:%.*]] = sext i32 [[DOTSCALE8]] to i64
; CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP33]]
; CHECK-NEXT: [[TMP35:%.*]] = load i32, ptr [[TMP34]], align 4
; CHECK-NEXT: [[TMP36:%.*]] = insertelement <4 x i32> [[TMP31]], i32 [[TMP35]], i64 2
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE4]]
Expand All @@ -1346,17 +1349,18 @@ define i32 @predicated_or_dominates_reduction(ptr %b) {
; CHECK-NEXT: [[TMP38:%.*]] = extractelement <4 x i1> [[TMP19]], i64 3
; CHECK-NEXT: br i1 [[TMP38]], label [[PRED_LOAD_IF5:%.*]], label [[PRED_LOAD_CONTINUE6]]
; CHECK: pred.load.if5:
; CHECK-NEXT: [[TMP39:%.*]] = sext i32 [[TMP2]] to i64
; CHECK-NEXT: [[TMP40:%.*]] = getelementptr inbounds [0 x %struct.e], ptr [[B]], i64 0, i64 [[TMP39]]
; CHECK-NEXT: [[DOTSCALE9:%.*]] = shl nsw i32 [[TMP2]], 3
; CHECK-NEXT: [[TMP39:%.*]] = sext i32 [[DOTSCALE9]] to i64
; CHECK-NEXT: [[TMP40:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP39]]
; CHECK-NEXT: [[TMP41:%.*]] = load i32, ptr [[TMP40]], align 4
; CHECK-NEXT: [[TMP42:%.*]] = insertelement <4 x i32> [[TMP37]], i32 [[TMP41]], i64 3
; CHECK-NEXT: br label [[PRED_LOAD_CONTINUE6]]
; CHECK: pred.load.continue6:
; CHECK-NEXT: [[TMP43:%.*]] = phi <4 x i32> [ [[TMP37]], [[PRED_LOAD_CONTINUE4]] ], [ [[TMP42]], [[PRED_LOAD_IF5]] ]
; CHECK-NEXT: [[TMP44:%.*]] = icmp ne <4 x i32> [[TMP43]], zeroinitializer
; CHECK-NEXT: [[NOT_:%.*]] = xor <4 x i1> [[TMP19]], <i1 true, i1 true, i1 true, i1 true>
; CHECK-NEXT: [[DOTNOT7:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
; CHECK-NEXT: [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT7]] to i4
; CHECK-NEXT: [[DOTNOT10:%.*]] = select <4 x i1> [[NOT_]], <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i1> [[TMP44]]
; CHECK-NEXT: [[TMP45:%.*]] = bitcast <4 x i1> [[DOTNOT10]] to i4
; CHECK-NEXT: [[TMP46:%.*]] = call range(i4 0, 5) i4 @llvm.ctpop.i4(i4 [[TMP45]])
; CHECK-NEXT: [[TMP47:%.*]] = zext nneg i4 [[TMP46]] to i32
; CHECK-NEXT: [[TMP48]] = add i32 [[VEC_PHI]], [[TMP47]]
Expand Down
Loading