Skip to content

[AutoDiff upstream] Add differential operators and some utilities. #30711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions include/swift/Demangling/TypeDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,33 +181,53 @@ enum class ImplFunctionRepresentation {
Closure
};

enum class ImplFunctionDifferentiabilityKind {
NonDifferentiable,
Normal,
Linear
};

class ImplFunctionTypeFlags {
unsigned Rep : 3;
unsigned Pseudogeneric : 1;
unsigned Escaping : 1;
unsigned DifferentiabilityKind : 2;

public:
ImplFunctionTypeFlags() : Rep(0), Pseudogeneric(0), Escaping(0) {}
ImplFunctionTypeFlags()
: Rep(0), Pseudogeneric(0), Escaping(0), DifferentiabilityKind(0) {}

ImplFunctionTypeFlags(ImplFunctionRepresentation rep,
bool pseudogeneric, bool noescape)
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape) {}
ImplFunctionTypeFlags(ImplFunctionRepresentation rep, bool pseudogeneric,
bool noescape,
ImplFunctionDifferentiabilityKind diffKind)
: Rep(unsigned(rep)), Pseudogeneric(pseudogeneric), Escaping(noescape),
DifferentiabilityKind(unsigned(diffKind)) {}

ImplFunctionTypeFlags
withRepresentation(ImplFunctionRepresentation rep) const {
return ImplFunctionTypeFlags(rep, Pseudogeneric, Escaping);
return ImplFunctionTypeFlags(
rep, Pseudogeneric, Escaping,
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
}

ImplFunctionTypeFlags
withEscaping() const {
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
Pseudogeneric, true);
return ImplFunctionTypeFlags(
ImplFunctionRepresentation(Rep), Pseudogeneric, true,
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
}

ImplFunctionTypeFlags
withPseudogeneric() const {
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep),
true, Escaping);
return ImplFunctionTypeFlags(
ImplFunctionRepresentation(Rep), true, Escaping,
ImplFunctionDifferentiabilityKind(DifferentiabilityKind));
}

ImplFunctionTypeFlags
withDifferentiabilityKind(ImplFunctionDifferentiabilityKind diffKind) const {
return ImplFunctionTypeFlags(ImplFunctionRepresentation(Rep), Pseudogeneric,
Escaping, diffKind);
}

