Skip to content

SimplifyLibCalls: Use the correct address space when computing integer widths. #118586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

resistor
Copy link
Collaborator

@resistor resistor commented Dec 4, 2024

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Owen Anderson (resistor)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/118586.diff

1 Files Affected:

  • (modified) llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (+53-29)
diff --git a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
index d85e0d99466022..7cc3920ed8b65f 100644
--- a/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
+++ b/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp
@@ -397,9 +397,11 @@ Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len,
 
   // We have enough information to now generate the memcpy call to do the
   // concatenation for us.  Make a memcpy to copy the nul byte with align = 1.
-  B.CreateMemCpy(
-      CpyDst, Align(1), Src, Align(1),
-      ConstantInt::get(DL.getIntPtrType(Src->getContext()), Len + 1));
+  B.CreateMemCpy(CpyDst, Align(1), Src, Align(1),
+                 ConstantInt::get(
+                     DL.getIntPtrType(Src->getContext(),
+                                      Src->getType()->getPointerAddressSpace()),
+                     Len + 1));
   return Dst;
 }
 
@@ -590,8 +592,11 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
   if (Len1 && Len2) {
     return copyFlags(
         *CI, emitMemCmp(Str1P, Str2P,
-                        ConstantInt::get(DL.getIntPtrType(CI->getContext()),
-                                         std::min(Len1, Len2)),
+                        ConstantInt::get(
+                            DL.getIntPtrType(
+                                CI->getContext(),
+                                Str1P->getType()->getPointerAddressSpace()),
+                            std::min(Len1, Len2)),
                         B, DL, TLI));
   }
 
@@ -599,17 +604,23 @@ Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
   if (!HasStr1 && HasStr2) {
     if (canTransformToMemCmp(CI, Str1P, Len2, DL))
       return copyFlags(
-          *CI,
-          emitMemCmp(Str1P, Str2P,
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
-                     B, DL, TLI));
+          *CI, emitMemCmp(Str1P, Str2P,
+                          ConstantInt::get(
+                              DL.getIntPtrType(
+                                  CI->getContext(),
+                                  Str1P->getType()->getPointerAddressSpace()),
+                              Len2),
+                          B, DL, TLI));
   } else if (HasStr1 && !HasStr2) {
     if (canTransformToMemCmp(CI, Str2P, Len1, DL))
       return copyFlags(
-          *CI,
-          emitMemCmp(Str1P, Str2P,
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
-                     B, DL, TLI));
+          *CI, emitMemCmp(Str1P, Str2P,
+                          ConstantInt::get(
+                              DL.getIntPtrType(
+                                  CI->getContext(),
+                                  Str1P->getType()->getPointerAddressSpace()),
+                              Len1),
+                          B, DL, TLI));
   }
 
   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
@@ -677,18 +688,24 @@ Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
     Len2 = std::min(Len2, Length);
     if (canTransformToMemCmp(CI, Str1P, Len2, DL))
       return copyFlags(
-          *CI,
-          emitMemCmp(Str1P, Str2P,
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
-                     B, DL, TLI));
+          *CI, emitMemCmp(Str1P, Str2P,
+                          ConstantInt::get(
+                              DL.getIntPtrType(
+                                  CI->getContext(),
+                                  Str1P->getType()->getPointerAddressSpace()),
+                              Len2),
+                          B, DL, TLI));
   } else if (HasStr1 && !HasStr2) {
     Len1 = std::min(Len1, Length);
     if (canTransformToMemCmp(CI, Str2P, Len1, DL))
       return copyFlags(
-          *CI,
-          emitMemCmp(Str1P, Str2P,
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
-                     B, DL, TLI));
+          *CI, emitMemCmp(Str1P, Str2P,
+                          ConstantInt::get(
+                              DL.getIntPtrType(
+                                  CI->getContext(),
+                                  Str1P->getType()->getPointerAddressSpace()),
+                              Len1),
+                          B, DL, TLI));
   }
 
   return nullptr;
