21
21
#include " llvm/ADT/SmallVector.h"
22
22
#include " llvm/ADT/Statistic.h"
23
23
#include " llvm/IR/CallingConv.h"
24
+ #include " llvm/IR/GlobalAlias.h"
24
25
#include " llvm/IR/IRBuilder.h"
25
26
#include " llvm/IR/Instruction.h"
26
27
#include " llvm/IR/Mangler.h"
@@ -57,15 +58,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
57
58
Function *buildEntryThunk (Function *F);
58
59
void lowerCall (CallBase *CB);
59
60
Function *buildGuestExitThunk (Function *F);
60
- bool processFunction (Function &F, SetVector<Function *> &DirectCalledFns);
61
+ Function *buildPatchableThunk (GlobalAlias *UnmangledAlias,
62
+ GlobalAlias *MangledAlias);
63
+ bool processFunction (Function &F, SetVector<GlobalValue *> &DirectCalledFns,
64
+ DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
61
65
bool runOnModule (Module &M) override ;
62
66
63
67
private:
64
68
int cfguard_module_flag = 0 ;
65
69
FunctionType *GuardFnType = nullptr ;
66
70
PointerType *GuardFnPtrType = nullptr ;
71
+ FunctionType *DispatchFnType = nullptr ;
72
+ PointerType *DispatchFnPtrType = nullptr ;
67
73
Constant *GuardFnCFGlobal = nullptr ;
68
74
Constant *GuardFnGlobal = nullptr ;
75
+ Constant *DispatchFnGlobal = nullptr ;
69
76
Module *M = nullptr ;
70
77
71
78
Type *PtrTy;
@@ -615,6 +622,64 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
615
622
return GuestExit;
616
623
}
617
624
625
+ Function *
626
+ AArch64Arm64ECCallLowering::buildPatchableThunk (GlobalAlias *UnmangledAlias,
627
+ GlobalAlias *MangledAlias) {
628
+ llvm::raw_null_ostream NullThunkName;
629
+ FunctionType *Arm64Ty, *X64Ty;
630
+ Function *F = cast<Function>(MangledAlias->getAliasee ());
631
+ getThunkType (F->getFunctionType (), F->getAttributes (),
632
+ Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
633
+ std::string ThunkName (MangledAlias->getName ());
634
+ if (ThunkName[0 ] == ' ?' && ThunkName.find (" @" ) != std::string::npos) {
635
+ ThunkName.insert (ThunkName.find (" @" ), " $hybpatch_thunk" );
636
+ } else {
637
+ ThunkName.append (" $hybpatch_thunk" );
638
+ }
639
+
640
+ Function *GuestExit =
641
+ Function::Create (Arm64Ty, GlobalValue::WeakODRLinkage, 0 , ThunkName, M);
642
+ GuestExit->setComdat (M->getOrInsertComdat (ThunkName));
643
+ GuestExit->setSection (" .wowthk$aa" );
644
+ BasicBlock *BB = BasicBlock::Create (M->getContext (), " " , GuestExit);
645
+ IRBuilder<> B (BB);
646
+
647
+ // Load the global symbol as a pointer to the check function.
648
+ LoadInst *DispatchLoad = B.CreateLoad (DispatchFnPtrType, DispatchFnGlobal);
649
+
650
+ // Create new dispatch call instruction.
651
+ Function *ExitThunk =
652
+ buildExitThunk (F->getFunctionType (), F->getAttributes ());
653
+ CallInst *Dispatch =
654
+ B.CreateCall (DispatchFnType, DispatchLoad,
655
+ {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee ()});
656
+
657
+ // Ensure that the first arguments are passed in the correct registers.
658
+ Dispatch->setCallingConv (CallingConv::CFGuard_Check);
659
+
660
+ Value *DispatchRetVal = B.CreateBitCast (Dispatch, PtrTy);
661
+ SmallVector<Value *> Args;
662
+ for (Argument &Arg : GuestExit->args ())
663
+ Args.push_back (&Arg);
664
+ CallInst *Call = B.CreateCall (Arm64Ty, DispatchRetVal, Args);
665
+ Call->setTailCallKind (llvm::CallInst::TCK_MustTail);
666
+
667
+ if (Call->getType ()->isVoidTy ())
668
+ B.CreateRetVoid ();
669
+ else
670
+ B.CreateRet (Call);
671
+
672
+ auto SRetAttr = F->getAttributes ().getParamAttr (0 , Attribute::StructRet);
673
+ auto InRegAttr = F->getAttributes ().getParamAttr (0 , Attribute::InReg);
674
+ if (SRetAttr.isValid () && !InRegAttr.isValid ()) {
675
+ GuestExit->addParamAttr (0 , SRetAttr);
676
+ Call->addParamAttr (0 , SRetAttr);
677
+ }
678
+
679
+ MangledAlias->setAliasee (GuestExit);
680
+ return GuestExit;
681
+ }
682
+
618
683
// Lower an indirect call with inline code.
619
684
void AArch64Arm64ECCallLowering::lowerCall (CallBase *CB) {
620
685
assert (Triple (CB->getModule ()->getTargetTriple ()).isOSWindows () &&
@@ -670,17 +735,57 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
670
735
671
736
GuardFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy}, false );
672
737
GuardFnPtrType = PointerType::get (GuardFnType, 0 );
738
+ DispatchFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy, PtrTy}, false );
739
+ DispatchFnPtrType = PointerType::get (DispatchFnType, 0 );
673
740
GuardFnCFGlobal =
674
741
M->getOrInsertGlobal (" __os_arm64x_check_icall_cfg" , GuardFnPtrType);
675
742
GuardFnGlobal =
676
743
M->getOrInsertGlobal (" __os_arm64x_check_icall" , GuardFnPtrType);
744
+ DispatchFnGlobal =
745
+ M->getOrInsertGlobal (" __os_arm64x_dispatch_call" , DispatchFnPtrType);
746
+
747
+ DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
748
+ SetVector<GlobalAlias *> PatchableFns;
677
749
678
- SetVector<Function *> DirectCalledFns;
750
+ for (Function &F : Mod) {
751
+ if (!F.hasFnAttribute (Attribute::HybridPatchable) || F.isDeclaration () ||
752
+ F.hasLocalLinkage () || F.getName ().ends_with (" $hp_target" ))
753
+ continue ;
754
+
755
+ // Rename hybrid patchable functions and change callers to use a global
756
+ // alias instead.
757
+ if (std::optional<std::string> MangledName =
758
+ getArm64ECMangledFunctionName (F.getName ().str ())) {
759
+ std::string OrigName (F.getName ());
760
+ F.setName (MangledName.value () + " $hp_target" );
761
+
762
+ // The unmangled symbol is a weak alias to an undefined symbol with the
763
+ // "EXP+" prefix. This undefined symbol is resolved by the linker by
764
+ // creating an x86 thunk that jumps back to the actual EC target. Since we
765
+ // can't represent that in IR, we create an alias to the target instead.
766
+ // The "EXP+" symbol is set as metadata, which is then used by
767
+ // emitGlobalAlias to emit the right alias.
768
+ auto *A =
769
+ GlobalAlias::create (GlobalValue::LinkOnceODRLinkage, OrigName, &F);
770
+ F.replaceAllUsesWith (A);
771
+ F.setMetadata (" arm64ec_exp_name" ,
772
+ MDNode::get (M->getContext (),
773
+ MDString::get (M->getContext (),
774
+ " EXP+" + MangledName.value ())));
775
+ A->setAliasee (&F);
776
+
777
+ FnsMap[A] = GlobalAlias::create (GlobalValue::LinkOnceODRLinkage,
778
+ MangledName.value (), &F);
779
+ PatchableFns.insert (A);
780
+ }
781
+ }
782
+
783
+ SetVector<GlobalValue *> DirectCalledFns;
679
784
for (Function &F : Mod)
680
785
if (!F.isDeclaration () &&
681
786
F.getCallingConv () != CallingConv::ARM64EC_Thunk_Native &&
682
787
F.getCallingConv () != CallingConv::ARM64EC_Thunk_X64)
683
- processFunction (F, DirectCalledFns);
788
+ processFunction (F, DirectCalledFns, FnsMap );
684
789
685
790
struct ThunkInfo {
686
791
Constant *Src;
@@ -698,14 +803,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
698
803
{&F, buildEntryThunk (&F), Arm64ECThunkType::Entry});
699
804
}
700
805
}
701
- for (Function *F : DirectCalledFns) {
806
+ for (GlobalValue *O : DirectCalledFns) {
807
+ auto GA = dyn_cast<GlobalAlias>(O);
808
+ auto F = dyn_cast<Function>(GA ? GA->getAliasee () : O);
702
809
ThunkMapping.push_back (
703
- {F , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
810
+ {O , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
704
811
Arm64ECThunkType::Exit});
705
- if (!F->hasDLLImportStorageClass ())
812
+ if (!GA && ! F->hasDLLImportStorageClass ())
706
813
ThunkMapping.push_back (
707
814
{buildGuestExitThunk (F), F, Arm64ECThunkType::GuestExit});
708
815
}
816
+ for (GlobalAlias *A : PatchableFns) {
817
+ Function *Thunk = buildPatchableThunk (A, FnsMap[A]);
818
+ ThunkMapping.push_back ({Thunk, A, Arm64ECThunkType::GuestExit});
819
+ }
709
820
710
821
if (!ThunkMapping.empty ()) {
711
822
SmallVector<Constant *> ThunkMappingArrayElems;
@@ -728,7 +839,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
728
839
}
729
840
730
841
bool AArch64Arm64ECCallLowering::processFunction (
731
- Function &F, SetVector<Function *> &DirectCalledFns) {
842
+ Function &F, SetVector<GlobalValue *> &DirectCalledFns,
843
+ DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
732
844
SmallVector<CallBase *, 8 > IndirectCalls;
733
845
734
846
// For ARM64EC targets, a function definition's name is mangled differently
@@ -780,6 +892,16 @@ bool AArch64Arm64ECCallLowering::processFunction(
780
892
continue ;
781
893
}
782
894
895
+ // Use mangled global alias for direct calls to patchable functions.
896
+ if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand ())) {
897
+ auto I = FnsMap.find (A);
898
+ if (I != FnsMap.end ()) {
899
+ CB->setCalledOperand (I->second );
900
+ DirectCalledFns.insert (I->first );
901
+ continue ;
902
+ }
903
+ }
904
+
783
905
IndirectCalls.push_back (CB);
784
906
++Arm64ECCallsLowered;
785
907
}
0 commit comments