Skip to content

Commit 2eb460d

Browse files
authored
[AutoDiff upstream] Add forward-mode differentiation. (#30878)
JVP functions are forward-mode derivative functions. They take original arguments and return original results and a differential function. Differential functions take derivatives wrt arguments and return derivatives wrt results. `JVPEmitter` is a cloner that emits JVP and differential functions at the same time. In JVP functions, function applications are replaced with JVP function applications. In differential functions, function applications are replaced with differential function applications. In JVP functions, each basic block takes a differential struct containing callee differentials. These structs are consumed by differential functions.
1 parent b5570a1 commit 2eb460d

File tree

9 files changed

+1957
-31
lines changed

9 files changed

+1957
-31
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,8 @@ NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,
512512
"function; use an explicit closure instead", ())
513513
NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
514514
"cannot differentiate through multiple results", ())
515+
NOTE(autodiff_cannot_differentiate_through_inout_arguments,none,
516+
"cannot differentiate through 'inout' arguments", ())
515517
// TODO(TF-1149): Remove this diagnostic.
516518
NOTE(autodiff_loadable_value_addressonly_tangent_unsupported,none,
517519
"cannot yet differentiate value whose type %0 has a compile-time known "

include/swift/SILOptimizer/Utils/Differentiation/JVPEmitter.h

Lines changed: 410 additions & 0 deletions
Large diffs are not rendered by default.

lib/Frontend/CompilerInvocation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,9 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
449449
if (Args.hasArg(OPT_enable_experimental_additive_arithmetic_derivation))
450450
Opts.EnableExperimentalAdditiveArithmeticDerivedConformances = true;
451451

452+
Opts.EnableExperimentalForwardModeDifferentiation |=
453+
Args.hasArg(OPT_enable_experimental_forward_mode_differentiation);
454+
452455
Opts.DebuggerSupport |= Args.hasArg(OPT_debugger_support);
453456
if (Opts.DebuggerSupport)
454457
Opts.EnableDollarIdentifiers = true;

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "swift/SILOptimizer/PassManager/Passes.h"
4040
#include "swift/SILOptimizer/PassManager/Transforms.h"
4141
#include "swift/SILOptimizer/Utils/Differentiation/ADContext.h"
42+
#include "swift/SILOptimizer/Utils/Differentiation/JVPEmitter.h"
4243
#include "swift/SILOptimizer/Utils/Differentiation/Thunk.h"
4344
#include "swift/SILOptimizer/Utils/Differentiation/VJPEmitter.h"
4445
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
@@ -898,9 +899,10 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
898899
diagnoseUnsupportedControlFlow(context, original, invoker)))
899900
return true;
900901

901-
witness->setJVP(
902-
createEmptyJVP(context, original, witness, serializeFunctions));
903-
context.recordGeneratedFunction(witness->getJVP());
902+
// Create empty JVP.
903+
auto *jvp = createEmptyJVP(context, original, witness, serializeFunctions);
904+
witness->setJVP(jvp);
905+
context.recordGeneratedFunction(jvp);
904906

905907
// For now, only do JVP generation if the flag is enabled and if custom VJP
906908
// does not exist. If custom VJP exists but custom JVP does not, skip JVP
@@ -917,18 +919,18 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
917919
diag::autodiff_jvp_control_flow_not_supported);
918920
return true;
919921
}
920-
// TODO(TF-1211): Upstream and use `JVPEmitter`. Fatal error with a nice
921-
// message for now.
922-
auto *jvp = witness->getJVP();
923-
emitFatalError(context, jvp, "_fatalErrorJVPNotGenerated");
922+
// Emit JVP function.
923+
JVPEmitter emitter(context, original, witness, jvp, invoker);
924+
if (emitter.run())
925+
return true;
924926
} else {
925927
// If JVP generation is disabled or a user-defined custom VJP function
926928
// exists, fatal error with a nice message.
927-
emitFatalError(context, witness->getJVP(),
929+
emitFatalError(context, jvp,
928930
"_fatalErrorForwardModeDifferentiationDisabled");
929931
LLVM_DEBUG(getADDebugStream()
930932
<< "Generated empty JVP for " << original->getName() << ":\n"
931-
<< *witness->getJVP());
933+
<< *jvp);
932934
}
933935
}
934936

@@ -945,6 +947,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
945947
auto *vjp = createEmptyVJP(context, original, witness, serializeFunctions);
946948
witness->setVJP(vjp);
947949
context.recordGeneratedFunction(vjp);
950+
// Emit VJP function.
948951
VJPEmitter emitter(context, original, witness, vjp, invoker);
949952
return emitter.run();
950953
}

lib/SILOptimizer/Utils/Differentiation/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ silopt_register_sources(
22
ADContext.cpp
33
Common.cpp
44
DifferentiationInvoker.cpp
5+
JVPEmitter.cpp
56
LinearMapInfo.cpp
67
PullbackEmitter.cpp
78
Thunk.cpp

0 commit comments

Comments
 (0)