Skip to content

Commit 317664d

Browse files
authored
move ad before mandatory inlining (#21560)
1 parent 73c7015 commit 317664d

File tree

7 files changed

+101
-127
lines changed

7 files changed

+101
-127
lines changed

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 64 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,28 @@ static SILFunction *lookUpOrLinkFunction(StringRef name, SILModule &module) {
9595
return module.findFunction(name, SILLinkage::PublicExternal);
9696
}
9797

98+
/// Computes the correct linkage for functions generated by the AD pass
99+
/// associated with a function with linkage `originalLinkage`.
100+
static SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage) {
101+
// If the original is defined externally, then the AD pass is just generating
102+
// associated functions for use in the current module and therefore these
103+
// associated functions should not be visible outside the module.
104+
if (isAvailableExternally(originalLinkage))
105+
return SILLinkage::Hidden;
106+
107+
// If the original is public, then external modules may need to link the
108+
// associated function. Make the associated function public.
109+
if (originalLinkage == SILLinkage::Public)
110+
return SILLinkage::Public;
111+
if (originalLinkage == SILLinkage::PublicNonABI)
112+
return SILLinkage::PublicNonABI;
113+
114+
// Otherwise, the original function is defined and used only in the current
115+
// module, so external modules will never try to access the associated
116+
// function. Make the associated function hidden.
117+
return SILLinkage::Hidden;
118+
}
119+
98120
/// Given a function, gather all of its formal results (both direct and
99121
/// indirect) in an order defined by its result type. Note that "formal results"
100122
/// refer to result values in the body of the function, not at call sites.
@@ -2104,8 +2126,11 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
21042126
// Clone.
21052127
cloneFunctionBody(original, entry, entryArgs);
21062128
// If errors occurred, back out.
2107-
if (errorOccurred)
2129+
if (errorOccurred) {
2130+
// Delete the body so that later passes don't get confused by invalid SIL.
2131+
getPrimal()->getBlocks().clear();
21082132
return true;
2133+
}
21092134
auto *origExit = &*original->findReturnBB();
21102135
auto *exit = BBMap.lookup(origExit);
21112136
assert(exit->getParent() == getPrimal());
@@ -4131,8 +4156,10 @@ void DifferentiationTask::createEmptyPrimal() {
41314156
auto indices = getIndices();
41324157
auto *original = getOriginal();
41334158
auto &module = context.getModule();
4134-
std::string primalName =
4135-
"AD__" + original->getName().str() + "__primal_" + indices.mangle();
4159+
auto primalName = original->getASTContext()
4160+
.getIdentifier("AD__" + original->getName().str() +
4161+
"__primal_" + indices.mangle())
4162+
.str();
41364163
StructDecl *primalValueStructDecl = context.createPrimalValueStruct(this);
41374164
primalInfo = std::unique_ptr<PrimalInfo>(new PrimalInfo(primalValueStructDecl, module));
41384165
auto pvType = primalValueStructDecl->getDeclaredType()->getCanonicalType();
@@ -4151,17 +4178,21 @@ void DifferentiationTask::createEmptyPrimal() {
41514178
originalTy->getParameters(), originalTy->getYields(), results,
41524179
originalTy->getOptionalErrorResult(), context.getASTContext());
41534180
SILOptFunctionBuilder fb(context.getTransform());
4154-
auto linkage = original->getLinkage();
4155-
if (linkage == SILLinkage::Public)
4156-
linkage = SILLinkage::PublicNonABI;
4181+
// We set generated primal linkage to Hidden because generated primals are
4182+
// never called cross-module in VJP mode: all cross-module calls to associated
4183+
// functions call the VJP.
4184+
// TODO: In order for cross-module calls to work in non-VJP mode, we must use
4185+
// `getAutoDiffFunctionLinkage` to make the linkage occasionally public. We'll
4186+
// also need to update TBDGen to generate TBD entries for public primals.
4187+
auto linkage = SILLinkage::Hidden;
41574188
primal = fb.getOrCreateFunction(
41584189
original->getLocation(), primalName, linkage, primalTy,
41594190
original->isBare(), original->isTransparent(), original->isSerialized());
41604191
primal->setUnqualifiedOwnership();
41614192
LLVM_DEBUG(getADDebugStream() << "Primal function created \n"
41624193
<< *primal << '\n');
41634194

4164-
attr->setPrimalName(primal->getName());
4195+
attr->setPrimalName(primalName);
41654196
}
41664197

