21
21
#include " llvm/ADT/SmallVector.h"
22
22
#include " llvm/ADT/Statistic.h"
23
23
#include " llvm/IR/CallingConv.h"
24
+ #include " llvm/IR/GlobalAlias.h"
24
25
#include " llvm/IR/IRBuilder.h"
25
26
#include " llvm/IR/Instruction.h"
26
27
#include " llvm/IR/Mangler.h"
29
30
#include " llvm/Pass.h"
30
31
#include " llvm/Support/CommandLine.h"
31
32
#include " llvm/TargetParser/Triple.h"
33
+ #include < map>
32
34
33
35
using namespace llvm ;
34
36
using namespace llvm ::COFF;
@@ -57,15 +59,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
57
59
Function *buildEntryThunk (Function *F);
58
60
void lowerCall (CallBase *CB);
59
61
Function *buildGuestExitThunk (Function *F);
60
- bool processFunction (Function &F, SetVector<Function *> &DirectCalledFns);
62
+ Function *buildPatchableThunk (GlobalAlias *UnmangledAlias,
63
+ GlobalAlias *MangledAlias);
64
+ bool processFunction (Function &F, SetVector<GlobalValue *> &DirectCalledFns,
65
+ std::map<GlobalAlias *, GlobalAlias *> &FnsMap);
61
66
bool runOnModule (Module &M) override ;
62
67
63
68
private:
64
69
int cfguard_module_flag = 0 ;
65
70
FunctionType *GuardFnType = nullptr ;
66
71
PointerType *GuardFnPtrType = nullptr ;
72
+ FunctionType *DispatchFnType = nullptr ;
73
+ PointerType *DispatchFnPtrType = nullptr ;
67
74
Constant *GuardFnCFGlobal = nullptr ;
68
75
Constant *GuardFnGlobal = nullptr ;
76
+ Constant *DispatchFnGlobal = nullptr ;
69
77
Module *M = nullptr ;
70
78
71
79
Type *PtrTy;
@@ -615,6 +623,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
615
623
return GuestExit;
616
624
}
617
625
626
+ Function *
627
+ AArch64Arm64ECCallLowering::buildPatchableThunk (GlobalAlias *UnmangledAlias,
628
+ GlobalAlias *MangledAlias) {
629
+ llvm::raw_null_ostream NullThunkName;
630
+ FunctionType *Arm64Ty, *X64Ty;
631
+ Function *F = cast<Function>(MangledAlias->getAliasee ());
632
+ getThunkType (F->getFunctionType (), F->getAttributes (),
633
+ Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty);
634
+ std::string ThunkName (MangledAlias->getName ());
635
+ if (ThunkName[0 ] == ' ?' && ThunkName.find (" @" ) != std::string::npos) {
636
+ ThunkName.insert (ThunkName.find (" @" ), " $hybpatch_thunk" );
637
+ } else {
638
+ ThunkName.append (" $hybpatch_thunk" );
639
+ }
640
+
641
+ Function *GuestExit =
642
+ Function::Create (Arm64Ty, GlobalValue::WeakODRLinkage, 0 , ThunkName, M);
643
+ GuestExit->setComdat (M->getOrInsertComdat (ThunkName));
644
+ GuestExit->setSection (" .wowthk$aa" );
645
+ BasicBlock *BB = BasicBlock::Create (M->getContext (), " " , GuestExit);
646
+ IRBuilder<> B (BB);
647
+
648
+ // Load the global symbol as a pointer to the check function.
649
+ LoadInst *DispatchLoad = B.CreateLoad (DispatchFnPtrType, DispatchFnGlobal);
650
+
651
+ // Create new dispatch call instruction.
652
+ Function *ExitThunk =
653
+ buildExitThunk (F->getFunctionType (), F->getAttributes ());
654
+ CallInst *Dispatch = B.CreateCall (
655
+ DispatchFnType, DispatchLoad,
656
+ {B.CreateBitCast (UnmangledAlias, B.getPtrTy ()),
657
+ B.CreateBitCast (ExitThunk, B.getPtrTy ()),
658
+ B.CreateBitCast (UnmangledAlias->getAliasee (), B.getPtrTy ())});
659
+
660
+ // Ensure that the first arguments are passed in the correct registers.
661
+ Dispatch->setCallingConv (CallingConv::CFGuard_Check);
662
+
663
+ Value *DispatchRetVal = B.CreateBitCast (Dispatch, PtrTy);
664
+ SmallVector<Value *> Args;
665
+ for (Argument &Arg : GuestExit->args ())
666
+ Args.push_back (&Arg);
667
+ CallInst *Call = B.CreateCall (Arm64Ty, DispatchRetVal, Args);
668
+ Call->setTailCallKind (llvm::CallInst::TCK_MustTail);
669
+
670
+ if (Call->getType ()->isVoidTy ())
671
+ B.CreateRetVoid ();
672
+ else
673
+ B.CreateRet (Call);
674
+
675
+ auto SRetAttr = F->getAttributes ().getParamAttr (0 , Attribute::StructRet);
676
+ auto InRegAttr = F->getAttributes ().getParamAttr (0 , Attribute::InReg);
677
+ if (SRetAttr.isValid () && !InRegAttr.isValid ()) {
678
+ GuestExit->addParamAttr (0 , SRetAttr);
679
+ Call->addParamAttr (0 , SRetAttr);
680
+ }
681
+
682
+ MangledAlias->setAliasee (GuestExit);
683
+ return GuestExit;
684
+ }
685
+
618
686
// Lower an indirect call with inline code.
619
687
void AArch64Arm64ECCallLowering::lowerCall (CallBase *CB) {
620
688
assert (Triple (CB->getModule ()->getTargetTriple ()).isOSWindows () &&
@@ -670,17 +738,51 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
670
738
671
739
GuardFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy}, false );
672
740
GuardFnPtrType = PointerType::get (GuardFnType, 0 );
741
+ DispatchFnType = FunctionType::get (PtrTy, {PtrTy, PtrTy, PtrTy}, false );
742
+ DispatchFnPtrType = PointerType::get (DispatchFnType, 0 );
673
743
GuardFnCFGlobal =
674
744
M->getOrInsertGlobal (" __os_arm64x_check_icall_cfg" , GuardFnPtrType);
675
745
GuardFnGlobal =
676
746
M->getOrInsertGlobal (" __os_arm64x_check_icall" , GuardFnPtrType);
747
+ DispatchFnGlobal =
748
+ M->getOrInsertGlobal (" __os_arm64x_dispatch_call" , DispatchFnPtrType);
749
+
750
+ std::map<GlobalAlias *, GlobalAlias *> FnsMap;
751
+ SetVector<GlobalAlias *> PatchableFns;
677
752
678
- SetVector<Function *> DirectCalledFns;
753
+ for (Function &F : Mod) {
754
+ if (!F.hasFnAttribute (Attribute::HybridPatchable) || F.isDeclaration () ||
755
+ F.hasLocalLinkage () || F.getName ().ends_with (" $hp_target" ))
756
+ continue ;
757
+
758
+ // Rename hybrid patchable functions and change callers to use a global
759
+ // alias instead.
760
+ if (std::optional<std::string> MangledName =
761
+ getArm64ECMangledFunctionName (F.getName ().str ())) {
762
+ std::string OrigName (F.getName ());
763
+ F.setName (MangledName.value () + " $hp_target" );
764
+
765
+ auto *A =
766
+ GlobalAlias::create (GlobalValue::LinkOnceODRLinkage, OrigName, &F);
767
+ F.replaceAllUsesWith (A);
768
+ F.setMetadata (" arm64ec_exp_name" ,
769
+ MDNode::get (M->getContext (),
770
+ MDString::get (M->getContext (),
771
+ " EXP+" + MangledName.value ())));
772
+ A->setAliasee (&F);
773
+
774
+ FnsMap[A] = GlobalAlias::create (GlobalValue::LinkOnceODRLinkage,
775
+ MangledName.value (), &F);
776
+ PatchableFns.insert (A);
777
+ }
778
+ }
779
+
780
+ SetVector<GlobalValue *> DirectCalledFns;
679
781
for (Function &F : Mod)
680
782
if (!F.isDeclaration () &&
681
783
F.getCallingConv () != CallingConv::ARM64EC_Thunk_Native &&
682
784
F.getCallingConv () != CallingConv::ARM64EC_Thunk_X64)
683
- processFunction (F, DirectCalledFns);
785
+ processFunction (F, DirectCalledFns, FnsMap );
684
786
685
787
struct ThunkInfo {
686
788
Constant *Src;
@@ -698,14 +800,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
698
800
{&F, buildEntryThunk (&F), Arm64ECThunkType::Entry});
699
801
}
700
802
}
701
- for (Function *F : DirectCalledFns) {
803
+ for (GlobalValue *O : DirectCalledFns) {
804
+ auto GA = dyn_cast<GlobalAlias>(O);
805
+ auto F = dyn_cast<Function>(GA ? GA->getAliasee () : O);
702
806
ThunkMapping.push_back (
703
- {F , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
807
+ {O , buildExitThunk (F->getFunctionType (), F->getAttributes ()),
704
808
Arm64ECThunkType::Exit});
705
- if (!F->hasDLLImportStorageClass ())
809
+ if (!GA && ! F->hasDLLImportStorageClass ())
706
810
ThunkMapping.push_back (
707
811
{buildGuestExitThunk (F), F, Arm64ECThunkType::GuestExit});
708
812
}
813
+ for (GlobalAlias *A : PatchableFns) {
814
+ Function *Thunk = buildPatchableThunk (A, FnsMap[A]);
815
+ ThunkMapping.push_back ({Thunk, A, Arm64ECThunkType::GuestExit});
816
+ }
709
817
710
818
if (!ThunkMapping.empty ()) {
711
819
SmallVector<Constant *> ThunkMappingArrayElems;
@@ -728,7 +836,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
728
836
}
729
837
730
838
bool AArch64Arm64ECCallLowering::processFunction (
731
- Function &F, SetVector<Function *> &DirectCalledFns) {
839
+ Function &F, SetVector<GlobalValue *> &DirectCalledFns,
840
+ std::map<GlobalAlias *, GlobalAlias *> &FnsMap) {
732
841
SmallVector<CallBase *, 8 > IndirectCalls;
733
842
734
843
// For ARM64EC targets, a function definition's name is mangled differently
@@ -780,6 +889,17 @@ bool AArch64Arm64ECCallLowering::processFunction(
780
889
continue ;
781
890
}
782
891
892
+ // Use mangled global alias for direct calls to patchable functions.
893
+ if (GlobalAlias *A =
894
+ dyn_cast_or_null<GlobalAlias>(CB->getCalledOperand ())) {
895
+ auto I = FnsMap.find (A);
896
+ if (I != FnsMap.end ()) {
897
+ CB->setCalledOperand (I->second );
898
+ DirectCalledFns.insert (I->first );
899
+ continue ;
900
+ }
901
+ }
902
+
783
903
IndirectCalls.push_back (CB);
784
904
++Arm64ECCallsLowered;
785
905
}
0 commit comments