Skip to content

TargetLibraryInfo: Use pointer index size to determine getSizeTSize(). #118747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llvm/include/llvm/Analysis/TargetLibraryInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
#define LLVM_ANALYSIS_TARGETLIBRARYINFO_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Pass.h"
#include "llvm/TargetParser/Triple.h"
Expand Down Expand Up @@ -565,6 +567,16 @@ class TargetLibraryInfo {
/// \copydoc TargetLibraryInfoImpl::getSizeTSize()
unsigned getSizeTSize(const Module &M) const { return Impl->getSizeTSize(M); }

/// Returns an IntegerType corresponding to size_t.
IntegerType *getSizeTType(const Module &M) const {
return IntegerType::get(M.getContext(), getSizeTSize(M));
}

/// Returns a constant materialized as a size_t type.
ConstantInt *getAsSizeT(uint64_t V, const Module &M) const {
return ConstantInt::get(getSizeTType(M), V);
}

/// \copydoc TargetLibraryInfoImpl::getIntSize()
unsigned getIntSize() const {
return Impl->getIntSize();
Expand Down
23 changes: 10 additions & 13 deletions llvm/lib/Analysis/TargetLibraryInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1459,19 +1459,16 @@ unsigned TargetLibraryInfoImpl::getWCharSize(const Module &M) const {
}

unsigned TargetLibraryInfoImpl::getSizeTSize(const Module &M) const {
// There is really no guarantee that sizeof(size_t) is equal to sizeof(int*).
// If that isn't true then it should be possible to derive the SizeTTy from
// the target triple here instead and do an early return.

// Historically LLVM assume that size_t has same size as intptr_t (hence
// deriving the size from sizeof(int*) in address space zero). This should
// work for most targets. For future consideration: DataLayout also implement
// getIndexSizeInBits which might map better to size_t compared to
// getPointerSizeInBits. Hard coding address space zero here might be
// unfortunate as well. Maybe getDefaultGlobalsAddressSpace() or
// getAllocaAddrSpace() is better.
unsigned AddressSpace = 0;
return M.getDataLayout().getPointerSizeInBits(AddressSpace);
// There is really no guarantee that sizeof(size_t) is equal to the index
// size of the default address space. If that isn't true then it should be
// possible to derive the SizeTTy from the target triple here instead and do
// an early return.

// Hard coding address space zero may seem unfortunate, but a number of
// configurations of common targets (i386, x86-64 x32, aarch64 x32, possibly
// others) have larger-than-size_t index sizes on non-default address spaces,
// making this the best default.
return M.getDataLayout().getIndexSizeInBits(/*AddressSpace=*/0);
}

TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass()
Expand Down
82 changes: 29 additions & 53 deletions llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,8 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len,

// We have enough information to now generate the memcpy call to do the
// concatenation for us. Make a memcpy to copy the nul byte with align = 1.
B.CreateMemCpy(
CpyDst, Align(1), Src, Align(1),
ConstantInt::get(DL.getIntPtrType(Src->getContext()), Len + 1));
B.CreateMemCpy(CpyDst, Align(1), Src, Align(1),
TLI->getAsSizeT(Len + 1, *B.GetInsertBlock()->getModule()));
return Dst;
}

Expand Down Expand Up @@ -590,26 +589,21 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
if (Len1 && Len2) {
return copyFlags(
*CI, emitMemCmp(Str1P, Str2P,
ConstantInt::get(DL.getIntPtrType(CI->getContext()),
std::min(Len1, Len2)),
TLI->getAsSizeT(std::min(Len1, Len2), *CI->getModule()),
B, DL, TLI));
}

// strcmp to memcmp
if (!HasStr1 && HasStr2) {
if (canTransformToMemCmp(CI, Str1P, Len2, DL))
return copyFlags(
*CI,
emitMemCmp(Str1P, Str2P,
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
B, DL, TLI));
return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
TLI->getAsSizeT(Len2, *CI->getModule()),
B, DL, TLI));
} else if (HasStr1 && !HasStr2) {
if (canTransformToMemCmp(CI, Str2P, Len1, DL))
return copyFlags(
*CI,
emitMemCmp(Str1P, Str2P,
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
B, DL, TLI));
return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
TLI->getAsSizeT(Len1, *CI->getModule()),
B, DL, TLI));
}

annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
Expand Down Expand Up @@ -676,19 +670,15 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
if (!HasStr1 && HasStr2) {
Len2 = std::min(Len2, Length);
if (canTransformToMemCmp(CI, Str1P, Len2, DL))
return copyFlags(
*CI,
emitMemCmp(Str1P, Str2P,
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
B, DL, TLI));
return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
TLI->getAsSizeT(Len2, *CI->getModule()),
B, DL, TLI));
} else if (HasStr1 && !HasStr2) {
Len1 = std::min(Len1, Length);
if (canTransformToMemCmp(CI, Str2P, Len1, DL))
return copyFlags(
*CI,
emitMemCmp(Str1P, Str2P,
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
B, DL, TLI));
return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
TLI->getAsSizeT(Len1, *CI->getModule()),
B, DL, TLI));
}

return nullptr;
Expand Down Expand Up @@ -722,15 +712,13 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) {

// We have enough information to now generate the memcpy call to do the
// copy for us. Make a memcpy to copy the nul byte with align = 1.
CallInst *NewCI =
B.CreateMemCpy(Dst, Align(1), Src, Align(1),
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len));
CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
TLI->getAsSizeT(Len, *CI->getModule()));
mergeAttributesAndFlags(NewCI, *CI);
return Dst;
}

Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
Function *Callee = CI->getCalledFunction();
Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1);

// stpcpy(d,s) -> strcpy(d,s) if the result is not used.
Expand All @@ -749,10 +737,9 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
else
return nullptr;

Type *PT = Callee->getFunctionType()->getParamType(0);
Value *LenV = ConstantInt::get(DL.getIntPtrType(PT), Len);
Value *LenV = TLI->getAsSizeT(Len, *CI->getModule());
Value *DstEnd = B.CreateInBoundsGEP(
B.getInt8Ty(), Dst, ConstantInt::get(DL.getIntPtrType(PT), Len - 1));
B.getInt8Ty(), Dst, TLI->getAsSizeT(Len - 1, *CI->getModule()));

// We have enough information to now generate the memcpy call to do the
// copy for us. Make a memcpy to copy the nul byte with align = 1.
Expand Down Expand Up @@ -819,13 +806,11 @@ Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
return ConstantInt::get(CI->getType(), 0);
}

Function *Callee = CI->getCalledFunction();
Type *PT = Callee->getFunctionType()->getParamType(0);
// Transform strlcpy(D, S, N) to memcpy(D, S, N') where N' is the lower
// bound on strlen(S) + 1 and N, optionally followed by a nul store to
// D[N' - 1] if necessary.
CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
ConstantInt::get(DL.getIntPtrType(PT), NBytes));
TLI->getAsSizeT(NBytes, *CI->getModule()));
mergeAttributesAndFlags(NewCI, *CI);

if (!NulTerm) {
Expand All @@ -844,7 +829,6 @@ Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
// otherwise.
Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
IRBuilderBase &B) {
Function *Callee = CI->getCalledFunction();
Value *Dst = CI->getArgOperand(0);
Value *Src = CI->getArgOperand(1);
Value *Size = CI->getArgOperand(2);
Expand Down Expand Up @@ -922,11 +906,10 @@ Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
/*M=*/nullptr, /*AddNull=*/false);
}

Type *PT = Callee->getFunctionType()->getParamType(0);
// st{p,r}ncpy(D, S, N) -> memcpy(align 1 D, align 1 S, N) when both
// S and N are constant.
CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
ConstantInt::get(DL.getIntPtrType(PT), N));
TLI->getAsSizeT(N, *CI->getModule()));
mergeAttributesAndFlags(NewCI, *CI);
if (!RetEnd)
return Dst;
Expand Down Expand Up @@ -3438,10 +3421,9 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
return nullptr; // we found a format specifier, bail out.

// sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1)
B.CreateMemCpy(
Dest, Align(1), CI->getArgOperand(1), Align(1),
ConstantInt::get(DL.getIntPtrType(CI->getContext()),
FormatStr.size() + 1)); // Copy the null byte.
B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(1), Align(1),
// Copy the null byte.
TLI->getAsSizeT(FormatStr.size() + 1, *CI->getModule()));
return ConstantInt::get(CI->getType(), FormatStr.size());
}

Expand Down Expand Up @@ -3476,9 +3458,8 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,

uint64_t SrcLen = GetStringLength(CI->getArgOperand(2));
if (SrcLen) {
B.CreateMemCpy(
Dest, Align(1), CI->getArgOperand(2), Align(1),
ConstantInt::get(DL.getIntPtrType(CI->getContext()), SrcLen));
B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(2), Align(1),
TLI->getAsSizeT(SrcLen, *CI->getModule()));
// Returns total number of characters written without null-character.
return ConstantInt::get(CI->getType(), SrcLen - 1);
} else if (Value *V = emitStpCpy(Dest, CI->getArgOperand(2), B, TLI)) {
Expand Down Expand Up @@ -3576,11 +3557,8 @@ Value *LibCallSimplifier::emitSnPrintfMemCpy(CallInst *CI, Value *StrArg,
Value *DstArg = CI->getArgOperand(0);
if (NCopy && StrArg)
// Transform the call to lvm.memcpy(dst, fmt, N).
copyFlags(
*CI,
B.CreateMemCpy(
DstArg, Align(1), StrArg, Align(1),
ConstantInt::get(DL.getIntPtrType(CI->getContext()), NCopy)));
copyFlags(*CI, B.CreateMemCpy(DstArg, Align(1), StrArg, Align(1),
TLI->getAsSizeT(NCopy, *CI->getModule())));

if (N > Str.size())
// Return early when the whole format string, including the final nul,
Expand Down Expand Up @@ -3696,11 +3674,9 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI,
if (FormatStr.contains('%'))
return nullptr; // We found a format specifier.

unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
return copyFlags(
*CI, emitFWrite(CI->getArgOperand(1),
ConstantInt::get(SizeTTy, FormatStr.size()),
TLI->getAsSizeT(FormatStr.size(), *CI->getModule()),
CI->getArgOperand(0), B, DL, TLI));
}

Expand Down
5 changes: 3 additions & 2 deletions llvm/test/Transforms/InstCombine/stdio-custom-dl.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ target datalayout = "e-m:o-p:40:64:64:32-i64:64-f80:128-n8:16:32:64-S128"
@.str.1 = private unnamed_addr constant [2 x i8] c"w\00", align 1
@.str.2 = private unnamed_addr constant [4 x i8] c"str\00", align 1

; Check fwrite is generated with arguments of ptr size, not index size
;; Check fwrite is generated with arguments of index size, not ptr size

define internal void @fputs_test_custom_dl() {
; CHECK-LABEL: @fputs_test_custom_dl(
; CHECK-NEXT: [[CALL:%.*]] = call ptr @fopen(ptr nonnull @.str, ptr nonnull @.str.1)
; CHECK-NEXT: [[TMP1:%.*]] = call i40 @fwrite(ptr nonnull @.str.2, i40 3, i40 1, ptr [[CALL]])
; CHECK-NEXT: [[TMP1:%.*]] = call i32 @fwrite(ptr nonnull @.str.2, i32 3, i32 1, ptr %call)
; CHECK-NEXT: ret void
;
%call = call ptr @fopen(ptr @.str, ptr @.str.1)
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ define void @test_strncpy_to_memcpy(ptr addrspace(200) %dst) addrspace(200) noun
; CHECK-LABEL: define {{[^@]+}}@test_strncpy_to_memcpy
; CHECK-SAME: (ptr addrspace(200) [[DST:%.*]]) addrspace(200) #[[ATTR1]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i128(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i128 17, i1 false)
; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i64(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i64 17, i1 false)
; CHECK-NEXT: ret void
;
entry:
Expand All @@ -64,7 +64,7 @@ define void @test_stpncpy_to_memcpy(ptr addrspace(200) %dst) addrspace(200) noun
; CHECK-LABEL: define {{[^@]+}}@test_stpncpy_to_memcpy
; CHECK-SAME: (ptr addrspace(200) [[DST:%.*]]) addrspace(200) #[[ATTR1]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i128(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i128 17, i1 false)
; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i64(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i64 17, i1 false)
; CHECK-NEXT: ret void
;
entry:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ target triple = "x86_64"
target datalayout = "e-p:64:64:64:32"

; Define a cunstom data layout that has index width < pointer width
; and make sure that doesn't mreak anything
; and make sure that doesn't break anything
define void @fat_ptrs(ptr dereferenceable(16) %a, ptr dereferenceable(16) %b) {
; CHECK-LABEL: @fat_ptrs(
; CHECK-NEXT: bb0:
; CHECK-NEXT: [[PTR_A1:%.*]] = getelementptr inbounds [2 x i64], ptr [[A:%.*]], i32 0, i32 1
; CHECK-NEXT: [[PTR_B1:%.*]] = getelementptr inbounds [2 x i64], ptr [[B:%.*]], i32 0, i32 1
; CHECK-NEXT: br label %"bb1+bb2"
; CHECK: "bb1+bb2":
; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[A]], ptr [[B]], i64 16)
; CHECK-NEXT: [[MEMCMP:%.*]] = call i32 @memcmp(ptr [[A]], ptr [[B]], i32 16)
; CHECK-NEXT: [[TMP0:%.*]] = icmp eq i32 [[MEMCMP]], 0
; CHECK-NEXT: br label [[BB3:%.*]]
; CHECK: bb3:
Expand Down
Loading