Skip to content

Commit 82db8fa

Browse files
author
marcrasi
authored
[AutoDiff] devirtualize diff witnesses (#28480)
1 parent a1aace2 commit 82db8fa

File tree

13 files changed

+184
-20
lines changed

13 files changed

+184
-20
lines changed

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ class SILDifferentiabilityWitness
9191
GenericSignature derivativeGenSig, SILFunction *jvp, SILFunction *vjp,
9292
bool isSerialized, DeclAttribute *attribute = nullptr);
9393

94+
void convertToDefinition(SILFunction *jvp, SILFunction *vjp,
95+
bool isSerialized);
96+
9497
SILDifferentiabilityWitnessKey getKey() const;
9598
SILModule &getModule() const { return Module; }
9699
SILLinkage getLinkage() const { return Linkage; }

include/swift/SIL/SILModule.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,10 @@ class SILModule {
613613
/// Look up the differentiability witness corresponding to the given key.
614614
SILDifferentiabilityWitness *
615615
lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key);
616+
617+
/// Attempt to deserialize the SILDifferentiabilityWitness. Returns true if
618+
/// deserialization succeeded, false otherwise.
619+
bool loadDifferentiabilityWitness(SILDifferentiabilityWitness *W);
616620
// SWIFT_ENABLE_TENSORFLOW_END
617621

618622
// Given a protocol, attempt to create a default witness table declaration

include/swift/SILOptimizer/PassManager/Passes.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ PASS(DiagnoseUnreachable, "diagnose-unreachable",
148148
"Diagnose Unreachable Code")
149149
PASS(DiagnosticConstantPropagation, "diagnostic-constant-propagation",
150150
"Constants Propagation for Diagnostics")
151+
PASS(DifferentiabilityWitnessDevirtualizer,
152+
"differentiability-witness-devirtualizer",
153+
"Inlines Differentiability Witnesses")
151154
PASS(EagerSpecializer, "eager-specializer",
152155
"Eager Specialization via @_specialize")
153156
PASS(EarlyCodeMotion, "early-codemotion",

lib/SIL/SILDifferentiabilityWitness.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ SILDifferentiabilityWitness *SILDifferentiabilityWitness::createDefinition(
6060
return diffWitness;
6161
}
6262

63+
void SILDifferentiabilityWitness::convertToDefinition(SILFunction *jvp,
64+
SILFunction *vjp,
65+
bool isSerialized) {
66+
assert(IsDeclaration);
67+
IsDeclaration = false;
68+
JVP = jvp;
69+
VJP = vjp;
70+
IsSerialized = isSerialized;
71+
}
72+
6373
SILDifferentiabilityWitnessKey SILDifferentiabilityWitness::getKey() const {
6474
return std::make_pair(getOriginalFunction()->getName(), getConfig());
6575
}

lib/SIL/SILModule.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,15 @@ SILModule::lookUpDifferentiabilityWitness(SILDifferentiabilityWitnessKey key) {
599599
mangler.mangleSILDifferentiabilityWitnessKey(key));
600600
}
601601

