Skip to content

Commit 4cd68ed

Browse files
ematejskadan-zheng
andauthored
[Autodiff upstream] Add DifferentiabilityWitnessDevirtualizer SILOptimizer pass (#30984)
Add DifferentiabilityWitnessDevirtualizer: an optimization pass that devirtualizes `differentiability_witness_function` instructions into `function_ref` instructions. Co-authored-by: Dan Zheng <[email protected]>
1 parent 120dcac commit 4cd68ed

File tree

5 files changed

+123
-0
lines changed

5 files changed

+123
-0
lines changed

include/swift/SILOptimizer/PassManager/Passes.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ PASS(DiagnoseUnreachable, "diagnose-unreachable",
152152
"Diagnose Unreachable Code")
153153
PASS(DiagnosticConstantPropagation, "diagnostic-constant-propagation",
154154
"Constants Propagation for Diagnostics")
155+
PASS(DifferentiabilityWitnessDevirtualizer,
156+
"differentiability-witness-devirtualizer",
157+
"Inlines Differentiability Witnesses")
155158
PASS(EagerSpecializer, "eager-specializer",
156159
"Eager Specialization via @_specialize")
157160
PASS(EarlyCodeMotion, "early-codemotion",

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,11 @@ static void addPerfEarlyModulePassPipeline(SILPassPipelinePlan &P) {
408408
// Cleanup after SILGen: remove unneeded borrows/copies.
409409
P.addSemanticARCOpts();
410410

411+
// Devirtualizes differentiability witnesses into functions that reference them.
412+
// This unblocks many other passes' optimizations (e.g. inlining) and this is
413+
// not blocked by any other passes' optimizations, so do it early.
414+
P.addDifferentiabilityWitnessDevirtualizer();
415+
411416
// Strip ownership from non-transparent functions.
412417
if (P.getOptions().StripOwnershipAfterSerialization)
413418
P.addNonTransparentFunctionOwnershipModelEliminator();

lib/SILOptimizer/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ silopt_register_sources(
1717
DeadStoreElimination.cpp
1818
DestroyHoisting.cpp
1919
Devirtualizer.cpp
20+
DifferentiabilityWitnessDevirtualizer.cpp
2021
GenericSpecializer.cpp
2122
MergeCondFail.cpp
2223
Outliner.cpp
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
//===--- DifferentiabilityWitnessDevirtualizer.cpp ------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2020 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// Devirtualizes `differentiability_witness_function` instructions into
14+
// `function_ref` instructions for differentiability witness definitions.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#include "swift/SIL/SILBuilder.h"
19+
#include "swift/SIL/SILFunction.h"
20+
#include "swift/SIL/SILInstruction.h"
21+
#include "swift/SILOptimizer/PassManager/Transforms.h"
22+
23+
using namespace swift;
24+
25+
namespace {
26+
class DifferentiabilityWitnessDevirtualizer : public SILFunctionTransform {
27+
28+
/// Returns true if any changes were made.
29+
bool devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f);
30+
31+
/// The entry point to the transformation.
32+
void run() override {
33+
if (devirtualizeDifferentiabilityWitnessesInFunction(*getFunction()))
34+
invalidateAnalysis(SILAnalysis::InvalidationKind::CallsAndInstructions);
35+
}
36+
};
37+
} // end anonymous namespace
38+
39+
bool DifferentiabilityWitnessDevirtualizer::
40+
devirtualizeDifferentiabilityWitnessesInFunction(SILFunction &f) {
41+
bool changed = false;
42+
llvm::SmallVector<DifferentiabilityWitnessFunctionInst *, 8> insts;
43+
for (auto &bb : f) {
44+
for (auto &inst : bb) {
45+
auto *dfwi = dyn_cast<DifferentiabilityWitnessFunctionInst>(&inst);
46+
if (!dfwi)
47+
continue;
48+
insts.push_back(dfwi);
49+
}
50+
}
51+
for (auto *inst : insts) {
52+
auto *witness = inst->getWitness();
53+
if (witness->isDeclaration())
54+
f.getModule().loadDifferentiabilityWitness(witness);
55+
if (witness->isDeclaration())
56+
continue;
57+
changed = true;
58+
SILBuilderWithScope builder(inst);
59+
auto kind = inst->getWitnessKind().getAsDerivativeFunctionKind();
60+
assert(kind.hasValue());
61+
auto *newInst = builder.createFunctionRefFor(inst->getLoc(),
62+
witness->getDerivative(*kind));
63+
inst->replaceAllUsesWith(newInst);
64+
inst->getParent()->erase(inst);
65+
}
66+
return changed;
67+
}
68+
69+
SILTransform *swift::createDifferentiabilityWitnessDevirtualizer() {
70+
return new DifferentiabilityWitnessDevirtualizer();
71+
}
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: %target-sil-opt -differentiability-witness-devirtualizer %s -enable-sil-verify-all | %FileCheck %s
2+
3+
sil_stage raw
4+
5+
import _Differentiation
6+
import Swift
7+
import Builtin
8+
9+
sil_differentiability_witness [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float {
10+
jvp: @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
11+
vjp: @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
12+
}
13+
14+
sil_differentiability_witness [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
15+
16+
// This is an example of a witness that is available (via deserialization)
17+
// even though it is not defined in the current module.
18+
// witness for static Swift.Float.+ infix(Swift.Float, Swift.Float) -> Swift.Float
19+
sil_differentiability_witness [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
20+
21+
sil @witness_defined_in_module : $@convention(thin) (Float) -> Float
22+
23+
sil @witness_defined_in_module_jvp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
24+
25+
sil @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
26+
27+
sil @witness_definition_not_available : $@convention(thin) (Float) -> Float
28+
29+
sil public_external [transparent] [serialized] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
30+
31+
sil @test : $@convention(thin) (Float) -> () {
32+
bb0(%0 : $Float):
33+
%1 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_defined_in_module : $@convention(thin) (Float) -> Float
34+
// CHECK: %1 = function_ref @witness_defined_in_module_vjp : $@convention(thin) (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
35+
36+
%2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
37+
// CHECK: %2 = differentiability_witness_function [vjp] [parameters 0] [results 0] @witness_definition_not_available : $@convention(thin) (Float) -> Float
38+
39+
%3 = differentiability_witness_function [vjp] [parameters 0 1] [results 0] @$sSf1poiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
40+
// CHECK: %3 = function_ref @AD__$sSf1poiyS2f_SftFZ__vjp_src_0_wrt_0_1 : $@convention(method) (Float, Float, @thin Float.Type) -> (Float, @owned @callee_guaranteed (Float) -> (Float, Float))
41+
42+
return undef : $()
43+
}

0 commit comments

Comments
 (0)