Skip to content

move ad before mandatory inlining #21560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 64 additions & 23 deletions lib/SILOptimizer/Mandatory/TFDifferentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,28 @@ static SILFunction *lookUpOrLinkFunction(StringRef name, SILModule &module) {
return module.findFunction(name, SILLinkage::PublicExternal);
}

/// Computes the correct linkage for functions generated by the AD pass
/// associated with a function with linkage `originalLinkage`.
static SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage) {
// If the original is defined externally, then the AD pass is just generating
// associated functions for use in the current module and therefore these
// associated functions should not be visible outside the module.
if (isAvailableExternally(originalLinkage))
return SILLinkage::Hidden;

// If the original is public, then external modules may need to link the
// associated function. Make the associated function public.
if (originalLinkage == SILLinkage::Public)
return SILLinkage::Public;
if (originalLinkage == SILLinkage::PublicNonABI)
return SILLinkage::PublicNonABI;

// Otherwise, the original function is defined and used only in the current
// module, so external modules will never try to access the associated
// function. Make the associated function hidden.
return SILLinkage::Hidden;
}

/// Given a function, gather all of its formal results (both direct and
/// indirect) in an order defined by its result type. Note that "formal results"
/// refer to result values in the body of the function, not at call sites.
Expand Down Expand Up @@ -2104,8 +2126,11 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
// Clone.
cloneFunctionBody(original, entry, entryArgs);
// If errors occurred, back out.
if (errorOccurred)
if (errorOccurred) {
// Delete the body so that later passes don't get confused by invalid SIL.
getPrimal()->getBlocks().clear();
return true;
}
auto *origExit = &*original->findReturnBB();
auto *exit = BBMap.lookup(origExit);
assert(exit->getParent() == getPrimal());
Expand Down Expand Up @@ -4131,8 +4156,10 @@ void DifferentiationTask::createEmptyPrimal() {
auto indices = getIndices();
auto *original = getOriginal();
auto &module = context.getModule();
std::string primalName =
"AD__" + original->getName().str() + "__primal_" + indices.mangle();
auto primalName = original->getASTContext()
.getIdentifier("AD__" + original->getName().str() +
"__primal_" + indices.mangle())
.str();
StructDecl *primalValueStructDecl = context.createPrimalValueStruct(this);
primalInfo = std::unique_ptr<PrimalInfo>(new PrimalInfo(primalValueStructDecl, module));
auto pvType = primalValueStructDecl->getDeclaredType()->getCanonicalType();
Expand All @@ -4151,17 +4178,21 @@ void DifferentiationTask::createEmptyPrimal() {
originalTy->getParameters(), originalTy->getYields(), results,
originalTy->getOptionalErrorResult(), context.getASTContext());
SILOptFunctionBuilder fb(context.getTransform());
auto linkage = original->getLinkage();
if (linkage == SILLinkage::Public)
linkage = SILLinkage::PublicNonABI;
// We set generated primal linkage to Hidden because generated primals are
// never called cross-module in VJP mode: all cross-module calls to associated
// functions call the VJP.
// TODO: In order for cross-module calls to work in non-VJP mode, we must use
// `getAutoDiffFunctionLinkage` to make the linkage occasionally public. We'll
// also need to update TBDGen to generate TBD entries for public primals.
auto linkage = SILLinkage::Hidden;
primal = fb.getOrCreateFunction(
original->getLocation(), primalName, linkage, primalTy,
original->isBare(), original->isTransparent(), original->isSerialized());
primal->setUnqualifiedOwnership();
LLVM_DEBUG(getADDebugStream() << "Primal function created \n"
<< *primal << '\n');

attr->setPrimalName(primal->getName());
attr->setPrimalName(primalName);
}

void DifferentiationTask::createEmptyAdjoint() {
Expand Down Expand Up @@ -4263,24 +4294,31 @@ void DifferentiationTask::createEmptyAdjoint() {
->getCanonicalType()));
}

