Skip to content

Commit c89e270

Browse files
authored
Implement value witness table for @differentiable functions (#60875)
@differentiable function is actually a triple (function, jvp, vjp). Previously normal thick function value witness table was used. As a result, for example, only function was copied, but none of differential components. This was the cause of uninitialized memory accesses and subsequent segfaults. Should fix now unavailable TF-1122
1 parent 983e2f3 commit c89e270

File tree

6 files changed

+49
-23
lines changed

6 files changed

+49
-23
lines changed

include/swift/Demangling/ManglingMacros.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,16 @@ _Pragma("clang diagnostic pop")
4646
#define NO_ARGS_MANGLING yy
4747
#define FUNC_TYPE_MANGLING c
4848
#define NOESCAPE_FUNC_TYPE_MANGLING XE
49+
#define DIFF_FUNC_TYPE_MANGLING Yjrc
4950
#define OBJC_PARTIAL_APPLY_THUNK_MANGLING Ta
5051
#define OPTIONAL_MANGLING(Ty) MANGLING_CONCAT2_IMPL(Ty, Sg)
5152

5253
#define FUNCTION_MANGLING \
5354
MANGLING_CONCAT2(NO_ARGS_MANGLING, FUNC_TYPE_MANGLING)
5455

56+
#define DIFF_FUNCTION_MANGLING \
57+
MANGLING_CONCAT2(NO_ARGS_MANGLING, DIFF_FUNC_TYPE_MANGLING)
58+
5559
#define NOESCAPE_FUNCTION_MANGLING \
5660
MANGLING_CONCAT2(NO_ARGS_MANGLING, NOESCAPE_FUNC_TYPE_MANGLING)
5761

include/swift/Runtime/Metadata.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,13 @@ SWIFT_RUNTIME_EXPORT
142142
const ValueWitnessTable
143143
VALUE_WITNESS_SYM(FUNCTION_MANGLING); // () -> ()
144144

145-
// The @escaping () -> () table can be used for arbitrary escaping function types.
145+
// The @differentiable(reverse) () -> () table can be used for differentiable
146+
// function types.
147+
SWIFT_RUNTIME_EXPORT
148+
const ValueWitnessTable
149+
VALUE_WITNESS_SYM(DIFF_FUNCTION_MANGLING); // @differentiable(reverse) () -> ()
150+
151+
// The @noescape () -> () table can be used for arbitrary noescaping function types.
146152
SWIFT_RUNTIME_EXPORT
147153
const ValueWitnessTable
148154
VALUE_WITNESS_SYM(NOESCAPE_FUNCTION_MANGLING); // @noescape () -> ()

stdlib/public/runtime/KnownMetadata.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,31 @@ namespace {
233233
return FunctionPointerBox::getExtraInhabitantTag((void *const *)src);
234234
}
235235
};
236+
struct DiffFunctionBox
237+
: AggregateBox<ThickFunctionBox, ThickFunctionBox, ThickFunctionBox> {
238+
239+
static constexpr unsigned numExtraInhabitants =
240+
ThickFunctionBox::numExtraInhabitants;
241+
242+
static void storeExtraInhabitantTag(char *dest, unsigned tag) {
243+
ThickFunctionBox::storeExtraInhabitantTag(dest, tag);
244+
}
245+
246+
static unsigned getExtraInhabitantTag(const char *src) {
247+
return ThickFunctionBox::getExtraInhabitantTag(src);
248+
}
249+
};
236250
} // end anonymous namespace
237251

238252
/// The basic value-witness table for escaping function types.
239253
const ValueWitnessTable
240254
swift::VALUE_WITNESS_SYM(FUNCTION_MANGLING) =
241255
ValueWitnessTableForBox<ThickFunctionBox>::table;
242256

257+
const ValueWitnessTable
258+
swift::VALUE_WITNESS_SYM(DIFF_FUNCTION_MANGLING) =
259+
ValueWitnessTableForBox<DiffFunctionBox>::table;
260+
243261
/// The basic value-witness table for @noescape function types.
244262
const ValueWitnessTable
245263
swift::VALUE_WITNESS_SYM(NOESCAPE_FUNCTION_MANGLING) =

