Skip to content

Implement value witness table for @differentiable functions #60875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/swift/Demangling/ManglingMacros.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ _Pragma("clang diagnostic pop")
#define NO_ARGS_MANGLING yy
#define FUNC_TYPE_MANGLING c
#define NOESCAPE_FUNC_TYPE_MANGLING XE
#define DIFF_FUNC_TYPE_MANGLING Yjrc
#define OBJC_PARTIAL_APPLY_THUNK_MANGLING Ta
#define OPTIONAL_MANGLING(Ty) MANGLING_CONCAT2_IMPL(Ty, Sg)

#define FUNCTION_MANGLING \
MANGLING_CONCAT2(NO_ARGS_MANGLING, FUNC_TYPE_MANGLING)

#define DIFF_FUNCTION_MANGLING \
MANGLING_CONCAT2(NO_ARGS_MANGLING, DIFF_FUNC_TYPE_MANGLING)

#define NOESCAPE_FUNCTION_MANGLING \
MANGLING_CONCAT2(NO_ARGS_MANGLING, NOESCAPE_FUNC_TYPE_MANGLING)

Expand Down
8 changes: 7 additions & 1 deletion include/swift/Runtime/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,13 @@ SWIFT_RUNTIME_EXPORT
const ValueWitnessTable
VALUE_WITNESS_SYM(FUNCTION_MANGLING); // () -> ()

// The @escaping () -> () table can be used for arbitrary escaping function types.
// The @differentiable(reverse) () -> () table can be used for differentiable
// function types.
SWIFT_RUNTIME_EXPORT
const ValueWitnessTable
VALUE_WITNESS_SYM(DIFF_FUNCTION_MANGLING); // @differentiable(reverse) () -> ()

// The @noescape () -> () table can be used for arbitrary noescaping function types.
SWIFT_RUNTIME_EXPORT
const ValueWitnessTable
VALUE_WITNESS_SYM(NOESCAPE_FUNCTION_MANGLING); // @noescape () -> ()
Expand Down
18 changes: 18 additions & 0 deletions stdlib/public/runtime/KnownMetadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,31 @@ namespace {
return FunctionPointerBox::getExtraInhabitantTag((void *const *)src);
}
};
struct DiffFunctionBox
: AggregateBox<ThickFunctionBox, ThickFunctionBox, ThickFunctionBox> {

static constexpr unsigned numExtraInhabitants =
ThickFunctionBox::numExtraInhabitants;

static void storeExtraInhabitantTag(char *dest, unsigned tag) {
ThickFunctionBox::storeExtraInhabitantTag(dest, tag);
}

static unsigned getExtraInhabitantTag(const char *src) {
return ThickFunctionBox::getExtraInhabitantTag(src);
}
};
} // end anonymous namespace

/// The basic value-witness table for escaping function types.
const ValueWitnessTable
swift::VALUE_WITNESS_SYM(FUNCTION_MANGLING) =
ValueWitnessTableForBox<ThickFunctionBox>::table;

const ValueWitnessTable
swift::VALUE_WITNESS_SYM(DIFF_FUNCTION_MANGLING) =
ValueWitnessTableForBox<DiffFunctionBox>::table;