auto adjName = "AD__" + original->getName().str() + "__adjoint_" +
getIndices().mangle();
auto adjName = original->getASTContext()
.getIdentifier("AD__" + original->getName().str() +
"__adjoint_" + getIndices().mangle())
.str();
auto adjType = SILFunctionType::get(
origTy->getGenericSignature(), origTy->getExtInfo(),
origTy->getCoroutineKind(), origTy->getCalleeConvention(), adjParams, {},
adjResults, None, original->getASTContext());
SILOptFunctionBuilder fb(context.getTransform());
auto linkage = original->getLinkage();
if (linkage == SILLinkage::Public)
linkage = SILLinkage::PublicNonABI;
adjoint = fb.createFunction(linkage, adjName, adjType,
original->getGenericEnvironment(), original->getLocation(),
original->isBare(), original->isTransparent(), original->isSerialized());
// We set generated adjoint linkage to Hidden because generated adjoints are
// never called cross-module in VJP mode: all cross-module calls to associated
// functions call the VJP.
// TODO: In order for cross-module calls to work in non-VJP mode, we must use
// `getAutoDiffFunctionLinkage` to make the linkage occasionally public. We'll
// also need to update TBDGen to generate TBD entries for public adjoints.
auto linkage = SILLinkage::Hidden;
adjoint = fb.createFunction(
linkage, adjName, adjType, original->getGenericEnvironment(),
original->getLocation(), original->isBare(), original->isTransparent(),
original->isSerialized());
adjoint->setUnqualifiedOwnership();
adjoint->setDebugScope(new (module)
SILDebugScope(original->getLocation(), adjoint));

attr->setAdjointName(adjoint->getName(), /*primitive*/ false);
attr->setAdjointName(adjName, /*primitive*/ false);
}

void DifferentiationTask::createVJP() {
Expand All @@ -4300,22 +4338,25 @@ void DifferentiationTask::createVJP() {
auto originalTy = original->getLoweredFunctionType();

// === Create an empty VJP. ===
auto vjpName = "AD__" + original->getName().str() + "__vjp_" +
getIndices().mangle();
auto vjpName = original->getASTContext()
.getIdentifier("AD__" + original->getName().str() +
"__vjp_" + getIndices().mangle())
.str();
auto vjpType = originalTy->getAutoDiffAssociatedFunctionType(
getIndices().parameters, getIndices().source, 1,
AutoDiffAssociatedFunctionKind::VJP, module,
LookUpConformanceInModule(module.getSwiftModule()));

SILOptFunctionBuilder fb(context.getTransform());
vjp = fb.createFunction(
original->getLinkage(), vjpName, vjpType,
original->getGenericEnvironment(), original->getLocation(),
original->isBare(), original->isTransparent(), original->isSerialized());
auto linkage = getAutoDiffFunctionLinkage(original->getLinkage());
vjp = fb.createFunction(linkage, vjpName, vjpType,
original->getGenericEnvironment(),
original->getLocation(), original->isBare(),
original->isTransparent(), original->isSerialized());
vjp->setUnqualifiedOwnership();
vjp->setDebugScope(new (module)
SILDebugScope(original->getLocation(), vjp));
attr->setVJPName(vjp->getName());
attr->setVJPName(vjpName);

// Work around a bad interaction between VJPs, TFDeabstraction, and SIL
// optimizations.
Expand Down
3 changes: 2 additions & 1 deletion lib/SILOptimizer/PassManager/PassPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ static void addMandatoryOptPipeline(SILPassPipelinePlan &P,
addDefiniteInitialization(P);
P.addClosureLifetimeFixup();
P.addOwnershipModelEliminator();
// SWIFT_ENABLE_TENSORFLOW
P.addDifferentiation();
P.addMandatoryInlining();
P.addMandatorySILLinker();
P.addPredictableMemoryOptimizations();
Expand All @@ -113,7 +115,6 @@ static void addMandatoryOptPipeline(SILPassPipelinePlan &P,
P.addSplitNonCondBrCriticalEdges();

// SWIFT_ENABLE_TENSORFLOW
P.addDifferentiation();
P.addTFDeabstraction();
}

Expand Down
37 changes: 0 additions & 37 deletions test/AutoDiff/builtin_math.sil
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,6 @@ bb0(%0 : @trivial $Builtin.FPIEEE32):
// CHECK-LABEL: [differentiable source 0 wrt 0, 1 primal @AD__add_literals__primal_src_0_wrt_0_1 adjoint @AD__add_literals__adjoint_src_0_wrt_0_1 vjp @AD__add_literals__vjp_src_0_wrt_0_1] @add_literals
// CHECK-LABEL [differentiable source 0 wrt 0, 1 primal @AD__fanout__primal_src_0_wrt_0_1 adjoint @AD__fanout__adjoint_src_0_wrt_0_1 vjp @AD__fanout__vjp_src_0_wrt_0_1] @fanout

// CHECK-LABEL: @AD__simple_mul__primal_src_0_wrt_0_1
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32):
// CHECK: %2 = builtin "fmul_FPIEEE32"(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %3 = struct $AD__simple_mul__Type__src_0_wrt_0_1 (%2 : $Builtin.FPIEEE32)
// CHECK: %4 = tuple (%3 : $AD__simple_mul__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32)
// CHECK: return %4 : $(AD__simple_mul__Type__src_0_wrt_0_1, Builtin.FPIEEE32)
// CHECK: }

// CHECK-LABEL: @AD__simple_mul__adjoint_src_0_wrt_0_1
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $AD__simple_mul__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32):
// CHECK: %5 = builtin "fmul_FPIEEE32"(%0 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
Expand Down Expand Up @@ -134,15 +126,6 @@ bb0(%0 : @trivial $Builtin.FPIEEE32):
// CHECK: return %10 : $(Builtin.FPIEEE32, Builtin.FPIEEE32)
// CHECK: }

// CHECK-LABEL: @AD__chain_rule__primal_src_0_wrt_0_1
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32):
// CHECK: %2 = builtin "fneg_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %3 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %4 = struct $AD__chain_rule__Type__src_0_wrt_0_1 (%2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32)
// CHECK: %5 = tuple (%4 : $AD__chain_rule__Type__src_0_wrt_0_1, %3 : $Builtin.FPIEEE32)
// CHECK: return %5 : $(AD__chain_rule__Type__src_0_wrt_0_1, Builtin.FPIEEE32)
// CHECK: }