602+
bool SILModule::loadDifferentiabilityWitness(SILDifferentiabilityWitness *W) {
603+
auto *NewW = getSILLoader()->lookupDifferentiabilityWitness(W->getKey());
604+
if (!NewW)
605+
return false;
606+
607+
assert(W == NewW);
608+
return true;
609+
}
610+
602611
void SILModule::registerDeserializationNotificationHandler(
603612
std::unique_ptr<DeserializationNotificationHandler> &&handler) {
604613
deserializationNotificationHandlers.add(std::move(handler));

lib/SILGen/SILGen.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -824,13 +824,10 @@ void SILGenModule::emitDifferentiabilityWitness(
824824
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
825825
// Create new SIL differentiability witness.
826826
// Witness JVP and VJP are set below.
827-
// TODO(TF-919): Explore creating serialized differentiability witnesses.
828-
// Currently, differentiability witnesses are never serialized to avoid
829-
// deserialization issues where JVP/VJP functions cannot be found.
830827
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
831-
M, originalFunction->getLinkage(), originalFunction,
832-
loweredParamIndices, config.resultIndices, derivativeCanGenSig,
833-
/*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false);
828+
M, originalFunction->getLinkage(), originalFunction, loweredParamIndices,
829+
config.resultIndices, derivativeCanGenSig,
830+
/*jvp*/ nullptr, /*vjp*/ nullptr, originalFunction->isSerialized());
834831

835832
// Set derivative function in differentiability witness.
836833
auto setDerivativeInDifferentiabilityWitness =

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,11 @@ static void addPerfEarlyModulePassPipeline(SILPassPipelinePlan &P) {
404404
// we do not spend time optimizing them.
405405
P.addDeadFunctionElimination();
406406

407+
// SWIFT_ENABLE_TENSORFLOW
408+
// This unblocks many other passes' optimizations (e.g. inlining) and this is
409+
// not blocked by any other passes' optimizations, so do it early.
410+
P.addDifferentiabilityWitnessDevirtualizer();
411+
407412
// Strip ownership from non-transparent functions.
408413
if (P.getOptions().StripOwnershipAfterSerialization)
409414
P.addNonTransparentFunctionOwnershipModelEliminator();

lib/SILOptimizer/Transforms/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ silopt_register_sources(
1717
DeadStoreElimination.cpp
1818
DestroyHoisting.cpp
1919
Devirtualizer.cpp
20+
# SWIFT_ENABLE_TENSORFLOW
21+
DifferentiabilityWitnessDevirtualizer.cpp
22+
# SWIFT_ENABLE_TENSORFLOW_END
2023
GenericSpecializer.cpp
2124
MergeCondFail.cpp
2225
Outliner.cpp
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
//===--- DifferentiabilityWitnessDevirtualizer.cpp ------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2019 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// Devirtualized differentiability witnesses whose bodies are availabe, by
14+
// turning "differentiability_witness_function" instructions into "function_ref"
15+
// instructions referencing the appropriate function.
16+
//
17+
//===----------------------------------------------------------------------===//
18+
19+
#include "swift/SIL/SILBuilder.h"
20+
#include "swift/SIL/SILFunction.h"
21+
#include "swift/SIL/SILInstruction.h"
22+
#include "swift/SILOptimizer/PassManager/Transforms.h"
23+
24+
using namespace swift;
25+
26+
namespace {
27+
class DifferentiabilityWitnessDevirtualizer : public SILFunctionTransform {
28+
29+
/// Returns true if and changes were made.
30+
bool devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f);
31+
32+
/// The entry point to the transformation.
33+
void run() override {
34+
if (devirtualizeDifferentiabilityWitnessesInFunction(*getFunction()))
35+
invalidateAnalysis(SILAnalysis::InvalidationKind::CallsAndInstructions);
36+
}
37+
};
38+
} // end anonymous namespace
39+
40+
bool DifferentiabilityWitnessDevirtualizer::
41+
devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f) {
42+
bool changed = false;
43+
llvm::SmallVector<DifferentiabilityWitnessFunctionInst *, 8> insts;
44+
for (auto &bb : f) {
45+
for (auto &inst : bb) {
46+
auto *dfwi = dyn_cast<DifferentiabilityWitnessFunctionInst>(&inst);
47+
if (!dfwi)
48+
continue;
49+
insts.push_back(dfwi);
50+
}
51+
}
52+
for (auto *inst : insts) {
53+
auto *wit = inst->getWitness();
54+
if (wit->isDeclaration())
55+
f.getModule().loadDifferentiabilityWitness(wit);
56+
if (wit->isDeclaration())
57+
continue;
58+
changed = true;
59+
SILBuilderWithScope builder(inst);
60+
auto kind = inst->getWitnessKind().getAsDerivativeFunctionKind();
61+
assert(kind.hasValue());
62+
auto *newInst = builder.createFunctionRefFor(inst->getLoc(),
63+
wit->getDerivative(*kind));
64+
inst->replaceAllUsesWith(newInst);
65+
inst->getParent()->erase(inst);
66+
}
67+
return changed;
68+
}
69+
70+
SILTransform *swift::createDifferentiabilityWitnessDevirtualizer() {
71+
return new DifferentiabilityWitnessDevirtualizer();
72+
}

lib/Serialization/DeserializeSIL.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3459,18 +3459,29 @@ SILDeserializer::readDifferentiabilityWitness(DeclID DId) {
34593459
ArrayRef<unsigned>(parameterAndResultIndices)
34603460
.take_back(numResultIndices));
34613461

3462-
if (isDeclaration) {
3463-
auto *diffWitness = SILDifferentiabilityWitness::createDeclaration(
3462+
AutoDiffConfig config(parameterIndices, resultIndices, derivativeGenSig);
3463+
auto *diffWitness =
3464+
SILMod.lookUpDifferentiabilityWitness({originalName, config});
3465+
3466+
// If there is no existing differentiability witness, create one.
3467+
if (!diffWitness)
3468+
diffWitness = SILDifferentiabilityWitness::createDeclaration(
34643469
SILMod, *linkage, original, parameterIndices, resultIndices,
34653470
derivativeGenSig);
3466-
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true);
3467-
return diffWitness;
3468-
}
34693471

3470-
auto *diffWitness = SILDifferentiabilityWitness::createDefinition(
3471-
SILMod, *linkage, original, parameterIndices, resultIndices,
3472-
derivativeGenSig, jvp, vjp, isSerialized);
3473-
diffWitnessOrOffset.set(diffWitness, /*isFullyDeserialized*/ true);
3472+
// If the current differentiability witness is merely a declaration, and the
3473+
// deserialized witness is a definition, upgrade the current differentiability
3474+
// witness to a definition. This can happen in the following situations:
3475+
// 1. The witness was just created above.
3476+
// 2. The witness started out as a declaration (e.g. the differentiation
3477+
// pass emitted a witness for an external function) and now we're loading
3478+
// the definition (e.g. an optimization pass asked for the definition and
3479+
// we found the definition serialized in this module).
3480+
if (diffWitness->isDeclaration() && !isDeclaration)
3481+
diffWitness->convertToDefinition(jvp, vjp, isSerialized);
3482+
3483+
diffWitnessOrOffset.set(diffWitness,
3484+
/*isFullyDeserialized*/ diffWitness->isDefinition());
34743485
return diffWitness;
34753486
}
34763487