ImplFunctionRepresentation getRepresentation() const {
Expand All @@ -217,6 +237,10 @@ class ImplFunctionTypeFlags {
bool isEscaping() const { return Escaping; }

bool isPseudogeneric() const { return Pseudogeneric; }

ImplFunctionDifferentiabilityKind getDifferentiabilityKind() const {
return ImplFunctionDifferentiabilityKind(DifferentiabilityKind);
}
};

#if SWIFT_OBJC_INTEROP
Expand Down Expand Up @@ -582,6 +606,14 @@ class TypeDecoder {
flags =
flags.withRepresentation(ImplFunctionRepresentation::Block);
}
} else if (child->getKind() == NodeKind::ImplDifferentiable) {
flags = flags.withDifferentiabilityKind(
ImplFunctionDifferentiabilityKind::Normal);
} else if (child->getKind() == NodeKind::ImplLinear) {
flags = flags.withDifferentiabilityKind(
ImplFunctionDifferentiabilityKind::Linear);
} else if (child->getKind() == NodeKind::ImplEscaping) {
flags = flags.withEscaping();
} else if (child->getKind() == NodeKind::ImplEscaping) {
flags = flags.withEscaping();
} else if (child->getKind() == NodeKind::ImplParameter) {
Expand Down
20 changes: 16 additions & 4 deletions lib/AST/ASTDemangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,23 @@ Type ASTBuilder::createImplFunctionType(
break;
}

DifferentiabilityKind diffKind;
switch (flags.getDifferentiabilityKind()) {
case ImplFunctionDifferentiabilityKind::NonDifferentiable:
diffKind = DifferentiabilityKind::NonDifferentiable;
break;
case ImplFunctionDifferentiabilityKind::Normal:
diffKind = DifferentiabilityKind::Normal;
break;
case ImplFunctionDifferentiabilityKind::Linear:
diffKind = DifferentiabilityKind::Linear;
break;
}

// TODO: [store-sil-clang-function-type]
auto einfo = SILFunctionType::ExtInfo(
representation, flags.isPseudogeneric(), !flags.isEscaping(),
DifferentiabilityKind::NonDifferentiable,
/*clangFunctionType*/ nullptr);
auto einfo = SILFunctionType::ExtInfo(representation, flags.isPseudogeneric(),
!flags.isEscaping(), diffKind,
/*clangFunctionType*/ nullptr);

llvm::SmallVector<SILParameterInfo, 8> funcParams;
llvm::SmallVector<SILYieldInfo, 8> funcYields;
Expand Down
26 changes: 9 additions & 17 deletions lib/SILGen/SILGenBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,14 +1106,9 @@ static ManagedValue emitBuiltinAutoDiffApplyTransposeFunction(
origFnArgVals.push_back(arg.getValue());

// Get the transpose function.
// TODO(TF-1142): Create a linear_function_extract instead of an undef.
auto fnTy = origFnVal->getType().castTo<SILFunctionType>();
auto transposeFnType =
fnTy->getWithoutDifferentiability()->getAutoDiffTransposeFunctionType(
fnTy->getDifferentiabilityParameterIndices(), SGF.SGM.M.Types,
LookUpConformanceInModule(SGF.SGM.M.getSwiftModule()));
SILValue transposeFn =
SILUndef::get(SILType::getPrimitiveObjectType(transposeFnType), SGF.F);
SILValue transposeFn = SGF.B.createLinearFunctionExtract(
loc, LinearDifferentiableFunctionTypeComponent::Transpose, origFnVal);
auto transposeFnType = transposeFn->getType().castTo<SILFunctionType>();
auto transposeFnUnsubstType =
transposeFnType->getUnsubstitutedType(SGF.getModule());
if (transposeFnType != transposeFnUnsubstType) {
Expand Down Expand Up @@ -1204,19 +1199,16 @@ static ManagedValue emitBuiltinLinearFunction(
assert(args.size() == 2);
auto origFn = args.front();
auto origType = origFn.getType().castTo<SILFunctionType>();
// TODO(TF-1142): Create a linear_function instead of an undef.
auto linearFnTy = origType->getWithDifferentiability(
DifferentiabilityKind::Linear,
auto linearFn = SGF.B.createLinearFunction(
loc,
IndexSubset::getDefault(
SGF.getASTContext(), origType->getNumParameters(),
/*includeAll*/ true));
SILValue linearFn = SILUndef::get(
SILType::getPrimitiveObjectType(linearFnTy), SGF.F);
SGF.getASTContext(),
origType->getNumParameters(),
/*includeAll*/ true),
origFn.forward(SGF), args[1].forward(SGF));
return SGF.emitManagedRValueWithCleanup(linearFn);
}



/// Emit SIL for the named builtin: globalStringTablePointer. Unlike the default
/// ownership convention for named builtins, which is to take (non-trivial)
/// arguments as Owned, this builtin accepts owned as well as guaranteed
Expand Down
4 changes: 4 additions & 0 deletions lib/SILOptimizer/Transforms/SemanticARCOpts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,10 @@ struct SemanticARCOptVisitor
FORWARDING_INST(OpenExistentialBoxValue)
FORWARDING_INST(MarkDependence)
FORWARDING_INST(InitExistentialRef)
FORWARDING_INST(DifferentiableFunction)
FORWARDING_INST(LinearFunction)
FORWARDING_INST(DifferentiableFunctionExtract)
FORWARDING_INST(LinearFunctionExtract)
#undef FORWARDING_INST

#define FORWARDING_TERM(NAME) \
Expand Down
7 changes: 6 additions & 1 deletion stdlib/public/Differentiation/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

add_swift_target_library(swift_Differentiation ${SWIFT_STDLIB_LIBRARY_BUILD_TYPES} IS_STDLIB
Differentiable.swift
DifferentialOperators.swift
DifferentiationUtilities.swift

SWIFT_COMPILE_FLAGS ${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
SWIFT_COMPILE_FLAGS
${SWIFT_STANDARD_LIBRARY_SWIFT_FLAGS}
-parse-stdlib
-Xfrontend -enable-experimental-differentiable-programming
LINK_FLAGS "${SWIFT_RUNTIME_SWIFT_LINK_FLAGS}"
INSTALL_IN_COMPONENT stdlib)
2 changes: 2 additions & 0 deletions stdlib/public/Differentiation/Differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
//
//===----------------------------------------------------------------------===//

import Swift

/// A type that mathematically represents a differentiable manifold whose
/// tangent spaces are finite-dimensional.
public protocol Differentiable {
Expand Down
Loading