Skip to content

Sign relative protocol witness tables #63788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/ABI/MetadataValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -1442,6 +1442,9 @@ namespace SpecialPointerAuthDiscriminators {

/// C type StoreExtraInhabitantTag function descriminator
const uint16_t StoreExtraInhabitantTagFunction = 0x9bf6; // = 39926

// Relative protocol witness table descriminator
const uint16_t RelativeProtocolWitnessTable = 0xb830; // = 47152
}

/// The number of arguments that will be passed directly to a generic
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/IRGenOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ struct PointerAuthOptions : clang::PointerAuthOptions {

/// C type StoreExtraInhabitantTag function descriminator.
PointerAuthSchema StoreExtraInhabitantTagFunction;

/// Relative protocol witness table descriminator.
PointerAuthSchema RelativeProtocolWitnessTable;
};

enum class JITDebugArtifact : unsigned {
Expand Down
35 changes: 35 additions & 0 deletions lib/IRGen/GenProto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1192,6 +1192,17 @@ WitnessIndex ProtocolInfo::getAssociatedTypeIndex(
llvm_unreachable("didn't find entry for associated type");
}

static llvm::Constant *
getConstantSignedRelativeProtocolWitnessTable(IRGenModule &IGM,
llvm::Value *table) {
auto constantTable = cast<llvm::Constant>(table);
auto &schema = IGM.getOptions().PointerAuth.RelativeProtocolWitnessTable;
constantTable =
IGM.getConstantSignedPointer(constantTable, schema, PointerAuthEntity(),
/*storageAddress*/ nullptr);
return constantTable;
}

namespace {

/// Conformance info for a witness table that can be directly generated.
Expand Down Expand Up @@ -1694,6 +1705,8 @@ void WitnessTableBuilderBase::defineAssociatedTypeWitnessTableAccessFunction(
// If we can emit a constant table, do so.
if (auto constantTable =
conformanceI->tryGetConstantTable(IGM, associatedType)) {
constantTable =
getConstantSignedRelativeProtocolWitnessTable(IGM, constantTable);
IGF.Builder.CreateRet(constantTable);
return;
}
Expand Down Expand Up @@ -1890,6 +1903,7 @@ llvm::Function *FragileWitnessTableBuilder::buildInstantiationFunction() {
// Ask the ConformanceInfo to emit the wtable.
llvm::Value *baseWTable =
base.second->getTable(IGF, &metadata);

baseWTable = IGF.Builder.CreateBitCast(baseWTable, IGM.Int8PtrTy);

// Store that to the appropriate slot in the new witness table.
Expand Down Expand Up @@ -2626,6 +2640,11 @@ llvm::Value *IRGenFunction::optionallyLoadFromConditionalProtocolWitnessTable(
auto *phi = Builder.CreatePHI(wtable->getType(), 2);
phi->addIncoming(wtable, origBB);
phi->addIncoming(wtableDeref, isCondBB);
if (auto &schema = getOptions().PointerAuth.RelativeProtocolWitnessTable) {
auto info = PointerAuthInfo::emit(*this, schema, nullptr,
PointerAuthEntity());
return emitPointerAuthAuth(*this, phi, info);
}
return phi;
}

Expand Down Expand Up @@ -2660,11 +2679,25 @@ llvm::Value *irgen::loadParentProtocolWitnessTable(IRGenFunction &IGF,
Builder.CreateBr(endBB);

Builder.emitBlock(isNotCondBB);
if (auto &schema = IGF.getOptions().PointerAuth.RelativeProtocolWitnessTable) {
auto info = PointerAuthInfo::emit(IGF, schema, nullptr,
PointerAuthEntity());
wtable = emitPointerAuthAuth(IGF, wtable, info);
}
auto baseWTable2 =
emitInvariantLoadOfOpaqueWitness(IGF,/*isProtocolWitness*/true, wtable,
index);
baseWTable2 = IGF.Builder.CreateBitCast(baseWTable2,
IGF.IGM.WitnessTablePtrTy);
if (auto &schema = IGF.getOptions().PointerAuth.RelativeProtocolWitnessTable) {
auto info = PointerAuthInfo::emit(IGF, schema, nullptr,
PointerAuthEntity());
baseWTable2 = emitPointerAuthSign(IGF, baseWTable2, info);

baseWTable2 = IGF.Builder.CreateBitCast(baseWTable2,
IGF.IGM.WitnessTablePtrTy);
}

Builder.CreateBr(endBB);

Builder.emitBlock(endBB);
Expand Down Expand Up @@ -3268,6 +3301,8 @@ llvm::Value *irgen::emitWitnessTableRef(IRGenFunction &IGF,

auto &conformanceI = IGF.IGM.getConformanceInfo(proto, concreteConformance);
wtable = conformanceI.getTable(IGF, srcMetadataCache);
if (isa<llvm::Constant>(wtable))
wtable = getConstantSignedRelativeProtocolWitnessTable(IGF.IGM, wtable);

IGF.setScopedLocalTypeData(srcType, cacheKind, wtable);
return wtable;
Expand Down
10 changes: 8 additions & 2 deletions lib/IRGen/IRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,8 @@ bool swift::compileAndWriteLLVM(llvm::Module *module,
}

static void setPointerAuthOptions(PointerAuthOptions &opts,
const clang::PointerAuthOptions &clangOpts){
const clang::PointerAuthOptions &clangOpts,
const IRGenOptions &irgenOpts) {
// Intentionally do a slice-assignment to copy over the clang options.
static_cast<clang::PointerAuthOptions&>(opts) = clangOpts;

Expand Down Expand Up @@ -1011,6 +1012,11 @@ static void setPointerAuthOptions(PointerAuthOptions &opts,
opts.StoreExtraInhabitantTagFunction = PointerAuthSchema(
codeKey, /*address*/ false, Discrimination::Constant,
SpecialPointerAuthDiscriminators::StoreExtraInhabitantTagFunction);

if (irgenOpts.UseRelativeProtocolWitnessTables)
opts.RelativeProtocolWitnessTable = PointerAuthSchema(
dataKey, /*address*/ false, Discrimination::Constant,
SpecialPointerAuthDiscriminators::RelativeProtocolWitnessTable);
}

std::unique_ptr<llvm::TargetMachine>
Expand Down Expand Up @@ -1045,7 +1051,7 @@ swift::createTargetMachine(const IRGenOptions &Opts, ASTContext &Ctx) {
// after the module loaders are set up, and where these options are
// formally not const.
setPointerAuthOptions(const_cast<IRGenOptions &>(Opts).PointerAuth,
clangInstance.getCodeGenOpts().PointerAuth);
clangInstance.getCodeGenOpts().PointerAuth, Opts);
}
}

Expand Down
33 changes: 31 additions & 2 deletions stdlib/public/runtime/Metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5527,14 +5527,30 @@ instantiateRelativeWitnessTable(const Metadata *Type,
// Advance the address point; the private storage area is accessed via
// negative offsets.
auto table = fullTable + privateSizeInWords;

#if SWIFT_PTRAUTH
table[0] = ptrauth_sign_unauthenticated(
(void*)pattern,
ptrauth_key_process_independent_data,
SpecialPointerAuthDiscriminators::RelativeProtocolWitnessTable);
#else
table[0] = (void*)pattern;
#endif

assert(1 == WitnessTableFirstRequirementOffset);

// Fill in the base protocols of the requirements from the pattern.
for (size_t i = 0, e = numBaseProtocols; i < e; ++i) {
size_t index = i + WitnessTableFirstRequirementOffset;
#if SWIFT_PTRAUTH
auto rawValue = ((RelativeBaseWitness const *)pattern)[index].get();
table[index] = (rawValue == nullptr) ? rawValue :
ptrauth_sign_unauthenticated(
rawValue,
ptrauth_key_process_independent_data,
SpecialPointerAuthDiscriminators::RelativeProtocolWitnessTable);
#else
table[index] = ((RelativeBaseWitness const *)pattern)[index].get();
#endif
}

// Copy any instantiation arguments that correspond to conditional
Expand Down Expand Up @@ -5627,6 +5643,13 @@ swift::swift_getWitnessTableRelative(const ProtocolConformanceDescriptor *confor
assert(!conformance->isSynthesizedNonUnique());
auto pattern = conformance->getWitnessTablePattern();
auto table = uniqueForeignWitnessTableRef(pattern);

#if SWIFT_STDLIB_USE_RELATIVE_PROTOCOL_WITNESS_TABLES && SWIFT_PTRAUTH
table = ptrauth_sign_unauthenticated(table,
ptrauth_key_process_independent_data,
SpecialPointerAuthDiscriminators::RelativeProtocolWitnessTable);
#endif

return reinterpret_cast<const RelativeWitnessTable*>(table);
}

Expand Down Expand Up @@ -5855,8 +5878,14 @@ RelativeWitnessTable *swift::lookThroughOptionalConditionalWitnessTable(
if (conditional_wtable & 0x1) {
conditional_wtable = conditional_wtable & ~(uintptr_t)(0x1);
conditional_wtable = (uintptr_t)*(void**)conditional_wtable;

}
auto table = (RelativeWitnessTable*)conditional_wtable;
#if SWIFT_PTRAUTH
table = swift_auth_data_non_address(
table,
SpecialPointerAuthDiscriminators::RelativeProtocolWitnessTable);
#endif
return table;
}

Expand Down Expand Up @@ -6169,7 +6198,7 @@ static const RelativeWitnessTable *swift_getAssociatedConformanceWitnessRelative

auto assocWitnessTable = witnessFn(assocType, conformingType, origWTable);

// The access function returns an unsigned pointer for now.
// The access function returns an signed pointer.
return assocWitnessTable;
}

Expand Down
1 change: 1 addition & 0 deletions stdlib/public/runtime/ProtocolConformance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,7 @@ swift_conformsToProtocolImpl(const Metadata *const type,
std::tie(table, hasUninstantiatedSuperclass) =
swift_conformsToProtocolMaybeInstantiateSuperclasses(
type, protocol, true /*instantiateSuperclassMetadata*/);

return table;
}

Expand Down
38 changes: 30 additions & 8 deletions test/IRGen/relative_protocol_witness_table.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: %target-swift-frontend -enable-relative-protocol-witness-tables -module-name A -primary-file %s -emit-ir | %FileCheck %s
// RUN: %target-swift-frontend -enable-relative-protocol-witness-tables -module-name A -primary-file %s -emit-ir | %FileCheck %s --check-prefix=CHECK-%target-cpu --check-prefix=CHECK

// REQUIRES: CPU=x86_64 || CPU=arm64
// UNSUPPORTED: CPU=arm64e
// REQUIRES: CPU=x86_64 || CPU=arm64 || CPU=arm64e

func testVWT<T>(_ t: T) {
var local = t
Expand Down Expand Up @@ -216,8 +215,15 @@ func instantiate_conditional_conformance_2nd<T>(_ t : T) where T: Sub, T.S == T
// CHECK: br label %[[LBL2]]
// CHECK:[[LBL2]]:
// CHECK: [[T5:%.*]] = phi i8** [ [[PWT]], %[[ENTRY]] ], [ [[T4]], %[[LBL1]] ]
// CHECK: [[CAST:%.*]] = bitcast i8** [[T5]] to i32*
// CHECK: [[SLOT:%.*]] = getelementptr inbounds i32, i32* [[CAST]], i32 1
// CHECK-arm64e: [[T6:%.*]] = ptrtoint i8** [[T5]] to i64
// CHECK-arm64e: [[T7:%.*]] = call i64 @llvm.ptrauth.auth(i64 [[T6]], i32 2, i64 47152)
// CHECK-arm64e: [[T8:%.*]] = inttoptr i64 [[T7]] to i8**
// CHECK-arm64e: [[CAST:%.*]] = bitcast i8** [[T8]] to i32*
// CHECK-arm64e: [[SLOT:%.*]] = getelementptr inbounds i32, i32* [[CAST]], i32 1
// CHECK-arm64: [[CAST:%.*]] = bitcast i8** [[T5]] to i32*
// CHECK-arm64: [[SLOT:%.*]] = getelementptr inbounds i32, i32* [[CAST]], i32 1
// CHECK-x86_64: [[CAST:%.*]] = bitcast i8** [[T5]] to i32*
// CHECK-x86_64: [[SLOT:%.*]] = getelementptr inbounds i32, i32* [[CAST]], i32 1
// CHECK: [[T0:%.*]] = load i32, i32* [[SLOT]]
// CHECK: [[T1:%.*]] = sext i32 [[T0]] to i64
// CHECK: [[T2:%.*]] = ptrtoint i32* [[SLOT]] to i64
Expand All @@ -243,14 +249,22 @@ func instantiate_conditional_conformance_2nd<T>(_ t : T) where T: Sub, T.S == T
// CHECK: br label %[[L3:.*]]

// CHECK:[[L2]]:
// CHECK: [[T8:%.*]] = bitcast i8** [[T_INHERITED]] to i32*
// CHECK-arm64e: [[P0:%.*]] = ptrtoint i8** [[T_INHERITED]] to i64
// CHECK-arm64e: [[P1:%.*]] = call i64 @llvm.ptrauth.auth(i64 [[P0]], i32 2, i64 47152)
// CHECK-arm64e: [[P2:%.*]] = inttoptr i64 [[P1]] to i8**
// CHECK-arm64e: [[T8:%.*]] = bitcast i8** [[P2]] to i32*
// CHECK-arm64: [[T8:%.*]] = bitcast i8** [[T_INHERITED]] to i32*
// CHECK-x86_64: [[T8:%.*]] = bitcast i8** [[T_INHERITED]] to i32*
// CHECK: [[T9:%.*]] = getelementptr inbounds i32, i32* [[T8]], i32 1
// CHECK: [[T10:%.*]] = load i32, i32* [[T9]]
// CHECK: [[T11:%.*]] = sext i32 [[T10]] to i64
// CHECK: [[T12:%.*]] = ptrtoint i32* [[T9]] to i64
// CHECK: [[T13:%.*]] = add i64 [[T12]], [[T11]]
// CHECK: [[T14:%.*]] = inttoptr i64 [[T13]] to i8*
// CHECK: [[T15:%.*]] = bitcast i8* [[T14]] to i8**
// CHECK-arm64e: [[T16:%.*]] = ptrtoint i8** [[T15]] to i64
// CHECK-arm64e: [[T17:%.*]] = call i64 @llvm.ptrauth.sign(i64 [[T16]], i32 2, i64 47152)
// CHECK-arm64e: [[T15:%.*]] = inttoptr i64 [[T17]] to i8**
// CHECK: br label %[[L3:.*]]

// CHECK:[[L3]]:
Expand All @@ -259,7 +273,7 @@ func instantiate_conditional_conformance_2nd<T>(_ t : T) where T: Sub, T.S == T
// Passing the witness table.

// CHECK: define{{.*}} swiftcc void @"$s1A6useIt2yyF"()
// CHECK: call swiftcc void @"$s1A15requireWitness2yyxAA9InheritedRzlF"(%swift.opaque* {{.*}}, %swift.type* {{.*}} @"$s1A7BStructVMf"{{.*}}, i8** {{.*}} @"$s1A7BStructVAA9InheritedAAWP"
// CHECK: call swiftcc void @"$s1A15requireWitness2yyxAA9InheritedRzlF"(%swift.opaque* {{.*}}, %swift.type* {{.*}} @"$s1A7BStructVMf"{{.*}}, i8** {{.*}} @"$s1A7BStructVAA9InheritedAAWP{{(.ptrauth)?}}"
// CHECK: ret void

// Accessing an associated witness
Expand Down Expand Up @@ -323,14 +337,22 @@ func instantiate_conditional_conformance_2nd<T>(_ t : T) where T: Sub, T.S == T
// CHECK: br label %[[T23:.*]]

// CHECK: [[T14]]:
// CHECK: [[T15:%.*]] = bitcast i8** [[T4]] to i32*
// CHECK-arm64e: [[P0:%.*]] = ptrtoint i8** [[T4]] to i64
// CHECK-arm64e: [[P1:%.*]] = call i64 @llvm.ptrauth.auth(i64 [[P0]], i32 2, i64 47152)
// CHECK-arm64e: [[P2:%.*]] = inttoptr i64 [[P1]] to i8**
// CHECK-arm64e: [[T15:%.*]] = bitcast i8** [[P2]] to i32*
// CHECK-x86_64: [[T15:%.*]] = bitcast i8** [[T4]] to i32*
// CHECK-arm64: [[T15:%.*]] = bitcast i8** [[T4]] to i32*
// CHECK: [[T16:%.*]] = getelementptr inbounds i32, i32* [[T15]], i32 1
// CHECK: [[T17:%.*]] = load i32, i32* [[T16]]
// CHECK: [[T18:%.*]] = sext i32 [[T17]] to i64
// CHECK: [[T19:%.*]] = ptrtoint i32* [[T16]] to i64
// CHECK: [[T20:%.*]] = add i64 [[T19]], [[T18]]
// CHECK: [[T21:%.*]] = inttoptr i64 [[T20]] to i8*
// CHECK: [[T22:%.*]] = bitcast i8* [[T21]] to i8**
// CHECK-arm64e: [[P0:%.*]] = ptrtoint i8** [[T22]] to i64
// CHECK-arm64e: [[P1:%.*]] = call i64 @llvm.ptrauth.sign(i64 [[P0]], i32 2, i64 47152)
// CHECK-arm64e: [[T22:%.*]] = inttoptr i64 [[P1]] to i8**
// CHECK: br label %[[T23]]

// CHECK: [[T23]]:
Expand Down