Skip to content

Commit 547ce3e

Browse files
ahmedbougachaojhunt
authored andcommitted
[PAC] Implement pointer authentication for C++ member function pointers.
Introduces type based signing of member function pointers. To support this discrimination schema we no longer emit member function pointer to virtual methods and indices into a vtable but migrate to using thunks. This does mean member function pointers are no longer necessarily directly comparable, however as such comparisons are UB this is acceptable. We derive the discriminator from the C++ mangling of the type of the pointer being authenticated. Co-Authored-By: Akira Hatanaka [email protected] Co-Authored-By: John McCall [email protected]
1 parent 4afdcd9 commit 547ce3e

File tree

11 files changed

+855
-117
lines changed

11 files changed

+855
-117
lines changed

clang/include/clang/AST/ASTContext.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
12871287
getPointerAuthVTablePointerDiscriminator(const CXXRecordDecl *RD);
12881288

12891289
/// Return the "other" type-specific discriminator for the given type.
1290-
uint16_t getPointerAuthTypeDiscriminator(QualType T) const;
1290+
uint16_t getPointerAuthTypeDiscriminator(QualType T);
12911291

12921292
/// Apply Objective-C protocol qualifiers to the given type.
12931293
/// \param allowOnPointerType specifies if we can apply protocol

clang/include/clang/Basic/PointerAuthOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ struct PointerAuthOptions {
175175

176176
/// The ABI for variadic C++ virtual function pointers.
177177
PointerAuthSchema CXXVirtualVariadicFunctionPointers;
178+
179+
/// The ABI for C++ member function pointers.
180+
PointerAuthSchema CXXMemberFunctionPointers;
178181
};
179182

180183
} // end namespace clang

clang/lib/AST/ASTContext.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3404,7 +3404,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx,
34043404
}
34053405
}
34063406

