Skip to content

[InstCombine] Allow load to store forwarding for scalable structs #123908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/include/llvm/Analysis/Loads.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ Value *FindAvailableLoadedValue(LoadInst *Load, BasicBlock *ScanBB,
/// This overload cannot be used to scan across multiple blocks.
Value *FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
bool *IsLoadCSE,
unsigned MaxInstsToScan = DefMaxInstsToScan);
unsigned MaxInstsToScan = DefMaxInstsToScan,
bool AllowPartwiseBitcastStructs = false);

/// Scan backwards to see if we have the value of the given pointer available
/// locally within a small number of instructions.
Expand Down
25 changes: 20 additions & 5 deletions llvm/lib/Analysis/Loads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,8 @@ static bool areNonOverlapSameBaseLoadAndStore(const Value *LoadPtr,

static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr,
Type *AccessTy, bool AtLeastAtomic,
const DataLayout &DL, bool *IsLoadCSE) {
const DataLayout &DL, bool *IsLoadCSE,
bool AllowPartwiseBitcastStructs = false) {
// If this is a load of Ptr, the loaded value is available.
// (This is true even if the load is volatile or atomic, although
// those cases are unlikely.)
Expand Down Expand Up @@ -572,6 +573,19 @@ static Value *getAvailableLoadStore(Instruction *Inst, const Value *Ptr,
if (CastInst::isBitOrNoopPointerCastable(Val->getType(), AccessTy, DL))
return Val;

if (AllowPartwiseBitcastStructs) {
if (StructType *SrcStructTy = dyn_cast<StructType>(Val->getType())) {
if (StructType *DestStructTy = dyn_cast<StructType>(AccessTy)) {
if (SrcStructTy->getNumElements() == DestStructTy->getNumElements() &&
all_of_zip(SrcStructTy->elements(), DestStructTy->elements(),
[](Type *T1, Type *T2) {
return CastInst::isBitCastable(T1, T2);
}))
return Val;
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may handle the case where the struct members are bitcastable but have different offsets because of different member alignment incorrectly?


TypeSize StoreSize = DL.getTypeSizeInBits(Val->getType());
TypeSize LoadSize = DL.getTypeSizeInBits(AccessTy);
if (TypeSize::isKnownLE(LoadSize, StoreSize))
Expand Down Expand Up @@ -704,8 +718,8 @@ Value *llvm::findAvailablePtrLoadStore(
}

Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
bool *IsLoadCSE,
unsigned MaxInstsToScan) {
bool *IsLoadCSE, unsigned MaxInstsToScan,
bool AllowPartwiseBitcastStructs) {
const DataLayout &DL = Load->getDataLayout();
Value *StrippedPtr = Load->getPointerOperand()->stripPointerCasts();
BasicBlock *ScanBB = Load->getParent();
Expand All @@ -727,8 +741,9 @@ Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BatchAAResults &AA,
if (MaxInstsToScan-- == 0)
return nullptr;

Available = getAvailableLoadStore(&Inst, StrippedPtr, AccessTy,
AtLeastAtomic, DL, IsLoadCSE);
Available =
getAvailableLoadStore(&Inst, StrippedPtr, AccessTy, AtLeastAtomic, DL,
IsLoadCSE, AllowPartwiseBitcastStructs);
if (Available)
break;

Expand Down
16 changes: 15 additions & 1 deletion llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,10 +1010,24 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) {
// separated by a few arithmetic operations.
bool IsLoadCSE = false;
BatchAAResults BatchAA(*AA);
if (Value *AvailableVal = FindAvailableLoadedValue(&LI, BatchAA, &IsLoadCSE)) {
if (Value *AvailableVal =
FindAvailableLoadedValue(&LI, BatchAA, &IsLoadCSE, DefMaxInstsToScan,
/*AllowPartwiseBitcastStructs=*/true)) {
if (IsLoadCSE)
combineMetadataForCSE(cast<LoadInst>(AvailableVal), &LI, false);

if (AvailableVal->getType() != LI.getType() &&
isa<StructType>(LI.getType())) {
StructType *DstST = cast<StructType>(LI.getType());
Value *R = PoisonValue::get(LI.getType());
for (unsigned I = 0, E = DstST->getNumElements(); I < E; I++) {
Value *Ext = Builder.CreateExtractValue(AvailableVal, I);
Value *BC =
Builder.CreateBitOrPointerCast(Ext, DstST->getElementType(I));
R = Builder.CreateInsertValue(R, BC, I);
}
return replaceInstUsesWith(LI, R);
}
return replaceInstUsesWith(
LI, Builder.CreateBitOrPointerCast(AvailableVal, LI.getType(),
LI.getName() + ".cast"));
Expand Down
86 changes: 86 additions & 0 deletions llvm/test/Transforms/InstCombine/availableloadstruct.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=instcombine < %s | FileCheck %s

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"

define {<16 x i8>, <16 x i8>} @check_v16i8_v4i32({<4 x i32>, <4 x i32>} %x, ptr %p) nounwind {
; CHECK-LABEL: define { <16 x i8>, <16 x i8> } @check_v16i8_v4i32(
; CHECK-SAME: { <4 x i32>, <4 x i32> } [[X:%.*]], ptr [[P:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[X_ELT:%.*]] = extractvalue { <4 x i32>, <4 x i32> } [[X]], 0
; CHECK-NEXT: store <4 x i32> [[X_ELT]], ptr [[P]], align 16
; CHECK-NEXT: [[P_REPACK1:%.*]] = getelementptr inbounds nuw i8, ptr [[P]], i64 16
; CHECK-NEXT: [[X_ELT2:%.*]] = extractvalue { <4 x i32>, <4 x i32> } [[X]], 1
; CHECK-NEXT: store <4 x i32> [[X_ELT2]], ptr [[P_REPACK1]], align 16
; CHECK-NEXT: [[R_UNPACK_CAST:%.*]] = bitcast <4 x i32> [[X_ELT]] to <16 x i8>
; CHECK-NEXT: [[TMP0:%.*]] = insertvalue { <16 x i8>, <16 x i8> } poison, <16 x i8> [[R_UNPACK_CAST]], 0
; CHECK-NEXT: [[R_UNPACK4_CAST:%.*]] = bitcast <4 x i32> [[X_ELT2]] to <16 x i8>
; CHECK-NEXT: [[R5:%.*]] = insertvalue { <16 x i8>, <16 x i8> } [[TMP0]], <16 x i8> [[R_UNPACK4_CAST]], 1
; CHECK-NEXT: ret { <16 x i8>, <16 x i8> } [[R5]]
;
entry:
store {<4 x i32>, <4 x i32>} %x, ptr %p
%r = load {<16 x i8>, <16 x i8>}, ptr %p
ret {<16 x i8>, <16 x i8>} %r
}

define {<vscale x 16 x i8>, <vscale x 16 x i8>} @check_nxv16i8_nxv4i32({<vscale x 4 x i32>, <vscale x 4 x i32>} %x, ptr %p) nounwind {
; CHECK-LABEL: define { <vscale x 16 x i8>, <vscale x 16 x i8> } @check_nxv16i8_nxv4i32(
; CHECK-SAME: { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X:%.*]], ptr [[P:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: store { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], ptr [[P]], align 16
; CHECK-NEXT: [[TMP0:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], 0
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <vscale x 4 x i32> [[TMP0]] to <vscale x 16 x i8>
; CHECK-NEXT: [[TMP2:%.*]] = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } poison, <vscale x 16 x i8> [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], 1
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <vscale x 4 x i32> [[TMP3]] to <vscale x 16 x i8>
; CHECK-NEXT: [[R:%.*]] = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[TMP2]], <vscale x 16 x i8> [[TMP4]], 1
; CHECK-NEXT: ret { <vscale x 16 x i8>, <vscale x 16 x i8> } [[R]]
;
entry:
store {<vscale x 4 x i32>, <vscale x 4 x i32>} %x, ptr %p
%r = load {<vscale x 16 x i8>, <vscale x 16 x i8>}, ptr %p
ret {<vscale x 16 x i8>, <vscale x 16 x i8>} %r
}

define {<vscale x 16 x i8>, <vscale x 16 x i8>} @alloca_nxv16i8_nxv4i32({<vscale x 4 x i32>, <vscale x 4 x i32>} %x) nounwind {
; CHECK-LABEL: define { <vscale x 16 x i8>, <vscale x 16 x i8> } @alloca_nxv16i8_nxv4i32(
; CHECK-SAME: { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP0:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], 0
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <vscale x 4 x i32> [[TMP0]] to <vscale x 16 x i8>
; CHECK-NEXT: [[TMP2:%.*]] = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } poison, <vscale x 16 x i8> [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { <vscale x 4 x i32>, <vscale x 4 x i32> } [[X]], 1
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <vscale x 4 x i32> [[TMP3]] to <vscale x 16 x i8>
; CHECK-NEXT: [[R:%.*]] = insertvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[TMP2]], <vscale x 16 x i8> [[TMP4]], 1
; CHECK-NEXT: ret { <vscale x 16 x i8>, <vscale x 16 x i8> } [[R]]
;
entry:
%p = alloca {<vscale x 4 x i32>, <vscale x 4 x i32>}
store {<vscale x 4 x i32>, <vscale x 4 x i32>} %x, ptr %p
%r = load {<vscale x 16 x i8>, <vscale x 16 x i8>}, ptr %p
ret {<vscale x 16 x i8>, <vscale x 16 x i8>} %r
}

define { <16 x i8>, <32 x i8> } @differenttypes({ <4 x i32>, <8 x i32> } %a, ptr %p) {
; CHECK-LABEL: define { <16 x i8>, <32 x i8> } @differenttypes(
; CHECK-SAME: { <4 x i32>, <8 x i32> } [[A:%.*]], ptr [[P:%.*]]) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: call void @llvm.lifetime.start.p0(i64 -1, ptr nonnull [[P]])
; CHECK-NEXT: store { <4 x i32>, <8 x i32> } [[A]], ptr [[P]], align 16
; CHECK-NEXT: [[TMP5:%.*]] = extractvalue { <4 x i32>, <8 x i32> } [[A]], 0
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[TMP5]] to <16 x i8>
; CHECK-NEXT: [[TMP2:%.*]] = insertvalue { <16 x i8>, <32 x i8> } poison, <16 x i8> [[TMP1]], 0
; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { <4 x i32>, <8 x i32> } [[A]], 1
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <8 x i32> [[TMP3]] to <32 x i8>
; CHECK-NEXT: [[TMP0:%.*]] = insertvalue { <16 x i8>, <32 x i8> } [[TMP2]], <32 x i8> [[TMP4]], 1
; CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 -1, ptr nonnull [[P]])
; CHECK-NEXT: ret { <16 x i8>, <32 x i8> } [[TMP0]]
;
entry:
call void @llvm.lifetime.start.p0(i64 -1, ptr nonnull %p) #5
store { <4 x i32>, <8 x i32> } %a, ptr %p, align 16
%2 = load { <16 x i8>, <32 x i8> }, ptr %p, align 16
call void @llvm.lifetime.end.p0(i64 -1, ptr nonnull %p) #5
ret { <16 x i8>, <32 x i8> } %2
}
Loading