lib/Serialization/SerializedSILLoader.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,15 @@ SerializedSILLoader::lookupDifferentiabilityWitness(
133133
SILDifferentiabilityWitnessKey key) {
134134
Mangle::ASTMangler mangler;
135135
std::string mangledKey = mangler.mangleSILDifferentiabilityWitnessKey(key);
136-
for (auto &Des : LoadedSILSections)
137-
if (auto *diffWitness = Des->lookupDifferentiabilityWitness(mangledKey))
138-
return diffWitness;
139-
return nullptr;
136+
// It is possible that one module has a declaration of a
137+
// SILDifferentiabilityWitness, while another has the full definition.
138+
SILDifferentiabilityWitness *wit = nullptr;
139+
for (auto &Des : LoadedSILSections) {
140+
wit = Des->lookupDifferentiabilityWitness(mangledKey);
141+
if (wit && wit->isDefinition())
142+
return wit;
143+
}
144+
return wit;
140145
}
141146
// SWIFT_ENABLE_TENSORFLOW END
142147

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %target-sil-opt -differentiability-witness-devirtualizer %s -enable-sil-verify-all | %FileCheck %s
2+
3+
sil_stage raw
4+
5+
import Swift
6+
import Builtin
7+
8+
sil_differentiability_witness [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float {
9+
jvp: @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
10+
vjp: @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
11+
}
12+
13+
sil_differentiability_witness [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
14+
15+
// This is an example of a witness that is available (via deserialization)
16+
// even though it is not defined in the current module.
17+
// witness for static Swift.Float.+ infix(Swift.Float, Swift.Float) -> Swift.Float
18+
sil_differentiability_witness [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
19+
20+
sil @witness_defined_in_module : $@convention(thin) (Float) -> Float
21+
22+
sil @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
23+
24+
sil @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
25+
26+
sil @witness_definition_not_available : $@convention(thin) (Float) -> Float
27+
28+
sil public_external [transparent] [serialized] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
29+
30+
sil @test : $@convention(thin) (Float) -> () {
31+
bb0(%0 : $Float):
32+
%1 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float
33+
// CHECK: %1 = function_ref @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
34+
35+
%2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
36+
// CHECK: %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
37+
38+
%3 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
39+
// CHECK: %3 = function_ref @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))
40+
41+
return undef : $()
42+
}

test/AutoDiff/sil_differentiability_witness_silgen.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public struct Foo: Differentiable {
7575
public var x: Float
7676

7777
// CHECK-LABEL: // differentiability witness for Foo.x.getter
78-
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float {
78+
// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float {
7979
// CHECK-NEXT: }
8080

8181
@differentiable

0 commit comments

Comments
 (0)