Skip to content

Commit 21a4bc5

Browse files
authored
[AutoDiff] Fix differentiability witness SIL serialization. (#28463)
- Create `SILSerializer::DifferentiabilityWitnessesToEmit` to track differentiability witnesses referenced by `differentiability_witness_function` instructions. These witnesses need to be serialized. - Move differentiability witness serialization before SIL function serialization but after visiting SIL functions (`differentiability_witness_function` instructions). - Use `-emit-sorted-sil` in test/AutoDiff/sil_differentiability_witness.sil for deterministic ordering for printing and deserialization.
1 parent dd47de3 commit 21a4bc5

File tree

4 files changed

+115
-79
lines changed

4 files changed

+115
-79
lines changed

lib/Serialization/DeserializeSIL.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,9 @@ SILDeserializer::readSILFunctionChecked(DeclID FID, SILFunction *existingFn,
801801
// SIL_VTABLE or SIL_GLOBALVAR or SIL_WITNESS_TABLE record also means the end
802802
// of this SILFunction.
803803
while (kind != SIL_FUNCTION && kind != SIL_VTABLE && kind != SIL_GLOBALVAR &&
804-
kind != SIL_WITNESS_TABLE) {
804+
// SWIFT_ENABLE_TENSORFLOW
805+
kind != SIL_WITNESS_TABLE && kind != SIL_DIFFERENTIABILITY_WITNESS) {
806+
// SWIFT_ENABLE_TENSORFLOW END
805807
if (kind == SIL_BASIC_BLOCK)
806808
// Handle a SILBasicBlock record.
807809
CurrentBB = readSILBasicBlock(fn, CurrentBB, scratch);

lib/Serialization/SerializeSIL.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ namespace {
201201
/// Global variables that we've emitted a reference to.
202202
llvm::DenseSet<const SILGlobalVariable *> GlobalsToEmit;
203203

204+
// SWIFT_ENABLE_TENSORFLOW
205+
/// Referenced differentiability witnesses that need to be emitted.
206+
llvm::DenseSet<const SILDifferentiabilityWitness *>
207+
DifferentiabilityWitnessesToEmit;
208+
// SWIFT_ENABLE_TENSORFLOW END
209+
204210
/// Additional functions we might need to serialize.
205211
llvm::SmallVector<const SILFunction *, 16> Worklist;
206212

@@ -1061,6 +1067,7 @@ void SILSerializer::writeSILInstruction(const SILInstruction &SI) {
10611067
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
10621068
auto *dwfi = cast<DifferentiabilityWitnessFunctionInst>(&SI);
10631069
auto *witness = dwfi->getWitness();
1070+
DifferentiabilityWitnessesToEmit.insert(witness);
10641071
Mangle::ASTMangler mangler;
10651072
auto mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(
10661073
witness->getKey());
@@ -2718,17 +2725,6 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
27182725
writeSILDefaultWitnessTable(wt);
27192726
}
27202727

2721-
// SWIFT_ENABLE_TENSORFLOW
2722-
// Write out differentiability witnesses.
2723-
for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) {
2724-
// TODO(TF-893): Consider checking
2725-
// `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP
2726-
// functions.
2727-
if ((ShouldSerializeAll || diffWitness.isSerialized()))
2728-
writeSILDifferentiabilityWitness(diffWitness);
2729-
}
2730-
// SWIFT_ENABLE_TENSORFLOW END
2731-
27322728
// Emit only declarations if it is a module with pre-specializations.
27332729
// And only do it in optimized builds.
27342730
bool emitDeclarationsForOnoneSupport =
@@ -2751,6 +2747,26 @@ void SILSerializer::writeSILBlock(const SILModule *SILMod) {
27512747
processSILFunctionWorklist();
27522748
}
27532749

2750+
// SWIFT_ENABLE_TENSORFLOW
2751+
// Write out differentiability witnesses.
2752+
// Note: this must be done after visiting SIL functions above so that
2753+
// differentiability witness references (`differentiability_witness_function`
2754+
// instructions) have been tracked.
2755+
for (const auto &diffWitness : SILMod->getDifferentiabilityWitnessList()) {
2756+
// TODO(TF-893): Consider checking
2757+
// `SILMod->shouldSerializeEntitiesAssociatedWithDeclContext` on the JVP/VJP
2758+
// functions.
2759+
if ((ShouldSerializeAll || diffWitness.isSerialized()))
2760+
DifferentiabilityWitnessesToEmit.insert(&diffWitness);
2761+
}
2762+
for (auto *diffWitness : DifferentiabilityWitnessesToEmit)
2763+
writeSILDifferentiabilityWitness(*diffWitness);
2764+
// Process SIL functions referenced by differentiability witnesses.
2765+
// Note: this is necessary despite processing `FuncsToEmit` below because
2766+
// `Worklist` is processed separately.
2767+
processSILFunctionWorklist();
2768+
// SWIFT_ENABLE_TENSORFLOW END
2769+
27542770
// Now write function declarations for every function we've
27552771
// emitted a reference to without emitting a function body for.
27562772
for (const SILFunction &F : *SILMod) {

test/AutoDiff/sil_differentiability_witness.sil

Lines changed: 67 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
// Round-trip parsing/printing test.
22

3-
// RUN: %target-sil-opt %s | %target-sil-opt | %FileCheck --check-prefix=ROUNDTRIP %s
3+
// RUN: %target-sil-opt %s | %target-sil-opt -emit-sorted-sil | %FileCheck --check-prefix=ROUNDTRIP %s
44

55
// Round-trip serialization-deserialization test.
66

77
// RUN: %empty-directory(%t)
88
// RUN: %target-sil-opt %s -emit-sib -o %t/tmp.sib -module-name main
99
// RUN: %target-sil-opt %t/tmp.sib -o %t/tmp.2.sib -module-name main
10-
// RUN: %target-sil-opt %t/tmp.2.sib -module-name main | %FileCheck --check-prefix=ROUNDTRIP %s
10+
// RUN: %target-sil-opt %t/tmp.2.sib -module-name main -emit-sorted-sil | %FileCheck --check-prefix=ROUNDTRIP %s
1111

1212
// IRGen test.
1313

@@ -19,6 +19,71 @@ import Builtin
1919
import Swift
2020
import SwiftShims
2121

22+
// Test SIL differentiability witness for bodiless original function, with defined jvp/vjp.
23+
24+
sil @externalFn1 : $@convention(thin) (Float) -> Float
25+
26+
sil @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
27+
bb0(%0 : $Float):
28+
return undef : $(Float, @callee_guaranteed (Float) -> Float)
29+
}
30+
31+
sil @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
32+
bb0(%0 : $Float):
33+
return undef : $(Float, @callee_guaranteed (Float) -> Float)
34+
}
35+
36+
sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
37+
jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
38+
vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
39+
}
40+
41+
// ROUNDTRIP-LABEL: // differentiability witness for externalFn1
42+
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
43+
// ROUNDTRIP: jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
44+
// ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
45+
// ROUNDTRIP: }
46+
47+
// IRGEN-LABEL: @AD__externalFn1_PSRS ={{( protected)?}} global { i8*, i8* } {
48+
// IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0
49+
// IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0
50+
// IRGEN-SAME: }
51+
52+
// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp.
53+
54+
sil @externalFn2 : $@convention(thin) (Float) -> Float
55+
56+
sil @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
57+
58+
sil @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
59+
60+
sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
61+
jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
62+
vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
63+
}
64+
65+
// ROUNDTRIP-LABEL: // differentiability witness for externalFn2
66+
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
67+
// ROUNDTRIP: jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
68+
// ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
69+
// ROUNDTRIP: }
70+
71+
// IRGEN-LABEL: @AD__externalFn2_PSRS ={{( protected)?}} global { i8*, i8* } {
72+
// IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0
73+
// IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0
74+
// IRGEN-SAME: }
75+
76+
// Test SIL differentiability witness declaration.
77+
78+
sil @externalFn3 : $@convention(thin) (Float) -> Float
79+
80+
sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float
81+
82+
// ROUNDTRIP-LABEL: // differentiability witness for externalFn3
83+
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float{{[^{]*$}}
84+
85+
// IRGEN-NOT: @AD__externalFn3{{.*}}={{.*}}{ i8*, i8* }
86+
2287
// Test public non-generic function.
2388
// SIL differentiability witness:
2489
// - Has public linkage (implicit).
@@ -92,68 +157,3 @@ sil_differentiability_witness hidden [parameters 0 1] [results 0] <τ_0_0 where
92157
// IRGEN-SAME: @AD__generic__jvp_src_0_wrt_0_1
93158
// IRGEN-SAME: @AD__generic__vjp_src_0_wrt_0_1
94159
// IRGEN-SAME: }
95-
96-
// Test SIL differentiability witness for bodiless original function, with defined jvp/vjp.
97-
98-
sil @externalFn1 : $@convention(thin) (Float) -> Float
99-
100-
sil @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
101-
bb0(%0 : $Float):
102-
return undef : $(Float, @callee_guaranteed (Float) -> Float)
103-
}
104-
105-
sil @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float) {
106-
bb0(%0 : $Float):
107-
return undef : $(Float, @callee_guaranteed (Float) -> Float)
108-
}
109-
110-
sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
111-
jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
112-
vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
113-
}
114-
115-
// ROUNDTRIP-LABEL: // differentiability witness for externalFn1
116-
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn1 : $@convention(thin) (Float) -> Float {
117-
// ROUNDTRIP: jvp: @AD__externalFn1__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
118-
// ROUNDTRIP: vjp: @AD__externalFn1__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
119-
// ROUNDTRIP: }
120-
121-
// IRGEN-LABEL: @AD__externalFn1_PSRS ={{( protected)?}} global { i8*, i8* } {
122-
// IRGEN-SAME: @AD__externalFn1__jvp_src_0_wrt_0
123-
// IRGEN-SAME: @AD__externalFn1__vjp_src_0_wrt_0
124-
// IRGEN-SAME: }
125-
126-
// Test SIL differentiability witness for bodiless original function, with bodiless jvp/vjp.
127-
128-
sil @externalFn2 : $@convention(thin) (Float) -> Float
129-
130-
sil @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
131-
132-
sil @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
133-
134-
sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
135-
jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
136-
vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
137-
}
138-
139-
// ROUNDTRIP-LABEL: // differentiability witness for externalFn2
140-
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn2 : $@convention(thin) (Float) -> Float {
141-
// ROUNDTRIP: jvp: @AD__externalFn2__jvp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
142-
// ROUNDTRIP: vjp: @AD__externalFn2__vjp_src_0_wrt_0 : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
143-
// ROUNDTRIP: }
144-
145-
// IRGEN-LABEL: @AD__externalFn2_PSRS ={{( protected)?}} global { i8*, i8* } {
146-
// IRGEN-SAME: @AD__externalFn2__jvp_src_0_wrt_0
147-
// IRGEN-SAME: @AD__externalFn2__vjp_src_0_wrt_0
148-
// IRGEN-SAME: }
149-
150-
// Test SIL differentiability witness declaration.
151-
152-
sil @externalFn3 : $@convention(thin) (Float) -> Float
153-
154-
sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float
155-
156-
// ROUNDTRIP-LABEL: // differentiability witness for externalFn3
157-
// ROUNDTRIP: sil_differentiability_witness [parameters 0] [results 0] @externalFn3 : $@convention(thin) (Float) -> Float{{[^{]*$}}
158-
159-
// IRGEN-NOT: @AD__externalFn3{{.*}}={{.*}}{ i8*, i8* }
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend -emit-module -emit-module-path %t/test.swiftmodule -module-name test %s
3+
// RUN: %target-sil-opt %t/test.swiftmodule
4+
5+
sil_stage raw
6+
7+
import Swift
8+
import Builtin
9+
10+
sil_differentiability_witness [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float
11+
12+
sil @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float
13+
14+
sil [serialized] @test_serialized : $@convention(thin) () -> () {
15+
bb0:
16+
%referenced_from_serialized_jvp_wrt_0 = differentiability_witness_function [jvp] [parameters 0] [results 0] @referenced_from_serialized : $@convention(thin) (Float, Float, Float) -> Float
17+
return undef : $()
18+
}

0 commit comments

Comments
 (0)