Skip to content

Commit 9065fd4

Browse files
committed
Introduce helpers for materializing size_t values and propagate
proper usage of it through SimplifyLibCalls.
1 parent d87236d commit 9065fd4

File tree

4 files changed

+57
-55
lines changed

4 files changed

+57
-55
lines changed

llvm/include/llvm/Analysis/TargetLibraryInfo.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
namespace llvm {
2121

2222
template <typename T> class ArrayRef;
23+
class ConstantInt;
2324

2425
/// Provides info so a possible vectorization of a function can be
2526
/// computed. Function 'VectorFnName' is equivalent to 'ScalarFnName'
@@ -249,6 +250,12 @@ class TargetLibraryInfoImpl {
249250
/// Returns the size of the size_t type in bits.
250251
unsigned getSizeTSize(const Module &M) const;
251252

253+
/// Returns an IntegerType corresponding to size_t.
254+
IntegerType *getSizeTType(const Module &M) const;
255+
256+
/// Returns a constant materialized as a size_t type.
257+
ConstantInt *getAsSizeT(uint64_t V, const Module &M) const;
258+
252259
/// Get size of a C-level int or unsigned int, in bits.
253260
unsigned getIntSize() const {
254261
return SizeOfInt;
@@ -565,6 +572,16 @@ class TargetLibraryInfo {
565572
/// \copydoc TargetLibraryInfoImpl::getSizeTSize()
566573
unsigned getSizeTSize(const Module &M) const { return Impl->getSizeTSize(M); }
567574

575+
/// \copydoc TargetLibraryInfoImpl::getSizeTType()
576+
IntegerType *getSizeTType(const Module &M) const {
577+
return Impl->getSizeTType(M);
578+
}
579+
580+
/// \copydoc TargetLibraryInfoImpl::getAsSizeT()
581+
ConstantInt *getAsSizeT(uint64_t V, const Module &M) const {
582+
return Impl->getAsSizeT(V, M);
583+
}
584+
568585
/// \copydoc TargetLibraryInfoImpl::getIntSize()
569586
unsigned getIntSize() const {
570587
return Impl->getIntSize();

llvm/lib/Analysis/TargetLibraryInfo.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,15 @@ unsigned TargetLibraryInfoImpl::getSizeTSize(const Module &M) const {
14711471
return M.getDataLayout().getIndexSizeInBits(AddressSpace);
14721472
}
14731473

1474+
IntegerType *TargetLibraryInfoImpl::getSizeTType(const Module &M) const {
1475+
return IntegerType::get(M.getContext(), getSizeTSize(M));
1476+
}
1477+
1478+
ConstantInt *TargetLibraryInfoImpl::getAsSizeT(uint64_t V,
1479+
const Module &M) const {
1480+
return ConstantInt::get(getSizeTType(M), V);
1481+
}
1482+
14741483
TargetLibraryInfoWrapperPass::TargetLibraryInfoWrapperPass()
14751484
: ImmutablePass(ID), TLA(TargetLibraryInfoImpl()) {
14761485
initializeTargetLibraryInfoWrapperPassPass(*PassRegistry::getPassRegistry());

llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp

Lines changed: 29 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,8 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len,
397397

398398
// We have enough information to now generate the memcpy call to do the
399399
// concatenation for us. Make a memcpy to copy the nul byte with align = 1.
400-
B.CreateMemCpy(
401-
CpyDst, Align(1), Src, Align(1),
402-
ConstantInt::get(DL.getIntPtrType(Src->getContext()), Len + 1));
400+
B.CreateMemCpy(CpyDst, Align(1), Src, Align(1),
401+
TLI->getAsSizeT(Len + 1, *B.GetInsertBlock()->getModule()));
403402
return Dst;
404403
}
405404

@@ -590,26 +589,21 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
590589
if (Len1 && Len2) {
591590
return copyFlags(
592591
*CI, emitMemCmp(Str1P, Str2P,
593-
ConstantInt::get(DL.getIntPtrType(CI->getContext()),
594-
std::min(Len1, Len2)),
592+
TLI->getAsSizeT(std::min(Len1, Len2), *CI->getModule()),
595593
B, DL, TLI));
596594
}
597595

598596
// strcmp to memcmp
599597
if (!HasStr1 && HasStr2) {
600598
if (canTransformToMemCmp(CI, Str1P, Len2, DL))
601-
return copyFlags(
602-
*CI,
603-
emitMemCmp(Str1P, Str2P,
604-
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
605-
B, DL, TLI));
599+
return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
600+
TLI->getAsSizeT(Len2, *CI->getModule()),
601+
B, DL, TLI));
606602
} else if (HasStr1 && !HasStr2) {
607603
if (canTransformToMemCmp(CI, Str2P, Len1, DL))
608-
return copyFlags(
609-
*CI,
610-
emitMemCmp(Str1P, Str2P,
611-
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
612-
B, DL, TLI));
604+
return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
605+
TLI->getAsSizeT(Len1, *CI->getModule()),
606+
B, DL, TLI));
613607
}
614608

615609
annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
@@ -676,19 +670,15 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
676670
if (!HasStr1 && HasStr2) {
677671
Len2 = std::min(Len2, Length);
678672
if (canTransformToMemCmp(CI, Str1P, Len2, DL))
679-
return copyFlags(
680-
*CI,
681-
emitMemCmp(Str1P, Str2P,
682-
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
683-
B, DL, TLI));
673+
return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
674+
TLI->getAsSizeT(Len2, *CI->getModule()),
675+
B, DL, TLI));
684676
} else if (HasStr1 && !HasStr2) {
685677
Len1 = std::min(Len1, Length);
686678
if (canTransformToMemCmp(CI, Str2P, Len1, DL))
687-
return copyFlags(
688-
*CI,
689-
emitMemCmp(Str1P, Str2P,
690-
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
691-
B, DL, TLI));
679+
return copyFlags(*CI, emitMemCmp(Str1P, Str2P,
680+
TLI->getAsSizeT(Len1, *CI->getModule()),
681+
B, DL, TLI));
692682
}
693683

