Skip to content

Commit 22ccb3f

Browse files
committed
Introduce experimental features for autodiff options
1 parent ad65a64 commit 22ccb3f

File tree

8 files changed

+26
-16
lines changed

8 files changed

+26
-16
lines changed

include/swift/Basic/Features.def

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ EXPERIMENTAL_FEATURE(TypeWitnessSystemInference)
108108
/// \endcode
109109
EXPERIMENTAL_FEATURE(BoundGenericExtensions)
110110

111+
/// Whether to enable experimental differentiable programming features:
112+
/// `@differentiable` declaration attribute, etc.
113+
EXPERIMENTAL_FEATURE(DifferentiableProgramming)
114+
115+
/// Whether to enable forward mode differentiation.
116+
EXPERIMENTAL_FEATURE(ForwardModeDifferentiation)
117+
118+
111119
#undef EXPERIMENTAL_FEATURE
112120
#undef FUTURE_FEATURE
113121
#undef SUPPRESSIBLE_LANGUAGE_FEATURE

include/swift/Basic/LangOptions.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -426,13 +426,6 @@ namespace swift {
426426
/// file.
427427
bool EmitFineGrainedDependencySourcefileDotFiles = false;
428428

429-
/// Whether to enable experimental differentiable programming features:
430-
/// `@differentiable` declaration attribute, etc.
431-
bool EnableExperimentalDifferentiableProgramming = false;
432-
433-
/// Whether to enable forward mode differentiation.
434-
bool EnableExperimentalForwardModeDifferentiation = false;
435-
436429
/// Whether to enable experimental `AdditiveArithmetic` derived
437430
/// conformances.
438431
bool EnableExperimentalAdditiveArithmeticDerivedConformances = false;

lib/AST/ASTPrinter.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3039,6 +3039,14 @@ static bool usesFeatureBoundGenericExtensions(Decl *decl) {
30393039
return false;
30403040
}
30413041

3042+
static bool usesFeatureDifferentiableProgramming(Decl *decl) {
3043+
return false;
3044+
}
3045+
3046+
static bool usesFeatureForwardModeDifferentiation(Decl *decl) {
3047+
return false;
3048+
}
3049+
30423050
static void
30433051
suppressingFeatureNoAsyncAvailability(PrintOptions &options,
30443052
llvm::function_ref<void()> action) {

lib/AST/AutoDiff.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void AutoDiffConfig::print(llvm::raw_ostream &s) const {
112112
bool swift::isDifferentiableProgrammingEnabled(SourceFile &SF) {
113113
auto &ctx = SF.getASTContext();
114114
// Return true if differentiable programming is explicitly enabled.
115-
if (ctx.LangOpts.EnableExperimentalDifferentiableProgramming)
115+
if (ctx.LangOpts.hasFeature(Feature::DifferentiableProgramming))
116116
return true;
117117
// Otherwise, return true iff the `_Differentiation` module is imported in
118118
// the given source file.

lib/Frontend/CompilerInvocation.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,9 +547,6 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
547547
if (Args.hasArg(OPT_enable_experimental_additive_arithmetic_derivation))
548548
Opts.EnableExperimentalAdditiveArithmeticDerivedConformances = true;
549549

550-
Opts.EnableExperimentalForwardModeDifferentiation |=
551-
Args.hasArg(OPT_enable_experimental_forward_mode_differentiation);
552-
553550
Opts.DebuggerSupport |= Args.hasArg(OPT_debugger_support);
554551
if (Opts.DebuggerSupport)
555552
Opts.EnableDollarIdentifiers = true;
@@ -665,6 +662,8 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
665662
Opts.Features.insert(Feature::TypeWitnessSystemInference);
666663
if (Args.hasArg(OPT_enable_experimental_bound_generic_extensions))
667664
Opts.Features.insert(Feature::BoundGenericExtensions);
665+
if (Args.hasArg(OPT_enable_experimental_forward_mode_differentiation))
666+
Opts.Features.insert(Feature::ForwardModeDifferentiation);
668667

669668
Opts.EnableAppExtensionRestrictions |= Args.hasArg(OPT_enable_app_extension);
670669

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
898898
// - Functions with no return.
899899
// - Functions with unsupported control flow.
900900
if (context.getASTContext()
901-
.LangOpts.EnableExperimentalForwardModeDifferentiation &&
901+
.LangOpts.hasFeature(Feature::ForwardModeDifferentiation) &&
902902
(diagnoseNoReturn(context, witness->getOriginalFunction(), invoker) ||
903903
diagnoseUnsupportedControlFlow(
904904
context, witness->getOriginalFunction(), invoker)))
@@ -914,7 +914,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
914914
// generation because generated JVP may not match semantics of custom VJP.
915915
// Instead, create an empty JVP.
916916
if (context.getASTContext()
917-
.LangOpts.EnableExperimentalForwardModeDifferentiation &&
917+
.LangOpts.hasFeature(Feature::ForwardModeDifferentiation) &&
918918
!witness->getVJP()) {
919919
// JVP and differential generation do not currently support functions with
920920
// multiple basic blocks.

lib/TBDGen/TBDGen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
559559

560560
// Differential functions are emitted only when forward-mode is enabled.
561561
if (kind == AutoDiffLinearMapKind::Differential &&
562-
!ctx.LangOpts.EnableExperimentalForwardModeDifferentiation)
562+
!ctx.LangOpts.hasFeature(Feature::ForwardModeDifferentiation))
563563
return;
564564
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
565565
config.parameterIndices,

tools/sil-opt/SILOpt.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,8 +576,10 @@ int main(int argc, char **argv) {
576576
if (EnableExperimentalStaticAssert)
577577
Invocation.getLangOptions().Features.insert(Feature::StaticAssert);
578578

579-
Invocation.getLangOptions().EnableExperimentalDifferentiableProgramming =
580-
EnableExperimentalDifferentiableProgramming;
579+
if (EnableExperimentalDifferentiableProgramming) {
580+
Invocation.getLangOptions().Features.insert(
581+
Feature::DifferentiableProgramming);
582+
}
581583

582584
Invocation.getLangOptions().EnableCXXInterop = EnableCxxInterop;
583585

0 commit comments

Comments
 (0)