41674198
void DifferentiationTask::createEmptyAdjoint() {
@@ -4263,24 +4294,31 @@ void DifferentiationTask::createEmptyAdjoint() {
42634294
->getCanonicalType()));
42644295
}
42654296

4266-
auto adjName = "AD__" + original->getName().str() + "__adjoint_" +
4267-
getIndices().mangle();
4297+
auto adjName = original->getASTContext()
4298+
.getIdentifier("AD__" + original->getName().str() +
4299+
"__adjoint_" + getIndices().mangle())
4300+
.str();
42684301
auto adjType = SILFunctionType::get(
42694302
origTy->getGenericSignature(), origTy->getExtInfo(),
42704303
origTy->getCoroutineKind(), origTy->getCalleeConvention(), adjParams, {},
42714304
adjResults, None, original->getASTContext());
42724305
SILOptFunctionBuilder fb(context.getTransform());
4273-
auto linkage = original->getLinkage();
4274-
if (linkage == SILLinkage::Public)
4275-
linkage = SILLinkage::PublicNonABI;
4276-
adjoint = fb.createFunction(linkage, adjName, adjType,
4277-
original->getGenericEnvironment(), original->getLocation(),
4278-
original->isBare(), original->isTransparent(), original->isSerialized());
4306+
// We set generated adjoint linkage to Hidden because generated adjoints are
4307+
// never called cross-module in VJP mode: all cross-module calls to associated
4308+
// functions call the VJP.
4309+
// TODO: In order for cross-module calls to work in non-VJP mode, we must use
4310+
// `getAutoDiffFunctionLinkage` to make the linkage occasionally public. We'll
4311+
// also need to update TBDGen to generate TBD entries for public adjoints.
4312+
auto linkage = SILLinkage::Hidden;
4313+
adjoint = fb.createFunction(
4314+
linkage, adjName, adjType, original->getGenericEnvironment(),
4315+
original->getLocation(), original->isBare(), original->isTransparent(),
4316+
original->isSerialized());
42794317
adjoint->setUnqualifiedOwnership();
42804318
adjoint->setDebugScope(new (module)
42814319
SILDebugScope(original->getLocation(), adjoint));
42824320

4283-
attr->setAdjointName(adjoint->getName(), /*primitive*/ false);
4321+
attr->setAdjointName(adjName, /*primitive*/ false);
42844322
}
42854323