/// The basic value-witness table for @noescape function types.
const ValueWitnessTable
swift::VALUE_WITNESS_SYM(NOESCAPE_FUNCTION_MANGLING) =
Expand Down
11 changes: 10 additions & 1 deletion stdlib/public/runtime/Metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,16 @@ FunctionCacheEntry::FunctionCacheEntry(const Key &key) {
if (!flags.isEscaping()) {
Data.ValueWitnesses = &VALUE_WITNESS_SYM(NOESCAPE_FUNCTION_MANGLING);
} else {
Data.ValueWitnesses = &VALUE_WITNESS_SYM(FUNCTION_MANGLING);
switch (key.getDifferentiabilityKind().Value) {
case FunctionMetadataDifferentiabilityKind::Reverse:
Data.ValueWitnesses = &VALUE_WITNESS_SYM(DIFF_FUNCTION_MANGLING);
break;
default:
assert(false && "unsupported function witness");
case FunctionMetadataDifferentiabilityKind::NonDifferentiable:
Data.ValueWitnesses = &VALUE_WITNESS_SYM(FUNCTION_MANGLING);
break;
}
}
break;

Expand Down
2 changes: 0 additions & 2 deletions test/AutoDiff/validation-test/derivative_registration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ DerivativeRegistrationTests.testWithLeakChecking("DerivativeGenericSignature") {
expectEqual(1000, dx)
}

#if REQUIRES_SR14042
// When non-canonicalized generic signatures are used to compare derivative configurations, the
// `@differentiable` and `@derivative` attributes create separate derivatives, and we get a
// duplicate symbol error in TBDGen.
Expand All @@ -237,7 +236,6 @@ DerivativeRegistrationTests.testWithLeakChecking("NonCanonicalizedGenericSignatu
// give a gradient of 1).
expectEqual(0, dx)
}
#endif

// Test derivatives of default implementations.
protocol HasADefaultImplementation {
Expand Down
29 changes: 10 additions & 19 deletions test/AutoDiff/validation-test/reabstraction.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ extension Float: HasFloat {
init(float: Float) { self = float }
}

#if REQUIRES_SR14042
ReabstractionE2ETests.test("diff param generic => concrete") {
func inner<T: HasFloat>(x: T) -> Float {
7 * x.float * x.float
Expand All @@ -71,7 +70,6 @@ ReabstractionE2ETests.test("diff param generic => concrete") {
expectEqual(Float(7 * 3 * 3), transformed(3))
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: transformed))
}
#endif

ReabstractionE2ETests.test("nondiff param generic => concrete") {
func inner<T: HasFloat>(x: Float, y: T) -> Float {
Expand All @@ -82,7 +80,6 @@ ReabstractionE2ETests.test("nondiff param generic => concrete") {
expectEqual(Float(7 * 2 * 3), gradient(at: 3) { transformed($0, 10) })
}

#if REQUIRES_SR14042
ReabstractionE2ETests.test("diff param and nondiff param generic => concrete") {
func inner<T: HasFloat>(x: T, y: T) -> Float {
7 * x.float * x.float + y.float
Expand All @@ -91,9 +88,7 @@ ReabstractionE2ETests.test("diff param and nondiff param generic => concrete") {
expectEqual(Float(7 * 3 * 3 + 10), transformed(3, 10))
expectEqual(Float(7 * 2 * 3), gradient(at: 3) { transformed($0, 10) })
}
#endif

#if REQUIRES_SR14042
ReabstractionE2ETests.test("result generic => concrete") {
func inner<T: HasFloat>(x: Float) -> T {
T(float: 7 * x * x)
Expand All @@ -102,7 +97,6 @@ ReabstractionE2ETests.test("result generic => concrete") {
expectEqual(Float(7 * 3 * 3), transformed(3))
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: transformed))
}
#endif

ReabstractionE2ETests.test("diff param concrete => generic => concrete") {
typealias FnTy<T: Differentiable> = @differentiable(reverse) (T) -> Float
Expand Down Expand Up @@ -152,21 +146,19 @@ ReabstractionE2ETests.test("@differentiable(reverse) function => opaque generic
func id<T>(_ t: T) -> T { t }
let inner: @differentiable(reverse) (Float) -> Float = { 7 * $0 * $0 }

// TODO(TF-1122): Actually using `id` causes a segfault at runtime.
// let transformed = id(inner)
// expectEqual(Float(7 * 3 * 3), transformed(3))
// expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: id(inner)))
let transformed = id(inner)
expectEqual(Float(7 * 3 * 3), transformed(3))
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: id(inner)))
}

ReabstractionE2ETests.test("@differentiable(reverse) function => opaque Any => concrete") {
func id(_ any: Any) -> Any { any }
let inner: @differentiable(reverse) (Float) -> Float = { 7 * $0 * $0 }

// TODO(TF-1122): Actually using `id` causes a segfault at runtime.
// let transformed = id(inner)
// let casted = transformed as! @differentiable(reverse) (Float) -> Float
// expectEqual(Float(7 * 3 * 3), casted(3))
// expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: casted))
let transformed = id(inner)
let casted = transformed as! @differentiable(reverse) (Float) -> Float
expectEqual(Float(7 * 3 * 3), casted(3))
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: casted))
}

ReabstractionE2ETests.test("access @differentiable(reverse) function using KeyPath") {
Expand All @@ -176,10 +168,9 @@ ReabstractionE2ETests.test("access @differentiable(reverse) function using KeyPa
let container = Container(f: { 7 * $0 * $0 })
let kp = \Container.f

// TODO(TF-1122): Actually using `kp` causes a segfault at runtime.
// let extracted = container[keyPath: kp]
// expectEqual(Float(7 * 3 * 3), extracted(3))
// expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: extracted))
let extracted = container[keyPath: kp]
expectEqual(Float(7 * 3 * 3), extracted(3))
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: extracted))
}

runAllTests()