3407-
uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
3407+
uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) {
34083408
assert(!T->isDependentType() &&
34093409
"cannot compute type discriminator of a dependent type");
34103410

@@ -3414,11 +3414,13 @@ uint16_t ASTContext::getPointerAuthTypeDiscriminator(QualType T) const {
34143414
if (T->isFunctionPointerType() || T->isFunctionReferenceType())
34153415
T = T->getPointeeType();
34163416

3417-
if (T->isFunctionType())
3417+
if (T->isFunctionType()) {
34183418
encodeTypeForFunctionPointerAuth(*this, Out, T);
3419-
else
3420-
llvm_unreachable(
3421-
"type discrimination of non-function type not implemented yet");
3419+
} else {
3420+
T = T.getUnqualifiedType();
3421+
std::unique_ptr<MangleContext> MC(createMangleContext());
3422+
MC->mangleCanonicalTypeName(T, Out);
3423+
}
34223424

34233425
return llvm::getPointerAuthStableSipHash(Str);
34243426
}

clang/lib/CodeGen/CGCall.cpp

Lines changed: 114 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -5034,7 +5034,8 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
50345034
ReturnValueSlot ReturnValue,
50355035
const CallArgList &CallArgs,
50365036
llvm::CallBase **callOrInvoke, bool IsMustTail,
5037-
SourceLocation Loc) {
5037+
SourceLocation Loc,
5038+
bool IsVirtualFunctionPointerThunk) {
50385039
// FIXME: We no longer need the types from CallArgs; lift up and simplify.
50395040

50405041
assert(Callee.isOrdinary() || Callee.isVirtual());
@@ -5098,7 +5099,11 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
50985099
RawAddress SRetAlloca = RawAddress::invalid();
50995100
llvm::Value *UnusedReturnSizePtr = nullptr;
51005101
if (RetAI.isIndirect() || RetAI.isInAlloca() || RetAI.isCoerceAndExpand()) {
5101-
if (!ReturnValue.isNull()) {
5102+
if (IsVirtualFunctionPointerThunk && RetAI.isIndirect()) {
5103+
SRetPtr = makeNaturalAddressForPointer(CurFn->arg_begin() +
5104+
IRFunctionArgs.getSRetArgNo(),
5105+
RetTy, CharUnits::fromQuantity(1));
5106+
} else if (!ReturnValue.isNull()) {
51025107
SRetPtr = ReturnValue.getAddress();
51035108
} else {
51045109
SRetPtr = CreateMemTemp(RetTy, "tmp", &SRetAlloca);
@@ -5877,119 +5882,130 @@ RValue CodeGenFunction::EmitCall(const CGFunctionInfo &CallInfo,
58775882
CallArgs.freeArgumentMemory(*this);
58785883

58795884
// Extract the return value.
5880-
RValue Ret = [&] {
5881-
switch (RetAI.getKind()) {
5882-
case ABIArgInfo::CoerceAndExpand: {
5883-
auto coercionType = RetAI.getCoerceAndExpandType();
5884-
5885-
Address addr = SRetPtr.withElementType(coercionType);
5886-
5887-
assert(CI->getType() == RetAI.getUnpaddedCoerceAndExpandType());
5888-
bool requiresExtract = isa<llvm::StructType>(CI->getType());
5885+
RValue Ret;
58895886

5890-
unsigned unpaddedIndex = 0;
5891-
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
5892-
llvm::Type *eltType = coercionType->getElementType(i);
5893-
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType)) continue;
5894-
Address eltAddr = Builder.CreateStructGEP(addr, i);
5895-
llvm::Value *elt = CI;
5896-
if (requiresExtract)
5897-
elt = Builder.CreateExtractValue(elt, unpaddedIndex++);
5898-
else
5899-
assert(unpaddedIndex == 0);
5900-
Builder.CreateStore(elt, eltAddr);
5887+
// If the current function is a virtual function pointer thunk, avoid copying
5888+
// the return value of the musttail call to a temporary.
5889+
if (IsVirtualFunctionPointerThunk)
5890+
Ret = RValue::get(CI);
5891+
else
5892+
Ret = [&] {
5893+
switch (RetAI.getKind()) {
5894+
case ABIArgInfo::CoerceAndExpand: {
5895+
auto coercionType = RetAI.getCoerceAndExpandType();
5896+
5897+
Address addr = SRetPtr.withElementType(coercionType);
5898+
5899+
assert(CI->getType() == RetAI.getUnpaddedCoerceAndExpandType());
5900+
bool requiresExtract = isa<llvm::StructType>(CI->getType());
5901+
5902+
unsigned unpaddedIndex = 0;
5903+
for (unsigned i = 0, e = coercionType->getNumElements(); i != e; ++i) {
5904+
llvm::Type *eltType = coercionType->getElementType(i);
5905+
if (ABIArgInfo::isPaddingForCoerceAndExpand(eltType))
5906+
continue;
5907+
Address eltAddr = Builder.CreateStructGEP(addr, i);
5908+
llvm::Value *elt = CI;
5909+
if (requiresExtract)
5910+
elt = Builder.CreateExtractValue(elt, unpaddedIndex++);
5911+
else
5912+
assert(unpaddedIndex == 0);
5913+
Builder.CreateStore(elt, eltAddr);
5914+
}
5915+
[[fallthrough]];
59015916
}
5902-
[[fallthrough]];
5903-
}
5904-
5905-
case ABIArgInfo::InAlloca:
5906-
case ABIArgInfo::Indirect: {
5907-
RValue ret = convertTempToRValue(SRetPtr, RetTy, SourceLocation());
5908-
if (UnusedReturnSizePtr)
5909-
PopCleanupBlock();
5910-
return ret;
5911-
}
59125917

5913-
case ABIArgInfo::Ignore:
5914-
// If we are ignoring an argument that had a result, make sure to
5915-
// construct the appropriate return value for our caller.
5916-
return GetUndefRValue(RetTy);
5918+
case ABIArgInfo::InAlloca:
5919+
case ABIArgInfo::Indirect: {
5920+
RValue ret = convertTempToRValue(SRetPtr, RetTy, SourceLocation());
5921+
if (UnusedReturnSizePtr)
5922+
PopCleanupBlock();
5923+
return ret;
5924+
}
59175925

5918-
case ABIArgInfo::Extend:
5919-
case ABIArgInfo::Direct: {
5920-
llvm::Type *RetIRTy = ConvertType(RetTy);
5921-
if (RetAI.getCoerceToType() == RetIRTy && RetAI.getDirectOffset() == 0) {
5922-
switch (getEvaluationKind(RetTy)) {
5923-
case TEK_Complex: {
5924-
llvm::Value *Real = Builder.CreateExtractValue(CI, 0);
5925-
llvm::Value *Imag = Builder.CreateExtractValue(CI, 1);
5926-
return RValue::getComplex(std::make_pair(Real, Imag));
5927-
}
5928-
case TEK_Aggregate: {
5929-
Address DestPtr = ReturnValue.getAddress();
5930-
bool DestIsVolatile = ReturnValue.isVolatile();
5926+
case ABIArgInfo::Ignore:
5927+
// If we are ignoring an argument that had a result, make sure to
5928+
// construct the appropriate return value for our caller.
5929+
return GetUndefRValue(RetTy);
5930+
5931+
case ABIArgInfo::Extend:
5932+
case ABIArgInfo::Direct: {
5933+
llvm::Type *RetIRTy = ConvertType(RetTy);
5934+
if (RetAI.getCoerceToType() == RetIRTy &&
5935+
RetAI.getDirectOffset() == 0) {
5936+
switch (getEvaluationKind(RetTy)) {
5937+
case TEK_Complex: {
5938+
llvm::Value *Real = Builder.CreateExtractValue(CI, 0);
5939+
llvm::Value *Imag = Builder.CreateExtractValue(CI, 1);
5940+
return RValue::getComplex(std::make_pair(Real, Imag));
5941+
}
5942+
case TEK_Aggregate: {
5943+
Address DestPtr = ReturnValue.getAddress();
5944+
bool DestIsVolatile = ReturnValue.isVolatile();
59315945

5932-
if (!DestPtr.isValid()) {
5933-
DestPtr = CreateMemTemp(RetTy, "agg.tmp");
5934-
DestIsVolatile = false;
5946+
if (!DestPtr.isValid()) {
5947+
DestPtr = CreateMemTemp(RetTy, "agg.tmp");
5948+
DestIsVolatile = false;
5949+
}
5950+
EmitAggregateStore(CI, DestPtr, DestIsVolatile);
5951+
return RValue::getAggregate(DestPtr);
5952+
}
5953+
case TEK_Scalar: {
5954+
// If the argument doesn't match, perform a bitcast to coerce it.
5955+
// This can happen due to trivial type mismatches.
5956+
llvm::Value *V = CI;
5957+
if (V->getType() != RetIRTy)
5958+
V = Builder.CreateBitCast(V, RetIRTy);
5959+
return RValue::get(V);
5960+
}
59355961
}
5936-
EmitAggregateStore(CI, DestPtr, DestIsVolatile);
5937-
return RValue::getAggregate(DestPtr);
5962+
llvm_unreachable("bad evaluation kind");
59385963
}
5939-
case TEK_Scalar: {
5940-
// If the argument doesn't match, perform a bitcast to coerce it. This
5941-
// can happen due to trivial type mismatches.
5964+
5965+
// If coercing a fixed vector from a scalable vector for ABI
5966+
// compatibility, and the types match, use the llvm.vector.extract
5967+
// intrinsic to perform the conversion.
5968+
if (auto *FixedDstTy = dyn_cast<llvm::FixedVectorType>(RetIRTy)) {
59425969
llvm::Value *V = CI;
5943-
if (V->getType() != RetIRTy)
5944-
V = Builder.CreateBitCast(V, RetIRTy);
5945-
return RValue::get(V);
5946-
}
5970+
if (auto *ScalableSrcTy =
5971+
dyn_cast<llvm::ScalableVectorType>(V->getType())) {
5972+
if (FixedDstTy->getElementType() ==
5973+
ScalableSrcTy->getElementType()) {
5974+
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
5975+
V = Builder.CreateExtractVector(FixedDstTy, V, Zero,
5976+
"cast.fixed");
5977+
return RValue::get(V);
5978+
}
5979+
}
59475980
}
5948-
llvm_unreachable("bad evaluation kind");
5949-
}
59505981

5951-
// If coercing a fixed vector from a scalable vector for ABI
5952-
// compatibility, and the types match, use the llvm.vector.extract
5953-
// intrinsic to perform the conversion.
5954-
if (auto *FixedDstTy = dyn_cast<llvm::FixedVectorType>(RetIRTy)) {
5955-
llvm::Value *V = CI;
5956-
if (auto *ScalableSrcTy =
5957-
dyn_cast<llvm::ScalableVectorType>(V->getType())) {
5958-
if (FixedDstTy->getElementType() == ScalableSrcTy->getElementType()) {
5959-
llvm::Value *Zero = llvm::Constant::getNullValue(CGM.Int64Ty);
5960-
V = Builder.CreateExtractVector(FixedDstTy, V, Zero, "cast.fixed");
5961-
return RValue::get(V);
5962-
}
5982+
Address DestPtr = ReturnValue.getValue();
5983+
bool DestIsVolatile = ReturnValue.isVolatile();
5984+
5985+
if (!DestPtr.isValid()) {
5986+
DestPtr = CreateMemTemp(RetTy, "coerce");
5987+
DestIsVolatile = false;
59635988
}
5964-
}
59655989

5966-
Address DestPtr = ReturnValue.getValue();
5967-
bool DestIsVolatile = ReturnValue.isVolatile();
5990+
// An empty record can overlap other data (if declared with
5991+
// no_unique_address); omit the store for such types - as there is no
5992+
// actual data to store.
5993+
if (!isEmptyRecord(getContext(), RetTy, true)) {
5994+
// If the value is offset in memory, apply the offset now.
5995+
Address StorePtr = emitAddressAtOffset(*this, DestPtr, RetAI);
5996+
CreateCoercedStore(CI, StorePtr, DestIsVolatile, *this);
5997+
}
59685998

5969-
if (!DestPtr.isValid()) {
5970-
DestPtr = CreateMemTemp(RetTy, "coerce");
5971-
DestIsVolatile = false;
5999+
return convertTempToRValue(DestPtr, RetTy, SourceLocation());
59726000
}
59736001

5974-
// An empty record can overlap other data (if declared with
5975-
// no_unique_address); omit the store for such types - as there is no
5976-
// actual data to store.
5977-
if (!isEmptyRecord(getContext(), RetTy, true)) {
5978-
// If the value is offset in memory, apply the offset now.
5979-
Address StorePtr = emitAddressAtOffset(*this, DestPtr, RetAI);
5980-
CreateCoercedStore(CI, StorePtr, DestIsVolatile, *this);
6002+
case ABIArgInfo::Expand:
6003+
case ABIArgInfo::IndirectAliased:
6004+
llvm_unreachable("Invalid ABI kind for return argument");
59816005
}
59826006

5983-
return convertTempToRValue(DestPtr, RetTy, SourceLocation());
5984-
}
5985-
5986-
case ABIArgInfo::Expand:
5987-
case ABIArgInfo::IndirectAliased:
5988-
llvm_unreachable("Invalid ABI kind for return argument");
5989-
}
5990-
5991-
llvm_unreachable("Unhandled ABIArgInfo::Kind");
5992-
} ();
6007+
llvm_unreachable("Unhandled ABIArgInfo::Kind");
6008+
}();
59936009

59946010
// Emit the assume_aligned check on the return value.
59956011
if (Ret.isScalar() && TargetDecl) {

clang/lib/CodeGen/CGPointerAuth.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,40 @@ llvm::Constant *CodeGenModule::getFunctionPointer(GlobalDecl GD,
365365
return getFunctionPointer(getRawFunctionPointer(GD, Ty), FuncType);
366366
}
367367

368+
CGPointerAuthInfo CodeGenModule::getMemberFunctionPointerAuthInfo(QualType FT) {
369+
assert(FT->getAs<MemberPointerType>() && "MemberPointerType expected");
370+
auto &Schema = getCodeGenOpts().PointerAuth.CXXMemberFunctionPointers;
371+
if (!Schema)
372+
return CGPointerAuthInfo();
373+
374+
assert(!Schema.isAddressDiscriminated() &&
375+
"function pointers cannot use address-specific discrimination");
376+
377+
llvm::ConstantInt *Discriminator =
378+
getPointerAuthOtherDiscriminator(Schema, GlobalDecl(), FT);
379+
return CGPointerAuthInfo(Schema.getKey(), Schema.getAuthenticationMode(),
380+
/* IsIsaPointer */ false,
381+
/* AuthenticatesNullValues */ false, Discriminator);
382+
}
383+
384+
llvm::Constant *CodeGenModule::getMemberFunctionPointer(llvm::Constant *Pointer,
385+
QualType FT) {
386+
if (CGPointerAuthInfo PointerAuth = getMemberFunctionPointerAuthInfo(FT))
387+
return getConstantSignedPointer(
388+
Pointer, PointerAuth.getKey(), nullptr,
389+
cast_or_null<llvm::ConstantInt>(PointerAuth.getDiscriminator()));
390+
391+
return Pointer;
392+
}
393+
394+
llvm::Constant *CodeGenModule::getMemberFunctionPointer(const FunctionDecl *FD,
395+
llvm::Type *Ty) {
396+
QualType FT = FD->getType();
397+
FT = getContext().getMemberPointerType(
398+
FT, cast<CXXMethodDecl>(FD)->getParent()->getTypeForDecl());
399+
return getMemberFunctionPointer(getRawFunctionPointer(FD, Ty), FT);
400+
}
401+
368402
std::optional<PointerAuthQualifier>
369403
CodeGenModule::computeVTPointerAuthentication(const CXXRecordDecl *ThisClass) {
370404
auto DefaultAuthentication = getCodeGenOpts().PointerAuth.CXXVTablePointers;

clang/lib/CodeGen/CodeGenFunction.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4373,7 +4373,8 @@ class CodeGenFunction : public CodeGenTypeCache {
43734373
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
43744374
ReturnValueSlot ReturnValue, const CallArgList &Args,
43754375
llvm::CallBase **callOrInvoke, bool IsMustTail,
4376-
SourceLocation Loc);
4376+
SourceLocation Loc,
4377+
bool IsVirtualFunctionPointerThunk = false);
43774378
RValue EmitCall(const CGFunctionInfo &CallInfo, const CGCallee &Callee,
43784379
ReturnValueSlot ReturnValue, const CallArgList &Args,
43794380
llvm::CallBase **callOrInvoke = nullptr,

clang/lib/CodeGen/CodeGenModule.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,8 +973,16 @@ class CodeGenModule : public CodeGenTypeCache {
973973
llvm::Constant *getFunctionPointer(llvm::Constant *Pointer,
974974
QualType FunctionType);
975975

976+
llvm::Constant *getMemberFunctionPointer(const FunctionDecl *FD,
977+
llvm::Type *Ty = nullptr);
978+
979+
llvm::Constant *getMemberFunctionPointer(llvm::Constant *Pointer,
980+
QualType FT);
981+
976982
CGPointerAuthInfo getFunctionPointerAuthInfo(QualType T);
977983

984+
CGPointerAuthInfo getMemberFunctionPointerAuthInfo(QualType FT);
985+
978986
CGPointerAuthInfo getPointerAuthInfoForPointeeType(QualType type);
979987

980988
CGPointerAuthInfo getPointerAuthInfoForType(QualType type);

0 commit comments

Comments
 (0)