694684
return nullptr;
@@ -722,15 +712,13 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) {
722712

723713
// We have enough information to now generate the memcpy call to do the
724714
// copy for us. Make a memcpy to copy the nul byte with align = 1.
725-
CallInst *NewCI =
726-
B.CreateMemCpy(Dst, Align(1), Src, Align(1),
727-
ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len));
715+
CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
716+
TLI->getAsSizeT(Len, *CI->getModule()));
728717
mergeAttributesAndFlags(NewCI, *CI);
729718
return Dst;
730719
}
731720

732721
Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
733-
Function *Callee = CI->getCalledFunction();
734722
Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1);
735723

736724
// stpcpy(d,s) -> strcpy(d,s) if the result is not used.
@@ -749,10 +737,9 @@ Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
749737
else
750738
return nullptr;
751739

752-
Type *PT = Callee->getFunctionType()->getParamType(0);
753-
Value *LenV = ConstantInt::get(DL.getIntPtrType(PT), Len);
740+
Value *LenV = TLI->getAsSizeT(Len, *CI->getModule());
754741
Value *DstEnd = B.CreateInBoundsGEP(
755-
B.getInt8Ty(), Dst, ConstantInt::get(DL.getIntPtrType(PT), Len - 1));
742+
B.getInt8Ty(), Dst, TLI->getAsSizeT(Len - 1, *CI->getModule()));
756743

757744
// We have enough information to now generate the memcpy call to do the
758745
// copy for us. Make a memcpy to copy the nul byte with align = 1.
@@ -819,13 +806,11 @@ Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
819806
return ConstantInt::get(CI->getType(), 0);
820807
}
821808

822-
Function *Callee = CI->getCalledFunction();
823-
Type *PT = Callee->getFunctionType()->getParamType(0);
824809
// Transform strlcpy(D, S, N) to memcpy(D, S, N') where N' is the lower
825810
// bound on strlen(S) + 1 and N, optionally followed by a nul store to
826811
// D[N' - 1] if necessary.
827812
CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
828-
ConstantInt::get(DL.getIntPtrType(PT), NBytes));
813+
TLI->getAsSizeT(NBytes, *CI->getModule()));
829814
mergeAttributesAndFlags(NewCI, *CI);
830815

831816
if (!NulTerm) {
@@ -844,7 +829,6 @@ Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
844829
// otherwise.
845830
Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
846831
IRBuilderBase &B) {
847-
Function *Callee = CI->getCalledFunction();
848832
Value *Dst = CI->getArgOperand(0);
849833
Value *Src = CI->getArgOperand(1);
850834
Value *Size = CI->getArgOperand(2);
@@ -922,11 +906,10 @@ Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
922906
/*M=*/nullptr, /*AddNull=*/false);
923907
}
924908

925-
Type *PT = Callee->getFunctionType()->getParamType(0);
926909
// st{p,r}ncpy(D, S, N) -> memcpy(align 1 D, align 1 S, N) when both
927910
// S and N are constant.
928911
CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
929-
ConstantInt::get(DL.getIntPtrType(PT), N));
912+
TLI->getAsSizeT(N, *CI->getModule()));
930913
mergeAttributesAndFlags(NewCI, *CI);
931914
if (!RetEnd)
932915
return Dst;
@@ -3438,10 +3421,9 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
34383421
return nullptr; // we found a format specifier, bail out.
34393422