// CHECK-LABEL: @AD__chain_rule__adjoint_src_0_wrt_0_1
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $AD__chain_rule__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32):
// CHECK: %5 = struct_extract %1 : $AD__chain_rule__Type__src_0_wrt_0_1, #AD__chain_rule__Type__src_0_wrt_0_1.v_0
Expand All @@ -153,17 +136,6 @@ bb0(%0 : @trivial $Builtin.FPIEEE32):
// CHECK: return %9 : $(Builtin.FPIEEE32, Builtin.FPIEEE32)
// CHECK: }

// CHECK-LABEL: @AD__add_literals__primal_src_0_wrt_0_1
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32):
// CHECK: %2 = builtin "fneg_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %3 = builtin "fmul_FPIEEE32"(%1 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %4 = float_literal $Builtin.FPIEEE32, 0x64
// CHECK: %5 = builtin "fsub_FPIEEE32"(%4 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %6 = struct $AD__add_literals__Type__src_0_wrt_0_1 (%2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %5 : $Builtin.FPIEEE32)
// CHECK: %7 = tuple (%6 : $AD__add_literals__Type__src_0_wrt_0_1, %5 : $Builtin.FPIEEE32)
// CHECK: return %7 : $(AD__add_literals__Type__src_0_wrt_0_1, Builtin.FPIEEE32)
// CHECK: }

// CHECK-LABEL: @AD__add_literals__adjoint_src_0_wrt_0_1
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $AD__add_literals__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32):
// CHECK: %5 = builtin "fneg_FPIEEE32"(%0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
Expand All @@ -175,15 +147,6 @@ bb0(%0 : @trivial $Builtin.FPIEEE32):
// CHECK: return %10 : $(Builtin.FPIEEE32, Builtin.FPIEEE32)
// CHECK: }

// CHECK-LABEL: @AD__fanout__primal_src_0_wrt_0_1
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $Builtin.FPIEEE32):
// CHECK: %2 = builtin "fmul_FPIEEE32"(%0 : $Builtin.FPIEEE32, %0 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %3 = builtin "fmul_FPIEEE32"(%0 : $Builtin.FPIEEE32, %2 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32
// CHECK: %4 = struct $AD__fanout__Type__src_0_wrt_0_1 (%2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32)
// CHECK: %5 = tuple (%4 : $AD__fanout__Type__src_0_wrt_0_1, %3 : $Builtin.FPIEEE32)
// CHECK: return %5 : $(AD__fanout__Type__src_0_wrt_0_1, Builtin.FPIEEE32)
// CHECK: }

// CHECK-LABEL: @AD__fanout__adjoint_src_0_wrt_0_1
// CHECK: bb0(%0 : $Builtin.FPIEEE32, %1 : $AD__fanout__Type__src_0_wrt_0_1, %2 : $Builtin.FPIEEE32, %3 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32):
// CHECK: %5 = struct_extract %1 : $AD__fanout__Type__src_0_wrt_0_1, #AD__fanout__Type__src_0_wrt_0_1.v_0
Expand Down
4 changes: 1 addition & 3 deletions test/AutoDiff/e2e_differentiable_property.swift
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ E2EDifferentiablePropertyTests.test("computed property") {
}

// FIXME: The AD pass cannot differentiate this because it sees
// `struct_extract`s instead of calls to getters. This problem should fix
// itself once we move the AD pass before mandatory inlining, and we should be
// able to enable this test.
// `struct_extract`s instead of calls to getters.
// E2EDifferentiablePropertyTests.test("stored property") {
// let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
// return 3 * point.y
Expand Down
Loading