42864324
void DifferentiationTask::createVJP() {
@@ -4300,22 +4338,25 @@ void DifferentiationTask::createVJP() {
43004338
auto originalTy = original->getLoweredFunctionType();
43014339

43024340
// === Create an empty VJP. ===
4303-
auto vjpName = "AD__" + original->getName().str() + "__vjp_" +
4304-
getIndices().mangle();
4341+
auto vjpName = original->getASTContext()
4342+
.getIdentifier("AD__" + original->getName().str() +
4343+
"__vjp_" + getIndices().mangle())
4344+
.str();
43054345
auto vjpType = originalTy->getAutoDiffAssociatedFunctionType(
43064346
getIndices().parameters, getIndices().source, 1,
43074347
AutoDiffAssociatedFunctionKind::VJP, module,
43084348
LookUpConformanceInModule(module.getSwiftModule()));
43094349

43104350
SILOptFunctionBuilder fb(context.getTransform());
4311-
vjp = fb.createFunction(
4312-
original->getLinkage(), vjpName, vjpType,
4313-
original->getGenericEnvironment(), original->getLocation(),
4314-
original->isBare(), original->isTransparent(), original->isSerialized());
4351+
auto linkage = getAutoDiffFunctionLinkage(original->getLinkage());
4352+
vjp = fb.createFunction(linkage, vjpName, vjpType,
4353+
original->getGenericEnvironment(),
4354+
original->getLocation(), original->isBare(),
4355+
original->isTransparent(), original->isSerialized());
43154356
vjp->setUnqualifiedOwnership();
43164357
vjp->setDebugScope(new (module)
43174358
SILDebugScope(original->getLocation(), vjp));
4318-
attr->setVJPName(vjp->getName());
4359+
attr->setVJPName(vjpName);
43194360

43204361
// Work around a bad interaction between VJPs, TFDeabstraction, and SIL
43214362
// optimizations.

lib/SILOptimizer/PassManager/PassPipeline.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ static void addMandatoryOptPipeline(SILPassPipelinePlan &P,
9797
addDefiniteInitialization(P);
9898
P.addClosureLifetimeFixup();
9999
P.addOwnershipModelEliminator();
100+
// SWIFT_ENABLE_TENSORFLOW
101+
P.addDifferentiation();
100102
P.addMandatoryInlining();
101103
P.addMandatorySILLinker();
102104
P.addPredictableMemoryOptimizations();
@@ -113,7 +115,6 @@ static void addMandatoryOptPipeline(SILPassPipelinePlan &P,
113115
P.addSplitNonCondBrCriticalEdges();
114116

115117
// SWIFT_ENABLE_TENSORFLOW
116-
P.addDifferentiation();
117118
P.addTFDeabstraction();
118119
}
119120

test/AutoDiff/builtin_math.sil

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,6 @@ bb0(%0 : @trivial $Builtin.FPIEEE32):
9696
// CHECK-LABEL: [differentiable source 0 wrt 0, 1 primal @AD__add_literals__primal_src_0_wrt_0_1 adjoint @AD__add_literals__adjoint_src_0_wrt_0_1 vjp @AD__add_literals__vjp_src_0_wrt_0_1] @add_literals
9797
// CHECK-LABEL [differentiable source 0 wrt 0, 1 primal @AD__fanout__primal_src_0_wrt_0_1 adjoint @AD__fanout__adjoint_src_0_wrt_0_1 vjp @AD__fanout__vjp_src_0_wrt_0_1] @fanout
9898

99-
// CHECK-LABEL: @AD__simple_mul__primal_src_0_wrt_0_1
100-
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32):
101-
// CHECK: %2 = builtin "fmul_FPIEEE32"(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
102-
// CHECK: %3 = struct $AD__simple_mul__Type__src_0_wrt_0_1 (%2 : $Builtin.FPIEEE32)
103-
// CHECK: %4 = tuple (%3 : $AD__simple_mul__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32)
104-
// CHECK: return %4 : $(AD__simple_mul__Type__src_0_wrt_0_1, Builtin.FPIEEE32)
105-
// CHECK: }
106-
10799
// CHECK-LABEL: @AD__simple_mul__adjoint_src_0_wrt_0_1
108100
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $AD__simple_mul__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32):
109101
// CHECK: %5 = builtin "fmul_FPIEEE32"(%0 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
@@ -134,15 +126,6 @@ bb0(%0 : @trivial $Builtin.FPIEEE32):
134126
// CHECK: return %10 : $(Builtin.FPIEEE32, Builtin.FPIEEE32)
135127
// CHECK: }
136128

137-
// CHECK-LABEL: @AD__chain_rule__primal_src_0_wrt_0_1
138-
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32):
139-
// CHECK: %2 = builtin "fneg_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
140-
// CHECK: %3 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
141-
// CHECK: %4 = struct $AD__chain_rule__Type__src_0_wrt_0_1 (%2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32)
142-
// CHECK: %5 = tuple (%4 : $AD__chain_rule__Type__src_0_wrt_0_1, %3 : $Builtin.FPIEEE32)
143-
// CHECK: return %5 : $(AD__chain_rule__Type__src_0_wrt_0_1, Builtin.FPIEEE32)
144-
// CHECK: }
145-
146129
// CHECK-LABEL: @AD__chain_rule__adjoint_src_0_wrt_0_1
147130
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $AD__chain_rule__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32):
148131
// CHECK: %5 = struct_extract %1 : $AD__chain_rule__Type__src_0_wrt_0_1, #AD__chain_rule__Type__src_0_wrt_0_1.v_0
@@ -153,17 +136,6 @@ bb0(%0 : @trivial $Builtin.FPIEEE32):
153136
// CHECK: return %9 : $(Builtin.FPIEEE32, Builtin.FPIEEE32)
154137
// CHECK: }
155138