stdlib/public/runtime/Metadata.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1337,7 +1337,16 @@ FunctionCacheEntry::FunctionCacheEntry(const Key &key) {
13371337
if (!flags.isEscaping()) {
13381338
Data.ValueWitnesses = &VALUE_WITNESS_SYM(NOESCAPE_FUNCTION_MANGLING);
13391339
} else {
1340-
Data.ValueWitnesses = &VALUE_WITNESS_SYM(FUNCTION_MANGLING);
1340+
switch (key.getDifferentiabilityKind().Value) {
1341+
case FunctionMetadataDifferentiabilityKind::Reverse:
1342+
Data.ValueWitnesses = &VALUE_WITNESS_SYM(DIFF_FUNCTION_MANGLING);
1343+
break;
1344+
default:
1345+
assert(false && "unsupported function witness");
1346+
case FunctionMetadataDifferentiabilityKind::NonDifferentiable:
1347+
Data.ValueWitnesses = &VALUE_WITNESS_SYM(FUNCTION_MANGLING);
1348+
break;
1349+
}
13411350
}
13421351
break;
13431352

test/AutoDiff/validation-test/derivative_registration.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ DerivativeRegistrationTests.testWithLeakChecking("DerivativeGenericSignature") {
217217
expectEqual(1000, dx)
218218
}
219219

220-
#if REQUIRES_SR14042
221220
// When non-canonicalized generic signatures are used to compare derivative configurations, the
222221
// `@differentiable` and `@derivative` attributes create separate derivatives, and we get a
223222
// duplicate symbol error in TBDGen.
@@ -237,7 +236,6 @@ DerivativeRegistrationTests.testWithLeakChecking("NonCanonicalizedGenericSignatu
237236
// give a gradient of 1).
238237
expectEqual(0, dx)
239238
}
240-
#endif
241239

242240
// Test derivatives of default implementations.
243241
protocol HasADefaultImplementation {

test/AutoDiff/validation-test/reabstraction.swift

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ extension Float: HasFloat {
6262
init(float: Float) { self = float }
6363
}
6464

65-
#if REQUIRES_SR14042
6665
ReabstractionE2ETests.test("diff param generic => concrete") {
6766
func inner<T: HasFloat>(x: T) -> Float {
6867
7 * x.float * x.float
@@ -71,7 +70,6 @@ ReabstractionE2ETests.test("diff param generic => concrete") {
7170
expectEqual(Float(7 * 3 * 3), transformed(3))
7271
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: transformed))
7372
}
74-
#endif
7573

7674
ReabstractionE2ETests.test("nondiff param generic => concrete") {
7775
func inner<T: HasFloat>(x: Float, y: T) -> Float {
@@ -82,7 +80,6 @@ ReabstractionE2ETests.test("nondiff param generic => concrete") {
8280
expectEqual(Float(7 * 2 * 3), gradient(at: 3) { transformed($0, 10) })
8381
}
8482

85-
#if REQUIRES_SR14042
8683
ReabstractionE2ETests.test("diff param and nondiff param generic => concrete") {
8784
func inner<T: HasFloat>(x: T, y: T) -> Float {
8885
7 * x.float * x.float + y.float
@@ -91,9 +88,7 @@ ReabstractionE2ETests.test("diff param and nondiff param generic => concrete") {
9188
expectEqual(Float(7 * 3 * 3 + 10), transformed(3, 10))
9289
expectEqual(Float(7 * 2 * 3), gradient(at: 3) { transformed($0, 10) })
9390
}
94-
#endif
9591

96-
#if REQUIRES_SR14042
9792
ReabstractionE2ETests.test("result generic => concrete") {
9893
func inner<T: HasFloat>(x: Float) -> T {
9994
T(float: 7 * x * x)
@@ -102,7 +97,6 @@ ReabstractionE2ETests.test("result generic => concrete") {
10297
expectEqual(Float(7 * 3 * 3), transformed(3))
10398
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: transformed))
10499
}
105-
#endif
106100

107101
ReabstractionE2ETests.test("diff param concrete => generic => concrete") {
108102
typealias FnTy<T: Differentiable> = @differentiable(reverse) (T) -> Float
@@ -152,21 +146,19 @@ ReabstractionE2ETests.test("@differentiable(reverse) function => opaque generic
152146
func id<T>(_ t: T) -> T { t }
153147
let inner: @differentiable(reverse) (Float) -> Float = { 7 * $0 * $0 }
154148

155-
// TODO(TF-1122): Actually using `id` causes a segfault at runtime.
156-
// let transformed = id(inner)
157-
// expectEqual(Float(7 * 3 * 3), transformed(3))
158-
// expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: id(inner)))
149+
let transformed = id(inner)
150+
expectEqual(Float(7 * 3 * 3), transformed(3))
151+
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: id(inner)))
159152
}
160153

161154
ReabstractionE2ETests.test("@differentiable(reverse) function => opaque Any => concrete") {
162155
func id(_ any: Any) -> Any { any }
163156
let inner: @differentiable(reverse) (Float) -> Float = { 7 * $0 * $0 }
164157

165-
// TODO(TF-1122): Actually using `id` causes a segfault at runtime.
166-
// let transformed = id(inner)
167-
// let casted = transformed as! @differentiable(reverse) (Float) -> Float
168-
// expectEqual(Float(7 * 3 * 3), casted(3))
169-
// expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: casted))
158+
let transformed = id(inner)
159+
let casted = transformed as! @differentiable(reverse) (Float) -> Float
160+
expectEqual(Float(7 * 3 * 3), casted(3))
161+
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: casted))
170162
}
171163

172164
ReabstractionE2ETests.test("access @differentiable(reverse) function using KeyPath") {
@@ -176,10 +168,9 @@ ReabstractionE2ETests.test("access @differentiable(reverse) function using KeyPa
176168
let container = Container(f: { 7 * $0 * $0 })
177169
let kp = \Container.f
178170

179-
// TODO(TF-1122): Actually using `kp` causes a segfault at runtime.
180-
// let extracted = container[keyPath: kp]
181-
// expectEqual(Float(7 * 3 * 3), extracted(3))
182-
// expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: extracted))
171+
let extracted = container[keyPath: kp]
172+
expectEqual(Float(7 * 3 * 3), extracted(3))
173+
expectEqual(Float(7 * 2 * 3), gradient(at: 3, of: extracted))
183174
}
184175

185176
runAllTests()

0 commit comments

Comments
 (0)