Skip to content

[CodeGen][ARM64EC] Add support for hybrid_patchable attribute. #92965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 19, 2024

Conversation

cjacek
Copy link
Contributor

@cjacek cjacek commented May 21, 2024

This PR implements LLVM part of hybrid_patchable support. A prototype of clang part is here: cjacek@e4b1ffc. The attribute is mentioned in the official documentation: https://learn.microsoft.com/en-us/windows/arm/arm64ec-abi. I described more details about how they work: https://wiki.winehq.org/ARM64ECToolchain#Patchable_functions.

@llvmbot
Copy link
Member

llvmbot commented May 21, 2024

@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Jacek Caban (cjacek)

Changes

This PR implements LLVM part of hybrid_patchable support. A prototype of clang part is here: cjacek@e4b1ffc. The attribute is mentioned in the official documentation: https://learn.microsoft.com/en-us/windows/arm/arm64ec-abi. I described more details about how they work: https://wiki.winehq.org/ARM64ECToolchain#Patchable_functions.


Full diff: https://github.com/llvm/llvm-project/pull/92965.diff

8 Files Affected:

  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+1)
  • (modified) llvm/include/llvm/IR/Attributes.td (+3)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+2)
  • (modified) llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (+115-2)
  • (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+12-4)
  • (modified) llvm/lib/Target/AArch64/AArch64CallingConvention.td (+1-1)
  • (modified) llvm/lib/Transforms/Utils/CodeExtractor.cpp (+1)
  • (added) llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll (+77)
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 909eb833c601a..1e6a1cbc856a7 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -744,6 +744,7 @@ enum AttributeKindCodes {
   ATTR_KIND_CORO_ONLY_DESTROY_WHEN_COMPLETE = 90,
   ATTR_KIND_DEAD_ON_UNWIND = 91,
   ATTR_KIND_RANGE = 92,
+  ATTR_KIND_HYBRID_PATCHABLE = 93,
 };
 
 enum ComdatSelectionKindCodes {
diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td
index cef8b17769f0d..3eebd6d018730 100644
--- a/llvm/include/llvm/IR/Attributes.td
+++ b/llvm/include/llvm/IR/Attributes.td
@@ -109,6 +109,9 @@ def ElementType : TypeAttr<"elementtype", [ParamAttr]>;
 /// symbol.
 def FnRetThunkExtern : EnumAttr<"fn_ret_thunk_extern", [FnAttr]>;
 
+/// Function has a hybrid patchable thunk.
+def HybridPatchable : EnumAttr<"hybrid_patchable", [FnAttr]>;
+
 /// Pass structure in an alloca.
 def InAlloca : TypeAttr<"inalloca", [ParamAttr]>;
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index c4cea3d6eef2d..0816278ad040d 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -707,6 +707,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
     return bitc::ATTR_KIND_HOT;
   case Attribute::ElementType:
     return bitc::ATTR_KIND_ELEMENTTYPE;
+  case Attribute::HybridPatchable:
+    return bitc::ATTR_KIND_HYBRID_PATCHABLE;
   case Attribute::InlineHint:
     return bitc::ATTR_KIND_INLINE_HINT;
   case Attribute::InReg:
diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
index 0ec15d34cd4a9..a8dd376e7416d 100644
--- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
@@ -57,6 +57,7 @@ class AArch64Arm64ECCallLowering : public ModulePass {
   Function *buildEntryThunk(Function *F);
   void lowerCall(CallBase *CB);
   Function *buildGuestExitThunk(Function *F);
+  Function *buildPatchableThunk(Function *F);
   bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
   bool runOnModule(Module &M) override;
 
@@ -64,8 +65,11 @@ class AArch64Arm64ECCallLowering : public ModulePass {
   int cfguard_module_flag = 0;
   FunctionType *GuardFnType = nullptr;
   PointerType *GuardFnPtrType = nullptr;
+  FunctionType *DispatchFnType = nullptr;
+  PointerType *DispatchFnPtrType = nullptr;
   Constant *GuardFnCFGlobal = nullptr;
   Constant *GuardFnGlobal = nullptr;
+  Constant *DispatchFnGlobal = nullptr;
   Module *M = nullptr;
 
   Type *PtrTy;
@@ -615,6 +619,78 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
   return GuestExit;
 }
 
+Function *AArch64Arm64ECCallLowering::buildPatchableThunk(Function *F) {
+  llvm::raw_null_ostream NullThunkName;
+  FunctionType *Arm64Ty, *X64Ty;
+  getThunkType(F->getFunctionType(), F->getAttributes(),
+               Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
+  auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
+  assert(MangledName && "Can't guest exit to function that's already native");
+  std::string ThunkName = *MangledName;
+  if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
+    ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
+  } else {
+    ThunkName.append("$hybpatch_thunk");
+  }
+
+  Function *GuestExit =
+      Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
+  GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
+  GuestExit->setSection(".wowthk$aa");
+  GuestExit->setMetadata(
+      "arm64ec_unmangled_name",
+      MDNode::get(M->getContext(),
+                  MDString::get(M->getContext(), F->getName())));
+  GuestExit->setMetadata(
+      "arm64ec_ecmangled_name",
+      MDNode::get(M->getContext(),
+                  MDString::get(M->getContext(), *MangledName)));
+  GuestExit->setMetadata(
+      "arm64ec_exp_name",
+      MDNode::get(M->getContext(),
+                  MDString::get(M->getContext(), "EXP+" + *MangledName)));
+  F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
+  BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
+  IRBuilder<> B(BB);
+
+  // Load the global symbol as a pointer to the check function.
+  LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
+  Value *TargetFn =
+      M->getOrInsertFunction(*MangledName + "$hp_target", F->getFunctionType())
+          .getCallee();
+
+  // Create new dispatch call instruction.
+  Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
+  CallInst *Dispatch = B.CreateCall(DispatchFnType, DispatchLoad,
+                                    {B.CreateBitCast(F, B.getPtrTy()),
+                                     B.CreateBitCast(Thunk, B.getPtrTy()),
+                                     B.CreateBitCast(TargetFn, B.getPtrTy())});
+
+  // Ensure that the first arguments are passed in the correct registers.
+  Dispatch->setCallingConv(CallingConv::CFGuard_Check);
+
+  Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
+  SmallVector<Value *> Args;
+  for (Argument &Arg : GuestExit->args())
+    Args.push_back(&Arg);
+  CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
+  Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
+
+  if (Call->getType()->isVoidTy())
+    B.CreateRetVoid();
+  else
+    B.CreateRet(Call);
+
+  auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
+  auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
+  if (SRetAttr.isValid() && !InRegAttr.isValid()) {
+    GuestExit->addParamAttr(0, SRetAttr);
+    Call->addParamAttr(0, SRetAttr);
+  }
+
+  return GuestExit;
+}
+
 // Lower an indirect call with inline code.
 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
@@ -670,10 +746,40 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
 
   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
   GuardFnPtrType = PointerType::get(GuardFnType, 0);
+  DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
+  DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
   GuardFnCFGlobal =
       M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
   GuardFnGlobal =
       M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
+  DispatchFnGlobal =
+      M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
+
+  // Rename hybrid patchable functions and change callers to use an external
+  // linkage function call instead.
+  SetVector<Function *> PatchableFns;
+  for (Function &F : Mod) {
+    if (!F.hasFnAttribute(Attribute::HybridPatchable) ||
+        F.getName().ends_with("$hp_target"))
+      continue;
+
+    if (F.isDeclaration() || F.hasLocalLinkage()) {
+      F.removeFnAttr(Attribute::HybridPatchable);
+      continue;
+    }
+
+    if (std::optional<std::string> MangledName =
+            getArm64ECMangledFunctionName(F.getName().str())) {
+      std::string OrigName(F.getName());
+      F.setName(MangledName.value() + "$hp_target");
+
+      Function *EF = Function::Create(
+          F.getFunctionType(), GlobalValue::ExternalLinkage, 0, OrigName, M);
+      EF->copyAttributesFrom(&F);
+      F.replaceAllUsesWith(EF);
+      PatchableFns.insert(EF);
+    }
+  }
 
   SetVector<Function *> DirectCalledFns;
   for (Function &F : Mod)
@@ -702,10 +808,15 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
     ThunkMapping.push_back(
         {F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
          Arm64ECThunkType::Exit});
+    assert(!F->hasFnAttribute(Attribute::HybridPatchable));
     if (!F->hasDLLImportStorageClass())
       ThunkMapping.push_back(
           {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
   }
+  for (Function *F : PatchableFns) {
+    Function *Thunk = buildPatchableThunk(F);
+    ThunkMapping.push_back({Thunk, F, Arm64ECThunkType::GuestExit});
+  }
 
   if (!ThunkMapping.empty()) {
     SmallVector<Constant *> ThunkMappingArrayElems;
@@ -738,7 +849,8 @@ bool AArch64Arm64ECCallLowering::processFunction(
   // name (emitting the definition) can grab it from the metadata.
   //
   // FIXME: Handle functions with weak linkage?
-  if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
+  if ((!F.hasLocalLinkage() || F.hasAddressTaken()) &&
+      !F.hasFnAttribute(Attribute::HybridPatchable)) {
     if (std::optional<std::string> MangledName =
             getArm64ECMangledFunctionName(F.getName().str())) {
       F.setMetadata("arm64ec_unmangled_name",
@@ -773,7 +885,8 @@ bool AArch64Arm64ECCallLowering::processFunction(
       // unprototyped functions in C)
       if (Function *F = CB->getCalledFunction()) {
         if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
-            F->isIntrinsic() || !F->isDeclaration())
+            F->isIntrinsic() || !F->isDeclaration() ||
+            F->hasFnAttribute(Attribute::HybridPatchable))
           continue;
 
         DirectCalledFns.insert(F);
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index bdc3fc630a4e3..d4b88a35a8286 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -1167,8 +1167,9 @@ void AArch64AsmPrinter::emitFunctionEntryLabel() {
       !MF->getFunction().hasLocalLinkage()) {
     // For ARM64EC targets, a function definition's name is mangled differently
     // from the normal symbol, emit required aliases here.
-    auto emitFunctionAlias = [&](MCSymbol *Src, MCSymbol *Dst) {
-      OutStreamer->emitSymbolAttribute(Src, MCSA_WeakAntiDep);
+    auto emitFunctionAlias = [&](MCSymbol *Src, MCSymbol *Dst,
+                                 MCSymbolAttr Attr = MCSA_WeakAntiDep) {
+      OutStreamer->emitSymbolAttribute(Src, Attr);
       OutStreamer->emitAssignment(
           Src, MCSymbolRefExpr::create(Dst, MCSymbolRefExpr::VK_None,
                                        MMI->getContext()));
@@ -1186,8 +1187,15 @@ void AArch64AsmPrinter::emitFunctionEntryLabel() {
     if (MCSymbol *UnmangledSym =
             getSymbolFromMetadata("arm64ec_unmangled_name")) {
       MCSymbol *ECMangledSym = getSymbolFromMetadata("arm64ec_ecmangled_name");
-
-      if (ECMangledSym) {
+      MCSymbol *ExpSym = getSymbolFromMetadata("arm64ec_exp_name");
+
+      if (ExpSym) {
+        // A hybrid patchable function, emit the alias from the unmangled
+        // symbol to x64 thunk and and the alias from the mangled symbol to
+        // patchable guest exit thunk.
+        emitFunctionAlias(ECMangledSym, CurrentFnSym, MCSA_Weak);
+        emitFunctionAlias(UnmangledSym, ExpSym, MCSA_Weak);
+      } else if (ECMangledSym) {
         // An external function, emit the alias from the unmangled symbol to
         // mangled symbol name and the alias from the mangled symbol to guest
         // exit thunk.
diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index 8e67f0f5c8815..5061605364c21 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -333,7 +333,7 @@ def CC_AArch64_Win64_CFGuard_Check : CallingConv<[
 
 let Entry = 1 in
 def CC_AArch64_Arm64EC_CFGuard_Check : CallingConv<[
-  CCIfType<[i64], CCAssignToReg<[X11, X10]>>
+  CCIfType<[i64], CCAssignToReg<[X11, X10, X9]>>
 ]>;
 
 let Entry = 1 in
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index f2672b8e9118f..e49309014ca71 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -932,6 +932,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
       case Attribute::DisableSanitizerInstrumentation:
       case Attribute::FnRetThunkExtern:
       case Attribute::Hot:
+      case Attribute::HybridPatchable:
       case Attribute::NoRecurse:
       case Attribute::InlineHint:
       case Attribute::MinSize:
diff --git a/llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll b/llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll
new file mode 100644
index 0000000000000..5b1cac173a82a
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll
@@ -0,0 +1,77 @@
+; RUN: llc -mtriple=arm64ec-pc-windows-msvc < %s | FileCheck %s
+
+define dso_local i32 @func() hybrid_patchable nounwind {
+; CHECK-LABEL:     .def    "#func$hp_target";
+; CHECK:           .section        .text,"xr",discard,"#func$hp_target"
+; CHECK-NEXT:      .globl  "#func$hp_target"               // -- Begin function #func$hp_target
+; CHECK-NEXT:      .p2align        2
+; CHECK-NEXT:  "#func$hp_target":                      // @"#func$hp_target"
+; CHECK-NEXT:      // %bb.0:
+; CHECK-NEXT:      mov     w0, #1                          // =0x1
+; CHECK-NEXT:      ret
+  ret i32 1
+}
+
+; hybrid_patchable attribute is ignored on internal functions
+define internal i32 @static_func() hybrid_patchable nounwind {
+; CHECK-LABEL:     .def    static_func;
+; CHECK:       static_func:                            // @static_func
+; CHECK-NEXT:      // %bb.0:
+; CHECK-NEXT:      mov     w0, #2                          // =0x2
+; CHECK-NEXT:      ret
+  ret i32 2
+}
+
+define dso_local void @caller() nounwind {
+; CHECK-LABEL:     .def    "#caller";
+; CHECK:           .section        .text,"xr",discard,"#caller"
+; CHECK-NEXT:      .globl  "#caller"                       // -- Begin function #caller
+; CHECK-NEXT:      .p2align        2
+; CHECK-NEXT:  "#caller":                              // @"#caller"
+; CHECK-NEXT:      .weak_anti_dep  caller
+; CHECK-NEXT:  .set caller, "#caller"{{$}}
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:      str     x30, [sp, #-16]!                // 8-byte Folded Spill
+; CHECK-NEXT:      bl      "#func"
+; CHECK-NEXT:      bl      static_func
+; CHECK-NEXT:      ldr     x30, [sp], #16                  // 8-byte Folded Reload
+; CHECK-NEXT:      ret
+  %1 = call i32 @func()
+  %2 = call i32 @static_func()
+  ret void
+}
+
+; CHECK: .def    $ientry_thunk$cdecl$i8$v;
+; CHECK: .def    $ientry_thunk$cdecl$v$v;
+; CHECK: .def    $iexit_thunk$cdecl$i8$v;
+
+; CHECK-LABEL:       def    "#func$hybpatch_thunk";
+; CHECK:            .section        .wowthk$aa,"xr",discard,"#func$hybpatch_thunk"
+; CHECK-NEXT:       .globl  "#func$hybpatch_thunk"          // -- Begin function #func$hybpatch_thunk
+; CHECK-NEXT:       .p2align        2
+; CHECK-NEXT:   "#func$hybpatch_thunk":                 // @"#func$hybpatch_thunk"
+; CHECK-NEXT:       .weak  "#func"
+; CHECK-NEXT:   .set "#func", "#func$hybpatch_thunk"{{$}}
+; CHECK-NEXT:       .weak  func
+; CHECK-NEXT:   .set func, "EXP+#func"{{$}}
+; CHECK-NEXT:   .seh_proc "#func$hybpatch_thunk"
+; CHECK-NEXT:   // %bb.0:
+; CHECK-NEXT:       str     x30, [sp, #-16]!                // 8-byte Folded Spill
+; CHECK-NEXT:       .seh_save_reg_x x30, 16
+; CHECK-NEXT:       .seh_endprologue
+; CHECK-NEXT:       adrp    x8, __os_arm64x_dispatch_call
+; CHECK-NEXT:       adrp    x11, func
+; CHECK-NEXT:       add     x11, x11, :lo12:func
+; CHECK-NEXT:       ldr     x8, [x8, :lo12:__os_arm64x_dispatch_call]
+; CHECK-NEXT:       adrp    x10, ($iexit_thunk$cdecl$i8$v)
+; CHECK-NEXT:       add     x10, x10, :lo12:($iexit_thunk$cdecl$i8$v)
+; CHECK-NEXT:       adrp    x9, "#func$hp_target"
+; CHECK-NEXT:       add     x9, x9, :lo12:"#func$hp_target"
+; CHECK-NEXT:       blr     x8
+; CHECK-NEXT:       .seh_startepilogue
+; CHECK-NEXT:       ldr     x30, [sp], #16                  // 8-byte Folded Reload
+; CHECK-NEXT:       .seh_save_reg_x x30, 16
+; CHECK-NEXT:       .seh_endepilogue
+; CHECK-NEXT:       br      x11
+; CHECK-NEXT:       .seh_endfunclet
+; CHECK-NEXT:       .seh_endproc

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also add a tests with varargs and sret to make sure the musttail call correctly preserves all the relevant registers.

std::string OrigName(F.getName());
F.setName(MangledName.value() + "$hp_target");

Function *EF = Function::Create(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd like to avoid the situation where we emit a declaration for the function at the IR level as if it isn't defined, but then we actually end up defining the symbol in the assembly. I'm not sure it really has any particular effect here: I think the symbol only ends up being used for the thunk table. But I'd prefer to reduce the gap between the IR and the actual generated code as much as possible.

Not sure if this actually works, but can we emit a GlobalAlias at the IR level? (It doesn't work for anti-dep symbols because we currently can't represent those, but I think maybe a normal weak alias lowers to the right thing?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GlobalAlias seems to work fine out of the box for the mangled alias (which points to the thunk). However, global aliases need to point to function definition, while we need to emit a weak alias to an undefined symbol for the unmangled alias. I worked around it by overriding AsmPrintet::emitGlobalAlias and special-casing those aliases based on aliasee's metadata. emitGlobalAlias was already a virtual function, but I needed to change it to protected in AsmPrinter to be able to fall back to it in aarch64 code. This allowed to reduce the gap between IR and generated code, but it's not really eliminated (FWIW, I think it would be possible to similarly define aliases for anti-dependencies if we want to).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current version seems okay. The whole EXP+ thing would require a significant IR extension to properly represent, which doesn't seem worth it. (We could theoretically define a Constant specifically to refer to the linker-generated hybrid_patchable thunk of a function, and allow aliases to refer to such a constant, but that seems like overkill if it would only be used in this one place.)

Maybe worth adding a comment in the code describing the intentional gap between a normal alias and what we're emitting here.

Not sure about anti-dep aliases; I still haven't quite wrapped my head around the semantics of antidep. I guess from the compiler's perspective, it might be close enough to a weak alias that we could just get away with setting the linkage to "weak", and adding some metadata to mark it as anti-dep?

@cjacek cjacek force-pushed the llvm-hybrid-patchable branch 2 times, most recently from a346bfd to 6c0308f Compare May 31, 2024 16:14
@cjacek
Copy link
Contributor Author

cjacek commented May 31, 2024

The new version fixes a cast in AsmPrinter. It also doesn't use the map to iterate over patchable thunks as it did not guarantee predictable ordering. I also rebased it as CI seemed to have additional unrelated failures.

@efriedma-quic
Copy link
Collaborator

Please add a testcase for taking the address of a hybrid_patchable function. (Not a complete review; I'll look again next week.)

@cjacek
Copy link
Contributor Author

cjacek commented Jun 3, 2024

I added tests for taking the address. I also added a test of dllexport. In this case, the output looks weird, but that's compatible with MSVC. I think we could as well not be compatible here and make it work with something like cjacek@4febef2 on top of this PR.

Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
GlobalAlias *MangledAlias);
bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
std::map<GlobalAlias *, GlobalAlias *> &FnsMap);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use DenseMap here?

buildExitThunk(F->getFunctionType(), F->getAttributes());
CallInst *Dispatch = B.CreateCall(
DispatchFnType, DispatchLoad,
{B.CreateBitCast(UnmangledAlias, B.getPtrTy()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These bitcasts aren't necessary anymore with opaque pointers; probably I should go through and clean them all up.

@@ -780,6 +889,17 @@ bool AArch64Arm64ECCallLowering::processFunction(
continue;
}

// Use mangled global alias for direct calls to patchable functions.
if (GlobalAlias *A =
dyn_cast_or_null<GlobalAlias>(CB->getCalledOperand())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getCalledOperand() never returns null.

dyn_cast_or_null<GlobalAlias>(CB->getCalledOperand())) {
auto I = FnsMap.find(A);
if (I != FnsMap.end()) {
CB->setCalledOperand(I->second);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels a little weird to rewrite all references to the function to aliases, then rewrite them again, but I guess it works.

std::string OrigName(F.getName());
F.setName(MangledName.value() + "$hp_target");

Function *EF = Function::Create(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current version seems okay. The whole EXP+ thing would require a significant IR extension to properly represent, which doesn't seem worth it. (We could theoretically define a Constant specifically to refer to the linker-generated hybrid_patchable thunk of a function, and allow aliases to refer to such a constant, but that seems like overkill if it would only be used in this one place.)

Maybe worth adding a comment in the code describing the intentional gap between a normal alias and what we're emitting here.

Not sure about anti-dep aliases; I still haven't quite wrapped my head around the semantics of antidep. I guess from the compiler's perspective, it might be close enough to a weak alias that we could just get away with setting the linkage to "weak", and adding some metadata to mark it as anti-dep?

@cjacek
Copy link
Contributor Author

cjacek commented Jun 17, 2024

Thanks for the review. The new version uses DenseMap, removes redundant null check and bitcasts and adds a comment explaining the gap.

BTW, for anti-dep aliases, I'm still unsure if I'm missing something, but the only general differences from weak aliases I spot so far is that it's legal to have them defined multiple time and if regular weak alias is present, it overrides anti-deps. Other than that, there is a very specific case where linker seems to use that to detect unmangled -> mangled alias and apply some special-casing.

@cjacek cjacek force-pushed the llvm-hybrid-patchable branch from 3b3739e to ad6a8b3 Compare June 18, 2024 15:04
@cjacek
Copy link
Contributor Author

cjacek commented Jun 18, 2024

I rebased to resolve the conflict.

@cjacek cjacek force-pushed the llvm-hybrid-patchable branch from ad6a8b3 to 2445632 Compare June 24, 2024 12:19
@cjacek
Copy link
Contributor Author

cjacek commented Jun 24, 2024

Rebased to resolve conflicts with 2c9c22c.

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

(Sorry about the delay; forgot this was still waiting for me.)

@cjacek cjacek force-pushed the llvm-hybrid-patchable branch from 2445632 to 35123aa Compare July 18, 2024 22:07
@cjacek cjacek force-pushed the llvm-hybrid-patchable branch from 35123aa to f9b73eb Compare July 19, 2024 09:37
@cjacek cjacek merged commit 6cc8774 into llvm:main Jul 19, 2024
4 of 7 checks passed
@cjacek cjacek deleted the llvm-hybrid-patchable branch July 19, 2024 09:43
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary: 

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251384
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants