Skip to content

Commit 1df0277

Browse files
committed
Autodiff Upstream : Adding DifferentiabilityWitnessDevirtualizer
1 parent e583f3a commit 1df0277

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

include/swift/SILOptimizer/PassManager/Passes.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ PASS(DiagnoseUnreachable, "diagnose-unreachable",
152152
"Diagnose Unreachable Code")
153153
PASS(DiagnosticConstantPropagation, "diagnostic-constant-propagation",
154154
"Constants Propagation for Diagnostics")
155+
PASS(DifferentiabilityWitnessDevirtualizer,
156+
"differentiability-witness-devirtualizer",
157+
"Inlines Differentiability Witnesses")
155158
PASS(EagerSpecializer, "eager-specializer",
156159
"Eager Specialization via @_specialize")
157160
PASS(EarlyCodeMotion, "early-codemotion",

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,11 @@ static void addPerfEarlyModulePassPipeline(SILPassPipelinePlan &P) {
408408
// Cleanup after SILGen: remove unneeded borrows/copies.
409409
P.addSemanticARCOpts();
410410

411+
// Devirtualizes differentiability witnesses into functions that reference them.
412+
// This unblocks many other passes' optimizations (e.g. inlining) and this is
413+
// not blocked by any other passes' optimizations, so do it early.
414+
P.addDifferentiabilityWitnessDevirtualizer();
415+
411416
// Strip ownership from non-transparent functions.
412417
if (P.getOptions().StripOwnershipAfterSerialization)
413418
P.addNonTransparentFunctionOwnershipModelEliminator();

lib/SILOptimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ silopt_register_sources(
1717
DeadStoreElimination.cpp
1818
DestroyHoisting.cpp
1919
Devirtualizer.cpp
20+
DifferentiabilityWitnessDevirtualizer.cpp
2021
GenericSpecializer.cpp
2122
MergeCondFail.cpp
2223
Outliner.cpp
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
//===--- DifferentiabilityWitnessDevirtualizer.cpp ------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// Devirtualized differentiability witnesses whose bodies are availabe, by
14+
// turning "differentiability_witness_function" instructions into "function_ref"
15+
// instructions referencing the appropriate function.
16+
//
17+
//===----------------------------------------------------------------------===//
18+
19+
#include "swift/SIL/SILBuilder.h"
20+
#include "swift/SIL/SILFunction.h"
21+
#include "swift/SIL/SILInstruction.h"
22+
#include "swift/SILOptimizer/PassManager/Transforms.h"
23+
24+
using namespace swift;
25+
26+
namespace {
27+
class DifferentiabilityWitnessDevirtualizer : public SILFunctionTransform {
28+
29+
/// Returns true if and changes were made.
30+
bool devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f);
31+
32+
/// The entry point to the transformation.
33+
void run() override {
34+
if (devirtualizeDifferentiabilityWitnessesInFunction(*getFunction()))
35+
invalidateAnalysis(SILAnalysis::InvalidationKind::CallsAndInstructions);
36+
}
37+
};
38+
} // end anonymous namespace
39+
40+
bool DifferentiabilityWitnessDevirtualizer::
41+
devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f) {
42+
bool changed = false;
43+
llvm::SmallVector<DifferentiabilityWitnessFunctionInst *, 8> insts;
44+
for (auto &bb : f) {
45+
for (auto &inst : bb) {
46+
auto *dfwi = dyn_cast<DifferentiabilityWitnessFunctionInst>(&inst);
47+
if (!dfwi)
48+
continue;
49+
insts.push_back(dfwi);
50+
}
51+
}
52+
for (auto *inst : insts) {
53+
auto *wit = inst->getWitness();
54+
if (wit->isDeclaration())
55+
f.getModule().loadDifferentiabilityWitness(wit);
56+
if (wit->isDeclaration())
57+
continue;
58+
changed = true;
59+
SILBuilderWithScope builder(inst);
60+
auto kind = inst->getWitnessKind().getAsDerivativeFunctionKind();
61+
assert(kind.hasValue());
62+
auto *newInst = builder.createFunctionRefFor(inst->getLoc(),
63+
wit->getDerivative(*kind));
64+
inst->replaceAllUsesWith(newInst);
65+
inst->getParent()->erase(inst);
66+
}
67+
return changed;
68+
}
69+
70+
SILTransform *swift::createDifferentiabilityWitnessDevirtualizer() {
71+
return new DifferentiabilityWitnessDevirtualizer();
72+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// RUN: %target-sil-opt -differentiability-witness-devirtualizer %s -enable-sil-verify-all | %FileCheck %s
2+
3+
sil_stage raw
4+
5+
import Swift
6+
import Builtin
7+
8+
sil_differentiability_witness [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float {
9+
jvp: @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
10+
vjp: @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
11+
}
12+
13+
sil_differentiability_witness [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
14+
15+
// This is an example of a witness that is available (via deserialization)
16+
// even though it is not defined in the current module.
17+
// witness for static Swift.Float.+ infix(Swift.Float, Swift.Float) -> Swift.Float
18+
sil_differentiability_witness [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
19+
20+
sil @witness_defined_in_module : $@convention(thin) (Float) -> Float
21+
22+
sil @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
23+
24+
sil @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
25+
26+
sil @witness_definition_not_available : $@convention(thin) (Float) -> Float
27+
28+
sil public_external [transparent] [serialized] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
29+
30+
sil @test : $@convention(thin) (Float) -> () {
31+
bb0(%0 : $Float):
32+
%1 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float
33+
// CHECK: %1 = function_ref @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
34+
35+
%2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
36+
// CHECK: %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
37+
38+
%3 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
39+
// CHECK: %3 = function_ref @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))
40+
41+
return undef : $()
42+
}

0 commit comments

Comments
 (0)