156-
// CHECK-LABEL: @AD__add_literals__primal_src_0_wrt_0_1
157-
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32):
158-
// CHECK: %2 = builtin "fneg_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
159-
// CHECK: %3 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
160-
// CHECK: %4 = float_literal $Builtin.FPIEEE32, 0x64
161-
// CHECK: %5 = builtin "fsub_FPIEEE32"(%4 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
162-
// CHECK: %6 = struct $AD__add_literals__Type__src_0_wrt_0_1 (%2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %5 : $Builtin.FPIEEE32)
163-
// CHECK: %7 = tuple (%6 : $AD__add_literals__Type__src_0_wrt_0_1, %5 : $Builtin.FPIEEE32)
164-
// CHECK: return %7 : $(AD__add_literals__Type__src_0_wrt_0_1, Builtin.FPIEEE32)
165-
// CHECK: }
166-
167139
// CHECK-LABEL: @AD__add_literals__adjoint_src_0_wrt_0_1
168140
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $AD__add_literals__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32):
169141
// CHECK: %5 = builtin "fneg_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
@@ -175,15 +147,6 @@ bb0(%0 : @trivial $Builtin.FPIEEE32):
175147
// CHECK: return %10 : $(Builtin.FPIEEE32, Builtin.FPIEEE32)
176148
// CHECK: }
177149

178-
// CHECK-LABEL: @AD__fanout__primal_src_0_wrt_0_1
179-
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32):
180-
// CHECK: %2 = builtin "fmul_FPIEEE32"(%0 : $Builtin.FPIEEE32, %0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
181-
// CHECK: %3 = builtin "fmul_FPIEEE32"(%0 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
182-
// CHECK: %4 = struct $AD__fanout__Type__src_0_wrt_0_1 (%2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32)
183-
// CHECK: %5 = tuple (%4 : $AD__fanout__Type__src_0_wrt_0_1, %3 : $Builtin.FPIEEE32)
184-
// CHECK: return %5 : $(AD__fanout__Type__src_0_wrt_0_1, Builtin.FPIEEE32)
185-
// CHECK: }
186-
187150
// CHECK-LABEL: @AD__fanout__adjoint_src_0_wrt_0_1
188151
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $AD__fanout__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32):
189152
// CHECK: %5 = struct_extract %1 : $AD__fanout__Type__src_0_wrt_0_1, #AD__fanout__Type__src_0_wrt_0_1.v_0

test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ E2EDifferentiablePropertyTests.test("computed property") {
7676
}
7777

7878
// FIXME: The AD pass cannot differentiate this because it sees
79-
// `struct_extract`s instead of calls to getters. This problem should fix
80-
// itself once we move the AD pass before mandatory inlining, and we should be
81-
// able to enable this test.
79+
// `struct_extract`s instead of calls to getters.
8280
// E2EDifferentiablePropertyTests.test("stored property") {
8381
// let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
8482
// return 3 * point.y

0 commit comments

Comments
 (0)