@@ -724,7 +741,7 @@ Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) {
   // copy for us.  Make a memcpy to copy the nul byte with align = 1.
   CallInst *NewCI =
       B.CreateMemCpy(Dst, Align(1), Src, Align(1),
-                     ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len));
+                     ConstantInt::get(DL.getIndexType(Dst->getType()), Len));
   mergeAttributesAndFlags(NewCI, *CI);
   return Dst;
 }
@@ -3357,7 +3374,9 @@ Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) {
     // Create a string literal with no \n on it.  We expect the constant merge
     // pass to be run after this pass, to merge duplicate strings.
     FormatStr = FormatStr.drop_back();
-    Value *GV = B.CreateGlobalString(FormatStr, "str");
+    Value *GV = B.CreateGlobalString(
+        FormatStr, "str",
+        CI->getArgOperand(1)->getType()->getPointerAddressSpace());
     return copyFlags(*CI, emitPutS(GV, B, TLI));
   }
 
@@ -3434,8 +3453,10 @@ Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
     // sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1)
     B.CreateMemCpy(
         Dest, Align(1), CI->getArgOperand(1), Align(1),
-        ConstantInt::get(DL.getIntPtrType(CI->getContext()),
-                         FormatStr.size() + 1)); // Copy the null byte.
+        ConstantInt::get(
+            DL.getIntPtrType(CI->getContext(),
+                             Dest->getType()->getPointerAddressSpace()),
+            FormatStr.size() + 1)); // Copy the null byte.
     return ConstantInt::get(CI->getType(), FormatStr.size());
   }
 
@@ -3571,10 +3592,13 @@ Value *LibCallSimplifier::emitSnPrintfMemCpy(CallInst *CI, Value *StrArg,
   if (NCopy && StrArg)
     // Transform the call to lvm.memcpy(dst, fmt, N).
     copyFlags(
-         *CI,
-          B.CreateMemCpy(
-                         DstArg, Align(1), StrArg, Align(1),
-              ConstantInt::get(DL.getIntPtrType(CI->getContext()), NCopy)));
+        *CI,
+        B.CreateMemCpy(
+            DstArg, Align(1), StrArg, Align(1),
+            ConstantInt::get(
+                DL.getIntPtrType(CI->getContext(),
+                                 DstArg->getType()->getPointerAddressSpace()),
+                NCopy)));
 
   if (N > Str.size())
     // Return early when the whole format string, including the final nul,