34403423
// sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1)
3441-
B.CreateMemCpy(
3442-
Dest, Align(1), CI->getArgOperand(1), Align(1),
3443-
ConstantInt::get(DL.getIntPtrType(CI->getContext()),
3444-
FormatStr.size() + 1)); // Copy the null byte.
3424+
B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(1), Align(1),
3425+
// Copy the null byte.
3426+
TLI->getAsSizeT(FormatStr.size() + 1, *CI->getModule()));
34453427
return ConstantInt::get(CI->getType(), FormatStr.size());
34463428
}
34473429

@@ -3476,9 +3458,8 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
34763458

34773459
uint64_t SrcLen = GetStringLength(CI->getArgOperand(2));
34783460
if (SrcLen) {
3479-
B.CreateMemCpy(
3480-
Dest, Align(1), CI->getArgOperand(2), Align(1),
3481-
ConstantInt::get(DL.getIntPtrType(CI->getContext()), SrcLen));
3461+
B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(2), Align(1),
3462+
TLI->getAsSizeT(SrcLen, *CI->getModule()));
34823463
// Returns total number of characters written without null-character.
34833464
return ConstantInt::get(CI->getType(), SrcLen - 1);
34843465
} else if (Value *V = emitStpCpy(Dest, CI->getArgOperand(2), B, TLI)) {
@@ -3576,11 +3557,8 @@ Value *LibCallSimplifier::emitSnPrintfMemCpy(CallInst *CI, Value *StrArg,
35763557
Value *DstArg = CI->getArgOperand(0);
35773558
if (NCopy && StrArg)
35783559
// Transform the call to lvm.memcpy(dst, fmt, N).
3579-
copyFlags(
3580-
*CI,
3581-
B.CreateMemCpy(
3582-
DstArg, Align(1), StrArg, Align(1),
3583-
ConstantInt::get(DL.getIntPtrType(CI->getContext()), NCopy)));
3560+
copyFlags(*CI, B.CreateMemCpy(DstArg, Align(1), StrArg, Align(1),
3561+
TLI->getAsSizeT(NCopy, *CI->getModule())));
35843562

35853563
if (N > Str.size())
35863564
// Return early when the whole format string, including the final nul,
@@ -3696,11 +3674,9 @@ Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI,
36963674
if (FormatStr.contains('%'))
36973675
return nullptr; // We found a format specifier.
36983676

3699-
unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
3700-
Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
37013677
return copyFlags(
37023678
*CI, emitFWrite(CI->getArgOperand(1),
3703-
ConstantInt::get(SizeTTy, FormatStr.size()),
3679+
TLI->getAsSizeT(FormatStr.size(), *CI->getModule()),
37043680
CI->getArgOperand(0), B, DL, TLI));
37053681
}
37063682

llvm/test/Transforms/InstCombine/strcpy-nonzero-as.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ define void @test_strncpy_to_memcpy(ptr addrspace(200) %dst) addrspace(200) noun
5252
; CHECK-LABEL: define {{[^@]+}}@test_strncpy_to_memcpy
5353
; CHECK-SAME: (ptr addrspace(200) [[DST:%.*]]) addrspace(200) #[[ATTR1]] {
5454
; CHECK-NEXT: entry:
55-
; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i128(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i128 17, i1 false)
55+
; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i64(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i64 17, i1 false)
5656
; CHECK-NEXT: ret void
5757
;
5858
entry:
@@ -64,7 +64,7 @@ define void @test_stpncpy_to_memcpy(ptr addrspace(200) %dst) addrspace(200) noun
6464
; CHECK-LABEL: define {{[^@]+}}@test_stpncpy_to_memcpy
6565
; CHECK-SAME: (ptr addrspace(200) [[DST:%.*]]) addrspace(200) #[[ATTR1]] {
6666
; CHECK-NEXT: entry:
67-
; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i128(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i128 17, i1 false)
67+
; CHECK-NEXT: call addrspace(200) void @llvm.memcpy.p200.p200.i64(ptr addrspace(200) noundef align 1 dereferenceable(17) [[DST]], ptr addrspace(200) noundef align 1 dereferenceable(17) @str, i64 17, i1 false)
6868
; CHECK-NEXT: ret void
6969
;
7070
entry:

0 commit comments

Comments
 (0)