@llvmbot llvmbot added the llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes label Dec 4, 2024
@resistor resistor force-pushed the simplifylibcalls branch 2 times, most recently from 65aabd6 to 4abfe66 Compare December 4, 2024 06:22
Comment on lines +591 to +595
ConstantInt::get(
DL.getIntPtrType(
CI->getContext(),
Str1P->getType()->getPointerAddressSpace()),
std::min(Len1, Len2)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the memcmp case should probably have a similar wrapper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to change this, please change it to use the correct value -- which is TLI::getSizeTSize().

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, TLI::getSizeTSize() contains a nice comment explaining that it always uses addrspace(0), and noting that maybe it should consider alternatives, which is exactly the goal here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few alternatives in here, none of them great:

  • We can use the current state of this patch, which will have every target always use i64. As efriedma points out, this will probably work out because of backend legalization.
  • We can thread TLI into IRBuilder::CreateMemCpy so that it can pick the right size. This requires updating a number of other passes throughout the tree, many of which don't depends on TargetLibraryInfo today.
  • Same as above, but we add it in parallel to the existing overloads so that we can migrate users gradually.
  • Build the types in SimplifyLibCalls as in the original version of this diff. Do some local cleanup to make it more readable per arsenm's feedback.

Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason getSizeTSize() unconditionally queries addrspace(0) is just that there aren't any in-tree targets where it needs to return anything else. Fixes welcome.

Passing in a integer to emitMemCmp where the size isn't getSizeTSize() just miscompiles, as far as I know. (CreateMemCpy works fine in any case because it's an intrinsic.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason getSizeTSize() unconditionally queries addrspace(0) is just that there aren't any in-tree targets where it needs to return anything else. Fixes welcome.

Honestly, I don't think fixing getSizeTSize() makes sense. It's just a wrapper around DataLayout::getPointerSizeInBits, which is already more widely accessible throughout the middle-end than TargetLibraryInfo is.

If we really want the convenience name, then I think we should move getSizeTSize to DataLayout, and add an addrspace argument such that it is a pure wrapper around getPointerSizeInBits.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getSizeTSize() is supposed to be the bitwidth of the type size_t. There's only one size_t. This is useful because we're talking about C library functions, which are defined in terms of size_t.

If you want to query information about a specific address-space, we have appropriate datalayout APIs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be amenable to moving getSizeTSize() onto DataLayout? That would make it much easier to update all callers of CreateMemCpy.

I'll look at ensuring getSizeTSize() is correct for CHERI targets in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think getSizeTSize belongs on the data layout. It's a property related to the C library, rather than the target itself.

For CHERI's purposes, shouldn't it be sufficient to consistently make use of getSizeTSize() in this file and patch getSizeTSize() to return the correct value? Maybe just making the default value be the index type size instead of the pointer type size would be sufficient?

The BuildLibCalls code itself already uses the correct types, we're just not passing in matching types here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For CHERI's purposes, shouldn't it be sufficient to consistently make use of getSizeTSize() in this file and patch getSizeTSize() to return the correct value? Maybe just making the default value be the index type size instead of the pointer type size would be sufficient?

I've split such a change into #118747
There seems to be at least one test that concretely tests for the opposite behavior today, so we'll need to sort that out. If we can make that change, then most of the changes in this PR become irrelevant.

@@ -16,7 +16,7 @@ define arm_aapcscc void @test_simplify1() {


call arm_aapcscc ptr @strcpy(ptr @a, ptr @hello)
; CHECK: @llvm.memcpy.p0.p0.i32
; CHECK: @llvm.memcpy.p0.p0.i64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by these changes -- this test uses a p:32:32 data layout, so wasn't the previous i32 overload the correct one?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CreateMemCpy overload in question just calls getInt64().

I'm not sure it's much of a practical issue: backend legalization will truncate the integer to the appropriate size. But maybe we should canonicalize somewhere.

@@ -16,7 +16,7 @@ define arm_aapcscc void @test_simplify1() {


call arm_aapcscc ptr @strcpy(ptr @a, ptr @hello)
; CHECK: @llvm.memcpy.p0.p0.i32
; CHECK: @llvm.memcpy.p0.p0.i64
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CreateMemCpy overload in question just calls getInt64().

I'm not sure it's much of a practical issue: backend legalization will truncate the integer to the appropriate size. But maybe we should canonicalize somewhere.

Value *GV = B.CreateGlobalString(FormatStr, "str");
Value *GV = B.CreateGlobalString(
FormatStr, "str",
CI->getArgOperand(0)->getType()->getPointerAddressSpace());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're going to touch CreateGlobalString, please fix all the calls in this file at the same time.

On targets with multiple address-spaces, is it actually legal to create a global like this? I thought the usual convention was to create a global in some specific address-space, then addrspacecast it to the correct address-space.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally should try to use a context address space. For a newly synthesized global, there's DL.getDefaultGlobalsAddressSpace

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this specific line we're replacing one global string with another, so it makes sense to use the original address space. I'll look into other uses within this file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getConstantStringInfo looks through addrspacecasts, so CI->getArgOperand(0)->getType()->getPointerAddressSpace() isn't necessarily the "original" address-space.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, let's just use the default globals address space. I'm going to split this component into a separate PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split off into #118729

Writing the test for this exposed more places where we were failing to propagate address spaces correctly, which are addressed in that PR as well.

@resistor resistor closed this Dec 9, 2024
@resistor
Copy link
Collaborator Author

Tagging @jrtc27 and @arichardson for awareness on CHERI-derived PRs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants